remove mysql2_client_wrapper struct, refactor open connection requirement to a macro

This commit is contained in:
Aaron Patterson 2010-07-01 10:01:19 -07:00
parent 8c96aa1fcd
commit ff55ef5c87
2 changed files with 72 additions and 89 deletions

View File

@ -1,5 +1,11 @@
#include "mysql2_ext.h" #include "mysql2_ext.h"
#define REQUIRE_OPEN_DB(_ctxt) \
if(!_ctxt->net.vio) { \
rb_raise(cMysql2Error, "closed MySQL connection"); \
return Qnil; \
}
/* /*
* non-blocking mysql_*() functions that we won't be wrapping since * non-blocking mysql_*() functions that we won't be wrapping since
* they do not appear to hit the network nor issue any interruptible * they do not appear to hit the network nor issue any interruptible
@ -23,12 +29,12 @@
*/ */
static VALUE nogvl_init(void *ptr) { static VALUE nogvl_init(void *ptr) {
struct nogvl_connect_args *args = ptr; MYSQL * client = (MYSQL *)ptr;
/* may initialize embedded server and read /etc/services off disk */ /* may initialize embedded server and read /etc/services off disk */
args->mysql = mysql_init(NULL); mysql_init(client);
return args->mysql == NULL ? Qfalse : Qtrue; return client ? Qtrue : Qfalse;
} }
static VALUE nogvl_connect(void *ptr) { static VALUE nogvl_connect(void *ptr) {
@ -45,11 +51,11 @@ static VALUE nogvl_connect(void *ptr) {
static VALUE allocate(VALUE klass) static VALUE allocate(VALUE klass)
{ {
mysql2_client_wrapper * client; MYSQL * client;
return Data_Make_Struct( return Data_Make_Struct(
klass, klass,
mysql2_client_wrapper, MYSQL,
NULL, NULL,
rb_mysql_client_free, rb_mysql_client_free,
client client
@ -58,7 +64,7 @@ static VALUE allocate(VALUE klass)
/* Mysql2::Client */ /* Mysql2::Client */
static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) { static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) {
mysql2_client_wrapper * client; MYSQL * client;
struct nogvl_connect_args args = { struct nogvl_connect_args args = {
.host = "localhost", .host = "localhost",
.user = NULL, .user = NULL,
@ -79,8 +85,7 @@ static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) {
unsigned int connect_timeout = 0; unsigned int connect_timeout = 0;
my_bool reconnect = 1; my_bool reconnect = 1;
/* FIXME: refactor this to not use mysql2_client_wrapper */ Data_Get_Struct(self, MYSQL, client);
Data_Get_Struct(self, mysql2_client_wrapper, client);
if (rb_scan_args(argc, argv, "01", &opts) == 1) { if (rb_scan_args(argc, argv, "01", &opts) == 1) {
Check_Type(opts, T_HASH); Check_Type(opts, T_HASH);
@ -151,67 +156,65 @@ static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) {
} }
} }
if (rb_thread_blocking_region(nogvl_init, &args, RUBY_UBF_IO, 0) == Qfalse) { if (rb_thread_blocking_region(nogvl_init, client, RUBY_UBF_IO, 0) == Qfalse) {
// TODO: warning - not enough memory? // TODO: warning - not enough memory?
return rb_raise_mysql2_error(args.mysql); return rb_raise_mysql2_error(client);
} }
// set default reconnect behavior // set default reconnect behavior
if (mysql_options(args.mysql, MYSQL_OPT_RECONNECT, &reconnect) != 0) { if (mysql_options(client, MYSQL_OPT_RECONNECT, &reconnect) != 0) {
// TODO: warning - unable to set reconnect behavior // TODO: warning - unable to set reconnect behavior
rb_warn("%s\n", mysql_error(args.mysql)); rb_warn("%s\n", mysql_error(client));
} }
// set default connection timeout behavior // set default connection timeout behavior
if (connect_timeout != 0 && mysql_options(args.mysql, MYSQL_OPT_CONNECT_TIMEOUT, (const char *)&connect_timeout) != 0) { if (connect_timeout != 0 && mysql_options(client, MYSQL_OPT_CONNECT_TIMEOUT, (const char *)&connect_timeout) != 0) {
// TODO: warning - unable to set connection timeout // TODO: warning - unable to set connection timeout
rb_warn("%s\n", mysql_error(args.mysql)); rb_warn("%s\n", mysql_error(client));
} }
// force the encoding to utf8 // force the encoding to utf8
if (mysql_options(args.mysql, MYSQL_SET_CHARSET_NAME, "utf8") != 0) { if (mysql_options(client, MYSQL_SET_CHARSET_NAME, "utf8") != 0) {
// TODO: warning - unable to set charset // TODO: warning - unable to set charset
rb_warn("%s\n", mysql_error(args.mysql)); rb_warn("%s\n", mysql_error(client));
} }
if (ssl_ca_cert != NULL || ssl_client_key != NULL) { if (ssl_ca_cert != NULL || ssl_client_key != NULL) {
mysql_ssl_set(args.mysql, ssl_client_key, ssl_client_cert, ssl_ca_cert, ssl_ca_path, ssl_cipher); mysql_ssl_set(client, ssl_client_key, ssl_client_cert, ssl_ca_cert, ssl_ca_path, ssl_cipher);
} }
args.mysql = client;
if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) { if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) {
// unable to connect // unable to connect
return rb_raise_mysql2_error(args.mysql);; return rb_raise_mysql2_error(client);
} }
client->client = args.mysql;
return self; return self;
} }
static void rb_mysql_client_free(void * ptr) { static void rb_mysql_client_free(void * ptr) {
mysql2_client_wrapper * client = ptr; MYSQL * client = (MYSQL *)ptr;
if (client->client) { /*
* we'll send a QUIT message to the server, but that message is more of a
* formality than a hard requirement since the socket is getting shutdown
* anyways, so ensure the socket write does not block our interpreter
*/
int fd = client->net.fd;
int flags;
if (fd >= 0) {
/* /*
* we'll send a QUIT message to the server, but that message is more of a * if the socket is dead we have no chance of blocking,
* formality than a hard requirement since the socket is getting shutdown * so ignore any potential fcntl errors since they don't matter
* anyways, so ensure the socket write does not block our interpreter
*/ */
int fd = client->client->net.fd; flags = fcntl(fd, F_GETFL);
int flags; if (flags > 0 && !(flags & O_NONBLOCK))
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
if (fd >= 0) {
/*
* if the socket is dead we have no chance of blocking,
* so ignore any potential fcntl errors since they don't matter
*/
flags = fcntl(fd, F_GETFL);
if (flags > 0 && !(flags & O_NONBLOCK))
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
mysql_close(client->client);
} }
/* It's safe to call mysql_close() on an already closed connection. */
mysql_close(client);
xfree(ptr); xfree(ptr);
} }
@ -227,16 +230,13 @@ static VALUE nogvl_close(void * ptr) {
* for the garbage collector. * for the garbage collector.
*/ */
static VALUE rb_mysql_client_close(VALUE self) { static VALUE rb_mysql_client_close(VALUE self) {
mysql2_client_wrapper *client; MYSQL *client;
Data_Get_Struct(self, mysql2_client_wrapper, client); Data_Get_Struct(self, MYSQL, client);
REQUIRE_OPEN_DB(client);
rb_thread_blocking_region(nogvl_close, client, RUBY_UBF_IO, 0);
if (client->client) {
rb_thread_blocking_region(nogvl_close, client->client, RUBY_UBF_IO, 0);
client->client = NULL;
} else {
rb_raise(cMysql2Error, "already closed MySQL connection");
}
return Qnil; return Qnil;
} }
@ -264,6 +264,8 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) {
VALUE opts; VALUE opts;
VALUE rb_async; VALUE rb_async;
MYSQL * client;
if (rb_scan_args(argc, argv, "11", &args.sql, &opts) == 2) { if (rb_scan_args(argc, argv, "11", &args.sql, &opts) == 2) {
if ((rb_async = rb_hash_aref(opts, sym_async)) != Qnil) { if ((rb_async = rb_hash_aref(opts, sym_async)) != Qnil) {
async = rb_async == Qtrue ? 1 : 0; async = rb_async == Qtrue ? 1 : 0;
@ -271,20 +273,19 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) {
} }
Check_Type(args.sql, T_STRING); Check_Type(args.sql, T_STRING);
Data_Get_Struct(self, MYSQL, client);
GetMysql2Client(self, args.mysql); REQUIRE_OPEN_DB(client);
if (!args.mysql) {
rb_raise(cMysql2Error, "closed MySQL connection"); args.mysql = client;
return Qnil;
}
if (rb_thread_blocking_region(nogvl_send_query, &args, RUBY_UBF_IO, 0) == Qfalse) { if (rb_thread_blocking_region(nogvl_send_query, &args, RUBY_UBF_IO, 0) == Qfalse) {
return rb_raise_mysql2_error(args.mysql);; return rb_raise_mysql2_error(client);
} }
if (!async) { if (!async) {
// the below code is largely from do_mysql // the below code is largely from do_mysql
// http://github.com/datamapper/do // http://github.com/datamapper/do
fd = args.mysql->net.fd; fd = client->net.fd;
for(;;) { for(;;) {
FD_ZERO(&fdset); FD_ZERO(&fdset);
FD_SET(fd, &fdset); FD_SET(fd, &fdset);
@ -318,11 +319,9 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) {
oldLen = RSTRING_LEN(str); oldLen = RSTRING_LEN(str);
char escaped[(oldLen*2)+1]; char escaped[(oldLen*2)+1];
GetMysql2Client(self, client); Data_Get_Struct(self, MYSQL, client);
if (!client) {
rb_raise(cMysql2Error, "closed MySQL connection"); REQUIRE_OPEN_DB(client);
return Qnil;
}
newLen = mysql_real_escape_string(client, escaped, RSTRING_PTR(str), RSTRING_LEN(str)); newLen = mysql_real_escape_string(client, escaped, RSTRING_PTR(str), RSTRING_LEN(str));
if (newLen == oldLen) { if (newLen == oldLen) {
// no need to return a new ruby string if nothing changed // no need to return a new ruby string if nothing changed
@ -364,11 +363,9 @@ static VALUE rb_mysql_client_server_info(VALUE self) {
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
#endif #endif
GetMysql2Client(self, client); Data_Get_Struct(self, MYSQL, client);
if (!client) { REQUIRE_OPEN_DB(client);
rb_raise(cMysql2Error, "closed MySQL connection");
return Qnil;
}
version = rb_hash_new(); version = rb_hash_new();
rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(client))); rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(client)));
server_info = rb_str_new2(mysql_get_server_info(client)); server_info = rb_str_new2(mysql_get_server_info(client));
@ -383,11 +380,9 @@ static VALUE rb_mysql_client_server_info(VALUE self) {
} }
static VALUE rb_mysql_client_socket(VALUE self) { static VALUE rb_mysql_client_socket(VALUE self) {
MYSQL * client = GetMysql2Client(self, client); MYSQL * client;
if (!client) { Data_Get_Struct(self, MYSQL, client);
rb_raise(cMysql2Error, "closed MySQL connection"); REQUIRE_OPEN_DB(client);
return Qnil;
}
return INT2NUM(client->net.fd); return INT2NUM(client->net.fd);
} }
@ -412,11 +407,10 @@ static VALUE nogvl_store_result(void *ptr) {
static VALUE rb_mysql_client_async_result(VALUE self) { static VALUE rb_mysql_client_async_result(VALUE self) {
MYSQL * client; MYSQL * client;
MYSQL_RES * result; MYSQL_RES * result;
GetMysql2Client(self, client);
if (!client) { Data_Get_Struct(self, MYSQL, client);
rb_raise(cMysql2Error, "closed MySQL connection");
return Qnil; REQUIRE_OPEN_DB(client);
}
if (rb_thread_blocking_region(nogvl_read_query_result, client, RUBY_UBF_IO, 0) == Qfalse) { if (rb_thread_blocking_region(nogvl_read_query_result, client, RUBY_UBF_IO, 0) == Qfalse) {
return rb_raise_mysql2_error(client); return rb_raise_mysql2_error(client);
} }
@ -434,21 +428,15 @@ static VALUE rb_mysql_client_async_result(VALUE self) {
static VALUE rb_mysql_client_last_id(VALUE self) { static VALUE rb_mysql_client_last_id(VALUE self) {
MYSQL * client; MYSQL * client;
GetMysql2Client(self, client); Data_Get_Struct(self, MYSQL, client);
if (!client) { REQUIRE_OPEN_DB(client);
rb_raise(cMysql2Error, "closed MySQL connection");
return Qnil;
}
return ULL2NUM(mysql_insert_id(client)); return ULL2NUM(mysql_insert_id(client));
} }
static VALUE rb_mysql_client_affected_rows(VALUE self) { static VALUE rb_mysql_client_affected_rows(VALUE self) {
MYSQL * client; MYSQL * client;
GetMysql2Client(self, client); Data_Get_Struct(self, MYSQL, client);
if (!client) { REQUIRE_OPEN_DB(client);
rb_raise(cMysql2Error, "closed MySQL connection");
return Qnil;
}
return ULL2NUM(mysql_affected_rows(client)); return ULL2NUM(mysql_affected_rows(client));
} }

View File

@ -30,11 +30,6 @@ static ID intern_new, intern_utc;
/* Mysql2::Error */ /* Mysql2::Error */
static VALUE cMysql2Error; static VALUE cMysql2Error;
/* Mysql2::Client */
typedef struct {
MYSQL * client;
} mysql2_client_wrapper;
#define GetMysql2Client(obj, sval) (sval = ((mysql2_client_wrapper*)(DATA_PTR(obj)))->client);
static ID sym_socket, sym_host, sym_port, sym_username, sym_password, static ID sym_socket, sym_host, sym_port, sym_username, sym_password,
sym_database, sym_reconnect, sym_connect_timeout, sym_id, sym_version, sym_database, sym_reconnect, sym_connect_timeout, sym_id, sym_version,
sym_sslkey, sym_sslcert, sym_sslca, sym_sslcapath, sym_sslcipher, sym_sslkey, sym_sslcert, sym_sslca, sym_sslcapath, sym_sslcipher,