From 7b8d6359c2d2c4e943635a7d459bd7fb2b40599c Mon Sep 17 00:00:00 2001 From: Joe Damato Date: Tue, 14 Sep 2010 18:01:14 -0700 Subject: [PATCH] Fix data corruption bug --- ext/mysql2/client.c | 84 +++++++++++++++++++++------------------------ ext/mysql2/client.h | 2 +- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index 432d9e3..7fe8ba8 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -19,9 +19,7 @@ static ID intern_merge, intern_error_number_eql, intern_sql_state_eql; #define GET_CLIENT(self) \ mysql_client_wrapper *wrapper; \ - MYSQL *client; \ - Data_Get_Struct(self, mysql_client_wrapper, wrapper); \ - client = &wrapper->client; + Data_Get_Struct(self, mysql_client_wrapper, wrapper); /* * used to pass all arguments to mysql_real_connect while inside @@ -85,12 +83,11 @@ static VALUE rb_raise_mysql2_error(MYSQL *client) { } static VALUE nogvl_init(void *ptr) { - MYSQL * client = (MYSQL *)ptr; + MYSQL **client = (MYSQL **)ptr; /* may initialize embedded server and read /etc/services off disk */ - client = mysql_init(NULL); - - return client ? Qtrue : Qfalse; + *client = mysql_init(NULL); + return *client ? Qtrue : Qfalse; } static VALUE nogvl_connect(void *ptr) { @@ -108,15 +105,14 @@ static VALUE nogvl_connect(void *ptr) { } static void rb_mysql_client_free(void * ptr) { - mysql_client_wrapper * wrapper = (mysql_client_wrapper *)ptr; - MYSQL * client = &wrapper->client; + GET_CLIENT(ptr) /* * we'll send a QUIT message to the server, but that message is more of a * formality than a hard requirement since the socket is getting shutdown * anyways, so ensure the socket write does not block our interpreter */ - int fd = client->net.fd; + int fd = wrapper->client->net.fd; if (fd >= 0) { /* @@ -134,7 +130,7 @@ static void rb_mysql_client_free(void * ptr) { } /* It's safe to call mysql_close() on an already closed connection. */ - mysql_close(client); + mysql_close(wrapper->client); xfree(ptr); } @@ -164,12 +160,12 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po args.user = NIL_P(user) ? NULL : StringValuePtr(user); args.passwd = NIL_P(pass) ? NULL : StringValuePtr(pass); args.db = NIL_P(database) ? NULL : StringValuePtr(database); - args.mysql = client; + args.mysql = wrapper->client; args.client_flag = NUM2INT(flags); if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) { // unable to connect - return rb_raise_mysql2_error(client); + return rb_raise_mysql2_error(wrapper->client); } return self; @@ -184,7 +180,7 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po static VALUE rb_mysql_client_close(VALUE self) { GET_CLIENT(self) - rb_thread_blocking_region(nogvl_close, client, RUBY_UBF_IO, 0); + rb_thread_blocking_region(nogvl_close, wrapper->client, RUBY_UBF_IO, 0); return Qnil; } @@ -227,21 +223,21 @@ static VALUE rb_mysql_client_async_result(VALUE self) { MYSQL_RES * result; GET_CLIENT(self) - REQUIRE_OPEN_DB(client); - if (rb_thread_blocking_region(nogvl_read_query_result, client, RUBY_UBF_IO, 0) == Qfalse) { + REQUIRE_OPEN_DB(wrapper->client); + if (rb_thread_blocking_region(nogvl_read_query_result, wrapper->client, RUBY_UBF_IO, 0) == Qfalse) { // an error occurred, mark this connection inactive MARK_CONN_INACTIVE(self); - return rb_raise_mysql2_error(client); + return rb_raise_mysql2_error(wrapper->client); } - result = (MYSQL_RES *)rb_thread_blocking_region(nogvl_store_result, client, RUBY_UBF_IO, 0); + result = (MYSQL_RES *)rb_thread_blocking_region(nogvl_store_result, wrapper->client, RUBY_UBF_IO, 0); // we have our result, mark this connection inactive MARK_CONN_INACTIVE(self); if (result == NULL) { - if (mysql_field_count(client) != 0) { - rb_raise_mysql2_error(client); + if (mysql_field_count(wrapper->client) != 0) { + rb_raise_mysql2_error(wrapper->client); } return Qnil; } @@ -266,8 +262,8 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { VALUE opts, defaults; GET_CLIENT(self) - REQUIRE_OPEN_DB(client); - args.mysql = client; + REQUIRE_OPEN_DB(wrapper->client); + args.mysql = wrapper->client; // see if this connection is still waiting on a result from a previous query if (wrapper->active == 0) { @@ -298,13 +294,13 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { if (rb_thread_blocking_region(nogvl_send_query, &args, RUBY_UBF_IO, 0) == Qfalse) { // an error occurred, we're not active anymore MARK_CONN_INACTIVE(self); - return rb_raise_mysql2_error(client); + return rb_raise_mysql2_error(wrapper->client); } if (!async) { // the below code is largely from do_mysql // http://github.com/datamapper/do - fd = client->net.fd; + fd = wrapper->client->net.fd; for(;;) { FD_ZERO(&fdset); FD_SET(fd, &fdset); @@ -344,8 +340,8 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { oldLen = RSTRING_LEN(str); newStr = rb_str_new(0, oldLen*2+1); - REQUIRE_OPEN_DB(client); - newLen = mysql_real_escape_string(client, RSTRING_PTR(newStr), StringValuePtr(str), oldLen); + REQUIRE_OPEN_DB(wrapper->client); + newLen = mysql_real_escape_string(wrapper->client, RSTRING_PTR(newStr), StringValuePtr(str), oldLen); if (newLen == oldLen) { // no need to return a new ruby string if nothing changed return str; @@ -389,11 +385,11 @@ static VALUE rb_mysql_client_server_info(VALUE self) { rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding); #endif - REQUIRE_OPEN_DB(client); + REQUIRE_OPEN_DB(wrapper->client); version = rb_hash_new(); - rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(client))); - server_info = rb_str_new2(mysql_get_server_info(client)); + rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(wrapper->client))); + server_info = rb_str_new2(mysql_get_server_info(wrapper->client)); #ifdef HAVE_RUBY_ENCODING_H rb_enc_associate(server_info, conn_enc); if (default_internal_enc) { @@ -406,20 +402,20 @@ static VALUE rb_mysql_client_server_info(VALUE self) { static VALUE rb_mysql_client_socket(VALUE self) { GET_CLIENT(self) - REQUIRE_OPEN_DB(client); - return INT2NUM(client->net.fd); + REQUIRE_OPEN_DB(wrapper->client); + return INT2NUM(wrapper->client->net.fd); } static VALUE rb_mysql_client_last_id(VALUE self) { GET_CLIENT(self) - REQUIRE_OPEN_DB(client); - return ULL2NUM(mysql_insert_id(client)); + REQUIRE_OPEN_DB(wrapper->client); + return ULL2NUM(mysql_insert_id(wrapper->client)); } static VALUE rb_mysql_client_affected_rows(VALUE self) { GET_CLIENT(self) - REQUIRE_OPEN_DB(client); - return ULL2NUM(mysql_affected_rows(client)); + REQUIRE_OPEN_DB(wrapper->client); + return ULL2NUM(mysql_affected_rows(wrapper->client)); } static VALUE set_reconnect(VALUE self, VALUE value) { @@ -430,9 +426,9 @@ static VALUE set_reconnect(VALUE self, VALUE value) { reconnect = value == Qfalse ? 0 : 1; /* set default reconnect behavior */ - if (mysql_options(client, MYSQL_OPT_RECONNECT, &reconnect)) { + if (mysql_options(wrapper->client, MYSQL_OPT_RECONNECT, &reconnect)) { /* TODO: warning - unable to set reconnect behavior */ - rb_warn("%s\n", mysql_error(client)); + rb_warn("%s\n", mysql_error(wrapper->client)); } } return value; @@ -447,9 +443,9 @@ static VALUE set_connect_timeout(VALUE self, VALUE value) { if(0 == connect_timeout) return value; /* set default connection timeout behavior */ - if (mysql_options(client, MYSQL_OPT_CONNECT_TIMEOUT, &connect_timeout)) { + if (mysql_options(wrapper->client, MYSQL_OPT_CONNECT_TIMEOUT, &connect_timeout)) { /* TODO: warning - unable to set connection timeout */ - rb_warn("%s\n", mysql_error(client)); + rb_warn("%s\n", mysql_error(wrapper->client)); } } return value; @@ -473,9 +469,9 @@ static VALUE set_charset_name(VALUE self, VALUE value) { charset_name = StringValuePtr(value); - if (mysql_options(client, MYSQL_SET_CHARSET_NAME, charset_name)) { + if (mysql_options(wrapper->client, MYSQL_SET_CHARSET_NAME, charset_name)) { /* TODO: warning - unable to set charset */ - rb_warn("%s\n", mysql_error(client)); + rb_warn("%s\n", mysql_error(wrapper->client)); } return value; @@ -485,7 +481,7 @@ static VALUE set_ssl_options(VALUE self, VALUE key, VALUE cert, VALUE ca, VALUE GET_CLIENT(self) if(!NIL_P(ca) || !NIL_P(key)) { - mysql_ssl_set(client, + mysql_ssl_set(wrapper->client, NIL_P(key) ? NULL : StringValuePtr(key), NIL_P(cert) ? NULL : StringValuePtr(cert), NIL_P(ca) ? NULL : StringValuePtr(ca), @@ -499,9 +495,9 @@ static VALUE set_ssl_options(VALUE self, VALUE key, VALUE cert, VALUE ca, VALUE static VALUE init_connection(VALUE self) { GET_CLIENT(self) - if (rb_thread_blocking_region(nogvl_init, client, RUBY_UBF_IO, 0) == Qfalse) { + if (rb_thread_blocking_region(nogvl_init, ((void *) &wrapper->client), RUBY_UBF_IO, 0) == Qfalse) { /* TODO: warning - not enough memory? */ - return rb_raise_mysql2_error(client); + return rb_raise_mysql2_error(wrapper->client); } return self; diff --git a/ext/mysql2/client.h b/ext/mysql2/client.h index 6bd9963..781ccda 100644 --- a/ext/mysql2/client.h +++ b/ext/mysql2/client.h @@ -34,7 +34,7 @@ void init_mysql2_client(); typedef struct { VALUE encoding; short int active; - MYSQL client; + MYSQL *client; } mysql_client_wrapper; #endif \ No newline at end of file