diff --git a/ext/mysql2/mysql2_ext.c b/ext/mysql2/mysql2_ext.c index fe6e087..83a4831 100644 --- a/ext/mysql2/mysql2_ext.c +++ b/ext/mysql2/mysql2_ext.c @@ -1,5 +1,11 @@ #include "mysql2_ext.h" +#define REQUIRE_OPEN_DB(_ctxt) \ + if(!_ctxt->net.vio) { \ + rb_raise(cMysql2Error, "closed MySQL connection"); \ + return Qnil; \ + } + /* * non-blocking mysql_*() functions that we won't be wrapping since * they do not appear to hit the network nor issue any interruptible @@ -23,12 +29,12 @@ */ static VALUE nogvl_init(void *ptr) { - struct nogvl_connect_args *args = ptr; + MYSQL * client = (MYSQL *)ptr; /* may initialize embedded server and read /etc/services off disk */ - args->mysql = mysql_init(NULL); + mysql_init(client); - return args->mysql == NULL ? Qfalse : Qtrue; + return client ? Qtrue : Qfalse; } static VALUE nogvl_connect(void *ptr) { @@ -45,11 +51,11 @@ static VALUE nogvl_connect(void *ptr) { static VALUE allocate(VALUE klass) { - mysql2_client_wrapper * client; + MYSQL * client; return Data_Make_Struct( klass, - mysql2_client_wrapper, + MYSQL, NULL, rb_mysql_client_free, client @@ -58,7 +64,7 @@ static VALUE allocate(VALUE klass) /* Mysql2::Client */ static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) { - mysql2_client_wrapper * client; + MYSQL * client; struct nogvl_connect_args args = { .host = "localhost", .user = NULL, @@ -79,8 +85,7 @@ static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) { unsigned int connect_timeout = 0; my_bool reconnect = 1; - /* FIXME: refactor this to not use mysql2_client_wrapper */ - Data_Get_Struct(self, mysql2_client_wrapper, client); + Data_Get_Struct(self, MYSQL, client); if (rb_scan_args(argc, argv, "01", &opts) == 1) { Check_Type(opts, T_HASH); @@ -151,67 +156,65 @@ static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) { } } - if (rb_thread_blocking_region(nogvl_init, &args, RUBY_UBF_IO, 0) == Qfalse) { + if (rb_thread_blocking_region(nogvl_init, client, RUBY_UBF_IO, 0) == Qfalse) { // TODO: warning - not enough memory? - return rb_raise_mysql2_error(args.mysql); + return rb_raise_mysql2_error(client); } // set default reconnect behavior - if (mysql_options(args.mysql, MYSQL_OPT_RECONNECT, &reconnect) != 0) { + if (mysql_options(client, MYSQL_OPT_RECONNECT, &reconnect) != 0) { // TODO: warning - unable to set reconnect behavior - rb_warn("%s\n", mysql_error(args.mysql)); + rb_warn("%s\n", mysql_error(client)); } // set default connection timeout behavior - if (connect_timeout != 0 && mysql_options(args.mysql, MYSQL_OPT_CONNECT_TIMEOUT, (const char *)&connect_timeout) != 0) { + if (connect_timeout != 0 && mysql_options(client, MYSQL_OPT_CONNECT_TIMEOUT, (const char *)&connect_timeout) != 0) { // TODO: warning - unable to set connection timeout - rb_warn("%s\n", mysql_error(args.mysql)); + rb_warn("%s\n", mysql_error(client)); } // force the encoding to utf8 - if (mysql_options(args.mysql, MYSQL_SET_CHARSET_NAME, "utf8") != 0) { + if (mysql_options(client, MYSQL_SET_CHARSET_NAME, "utf8") != 0) { // TODO: warning - unable to set charset - rb_warn("%s\n", mysql_error(args.mysql)); + rb_warn("%s\n", mysql_error(client)); } if (ssl_ca_cert != NULL || ssl_client_key != NULL) { - mysql_ssl_set(args.mysql, ssl_client_key, ssl_client_cert, ssl_ca_cert, ssl_ca_path, ssl_cipher); + mysql_ssl_set(client, ssl_client_key, ssl_client_cert, ssl_ca_cert, ssl_ca_path, ssl_cipher); } + args.mysql = client; if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) { // unable to connect - return rb_raise_mysql2_error(args.mysql);; + return rb_raise_mysql2_error(client); } - client->client = args.mysql; - return self; } static void rb_mysql_client_free(void * ptr) { - mysql2_client_wrapper * client = ptr; + MYSQL * client = (MYSQL *)ptr; - if (client->client) { + /* + * we'll send a QUIT message to the server, but that message is more of a + * formality than a hard requirement since the socket is getting shutdown + * anyways, so ensure the socket write does not block our interpreter + */ + int fd = client->net.fd; + int flags; + + if (fd >= 0) { /* - * we'll send a QUIT message to the server, but that message is more of a - * formality than a hard requirement since the socket is getting shutdown - * anyways, so ensure the socket write does not block our interpreter + * if the socket is dead we have no chance of blocking, + * so ignore any potential fcntl errors since they don't matter */ - int fd = client->client->net.fd; - int flags; - - if (fd >= 0) { - /* - * if the socket is dead we have no chance of blocking, - * so ignore any potential fcntl errors since they don't matter - */ - flags = fcntl(fd, F_GETFL); - if (flags > 0 && !(flags & O_NONBLOCK)) - fcntl(fd, F_SETFL, flags | O_NONBLOCK); - } - - mysql_close(client->client); + flags = fcntl(fd, F_GETFL); + if (flags > 0 && !(flags & O_NONBLOCK)) + fcntl(fd, F_SETFL, flags | O_NONBLOCK); } + + /* It's safe to call mysql_close() on an already closed connection. */ + mysql_close(client); xfree(ptr); } @@ -227,16 +230,13 @@ static VALUE nogvl_close(void * ptr) { * for the garbage collector. */ static VALUE rb_mysql_client_close(VALUE self) { - mysql2_client_wrapper *client; + MYSQL *client; - Data_Get_Struct(self, mysql2_client_wrapper, client); + Data_Get_Struct(self, MYSQL, client); + + REQUIRE_OPEN_DB(client); + rb_thread_blocking_region(nogvl_close, client, RUBY_UBF_IO, 0); - if (client->client) { - rb_thread_blocking_region(nogvl_close, client->client, RUBY_UBF_IO, 0); - client->client = NULL; - } else { - rb_raise(cMysql2Error, "already closed MySQL connection"); - } return Qnil; } @@ -264,6 +264,8 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { VALUE opts; VALUE rb_async; + MYSQL * client; + if (rb_scan_args(argc, argv, "11", &args.sql, &opts) == 2) { if ((rb_async = rb_hash_aref(opts, sym_async)) != Qnil) { async = rb_async == Qtrue ? 1 : 0; @@ -271,20 +273,19 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { } Check_Type(args.sql, T_STRING); + Data_Get_Struct(self, MYSQL, client); - GetMysql2Client(self, args.mysql); - if (!args.mysql) { - rb_raise(cMysql2Error, "closed MySQL connection"); - return Qnil; - } + REQUIRE_OPEN_DB(client); + + args.mysql = client; if (rb_thread_blocking_region(nogvl_send_query, &args, RUBY_UBF_IO, 0) == Qfalse) { - return rb_raise_mysql2_error(args.mysql);; + return rb_raise_mysql2_error(client); } if (!async) { // the below code is largely from do_mysql // http://github.com/datamapper/do - fd = args.mysql->net.fd; + fd = client->net.fd; for(;;) { FD_ZERO(&fdset); FD_SET(fd, &fdset); @@ -318,11 +319,9 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { oldLen = RSTRING_LEN(str); char escaped[(oldLen*2)+1]; - GetMysql2Client(self, client); - if (!client) { - rb_raise(cMysql2Error, "closed MySQL connection"); - return Qnil; - } + Data_Get_Struct(self, MYSQL, client); + + REQUIRE_OPEN_DB(client); newLen = mysql_real_escape_string(client, escaped, RSTRING_PTR(str), RSTRING_LEN(str)); if (newLen == oldLen) { // no need to return a new ruby string if nothing changed @@ -364,11 +363,9 @@ static VALUE rb_mysql_client_server_info(VALUE self) { rb_encoding *default_internal_enc = rb_default_internal_encoding(); #endif - GetMysql2Client(self, client); - if (!client) { - rb_raise(cMysql2Error, "closed MySQL connection"); - return Qnil; - } + Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); + version = rb_hash_new(); rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(client))); server_info = rb_str_new2(mysql_get_server_info(client)); @@ -383,11 +380,9 @@ static VALUE rb_mysql_client_server_info(VALUE self) { } static VALUE rb_mysql_client_socket(VALUE self) { - MYSQL * client = GetMysql2Client(self, client); - if (!client) { - rb_raise(cMysql2Error, "closed MySQL connection"); - return Qnil; - } + MYSQL * client; + Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); return INT2NUM(client->net.fd); } @@ -412,11 +407,10 @@ static VALUE nogvl_store_result(void *ptr) { static VALUE rb_mysql_client_async_result(VALUE self) { MYSQL * client; MYSQL_RES * result; - GetMysql2Client(self, client); - if (!client) { - rb_raise(cMysql2Error, "closed MySQL connection"); - return Qnil; - } + + Data_Get_Struct(self, MYSQL, client); + + REQUIRE_OPEN_DB(client); if (rb_thread_blocking_region(nogvl_read_query_result, client, RUBY_UBF_IO, 0) == Qfalse) { return rb_raise_mysql2_error(client); } @@ -434,21 +428,15 @@ static VALUE rb_mysql_client_async_result(VALUE self) { static VALUE rb_mysql_client_last_id(VALUE self) { MYSQL * client; - GetMysql2Client(self, client); - if (!client) { - rb_raise(cMysql2Error, "closed MySQL connection"); - return Qnil; - } + Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); return ULL2NUM(mysql_insert_id(client)); } static VALUE rb_mysql_client_affected_rows(VALUE self) { MYSQL * client; - GetMysql2Client(self, client); - if (!client) { - rb_raise(cMysql2Error, "closed MySQL connection"); - return Qnil; - } + Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); return ULL2NUM(mysql_affected_rows(client)); } diff --git a/ext/mysql2/mysql2_ext.h b/ext/mysql2/mysql2_ext.h index 503b65d..3580dbb 100644 --- a/ext/mysql2/mysql2_ext.h +++ b/ext/mysql2/mysql2_ext.h @@ -30,11 +30,6 @@ static ID intern_new, intern_utc; /* Mysql2::Error */ static VALUE cMysql2Error; -/* Mysql2::Client */ -typedef struct { - MYSQL * client; -} mysql2_client_wrapper; -#define GetMysql2Client(obj, sval) (sval = ((mysql2_client_wrapper*)(DATA_PTR(obj)))->client); static ID sym_socket, sym_host, sym_port, sym_username, sym_password, sym_database, sym_reconnect, sym_connect_timeout, sym_id, sym_version, sym_sslkey, sym_sslcert, sym_sslca, sym_sslcapath, sym_sslcipher,