make sure we only attempt to close/free the MYSQL pointer once

This commit is contained in:
Brian Lopez 2010-09-14 22:27:07 -07:00
parent 5cadce3417
commit e7924df06a
2 changed files with 20 additions and 14 deletions

View File

@ -8,8 +8,8 @@ static VALUE intern_encoding_from_charset;
static ID sym_id, sym_version, sym_async, sym_symbolize_keys, sym_as, sym_array; static ID sym_id, sym_version, sym_async, sym_symbolize_keys, sym_as, sym_array;
static ID intern_merge, intern_error_number_eql, intern_sql_state_eql; static ID intern_merge, intern_error_number_eql, intern_sql_state_eql;
#define REQUIRE_OPEN_DB(_ctxt) \ #define REQUIRE_OPEN_DB(wrapper) \
if(!_ctxt->net.vio) { \ if(wrapper->closed || !wrapper->client->net.vio) { \
rb_raise(cMysql2Error, "closed MySQL connection"); \ rb_raise(cMysql2Error, "closed MySQL connection"); \
return Qnil; \ return Qnil; \
} }
@ -130,14 +130,18 @@ static void rb_mysql_client_free(void * ptr) {
} }
/* It's safe to call mysql_close() on an already closed connection. */ /* It's safe to call mysql_close() on an already closed connection. */
mysql_close(wrapper->client); if (!wrapper->closed) {
mysql_close(wrapper->client);
}
xfree(ptr); xfree(ptr);
} }
static VALUE nogvl_close(void * ptr) { static VALUE nogvl_close(void * ptr) {
MYSQL *client = (MYSQL *)ptr; mysql_client_wrapper *wrapper = ptr;
mysql_close(client); if (!wrapper->closed) {
client->net.fd = -1; mysql_close(wrapper->client);
wrapper->closed = 1;
}
return Qnil; return Qnil;
} }
@ -147,6 +151,7 @@ static VALUE allocate(VALUE klass) {
obj = Data_Make_Struct(klass, mysql_client_wrapper, rb_mysql_client_mark, rb_mysql_client_free, wrapper); obj = Data_Make_Struct(klass, mysql_client_wrapper, rb_mysql_client_mark, rb_mysql_client_free, wrapper);
wrapper->encoding = Qnil; wrapper->encoding = Qnil;
wrapper->active = 0; wrapper->active = 0;
wrapper->closed = 0;
return obj; return obj;
} }
@ -180,7 +185,7 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po
static VALUE rb_mysql_client_close(VALUE self) { static VALUE rb_mysql_client_close(VALUE self) {
GET_CLIENT(self); GET_CLIENT(self);
rb_thread_blocking_region(nogvl_close, wrapper->client, RUBY_UBF_IO, 0); rb_thread_blocking_region(nogvl_close, wrapper, RUBY_UBF_IO, 0);
return Qnil; return Qnil;
} }
@ -223,7 +228,7 @@ static VALUE rb_mysql_client_async_result(VALUE self) {
MYSQL_RES * result; MYSQL_RES * result;
GET_CLIENT(self); GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper->client); REQUIRE_OPEN_DB(wrapper);
if (rb_thread_blocking_region(nogvl_read_query_result, wrapper->client, RUBY_UBF_IO, 0) == Qfalse) { if (rb_thread_blocking_region(nogvl_read_query_result, wrapper->client, RUBY_UBF_IO, 0) == Qfalse) {
// an error occurred, mark this connection inactive // an error occurred, mark this connection inactive
MARK_CONN_INACTIVE(self); MARK_CONN_INACTIVE(self);
@ -262,7 +267,7 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) {
VALUE opts, defaults; VALUE opts, defaults;
GET_CLIENT(self); GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper->client); REQUIRE_OPEN_DB(wrapper);
args.mysql = wrapper->client; args.mysql = wrapper->client;
// see if this connection is still waiting on a result from a previous query // see if this connection is still waiting on a result from a previous query
@ -340,7 +345,7 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) {
oldLen = RSTRING_LEN(str); oldLen = RSTRING_LEN(str);
newStr = rb_str_new(0, oldLen*2+1); newStr = rb_str_new(0, oldLen*2+1);
REQUIRE_OPEN_DB(wrapper->client); REQUIRE_OPEN_DB(wrapper);
newLen = mysql_real_escape_string(wrapper->client, RSTRING_PTR(newStr), StringValuePtr(str), oldLen); newLen = mysql_real_escape_string(wrapper->client, RSTRING_PTR(newStr), StringValuePtr(str), oldLen);
if (newLen == oldLen) { if (newLen == oldLen) {
// no need to return a new ruby string if nothing changed // no need to return a new ruby string if nothing changed
@ -385,7 +390,7 @@ static VALUE rb_mysql_client_server_info(VALUE self) {
rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding); rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding);
#endif #endif
REQUIRE_OPEN_DB(wrapper->client); REQUIRE_OPEN_DB(wrapper);
version = rb_hash_new(); version = rb_hash_new();
rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(wrapper->client))); rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(wrapper->client)));
@ -402,19 +407,19 @@ static VALUE rb_mysql_client_server_info(VALUE self) {
static VALUE rb_mysql_client_socket(VALUE self) { static VALUE rb_mysql_client_socket(VALUE self) {
GET_CLIENT(self); GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper->client); REQUIRE_OPEN_DB(wrapper);
return INT2NUM(wrapper->client->net.fd); return INT2NUM(wrapper->client->net.fd);
} }
static VALUE rb_mysql_client_last_id(VALUE self) { static VALUE rb_mysql_client_last_id(VALUE self) {
GET_CLIENT(self); GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper->client); REQUIRE_OPEN_DB(wrapper);
return ULL2NUM(mysql_insert_id(wrapper->client)); return ULL2NUM(mysql_insert_id(wrapper->client));
} }
static VALUE rb_mysql_client_affected_rows(VALUE self) { static VALUE rb_mysql_client_affected_rows(VALUE self) {
GET_CLIENT(self); GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper->client); REQUIRE_OPEN_DB(wrapper);
return ULL2NUM(mysql_affected_rows(wrapper->client)); return ULL2NUM(mysql_affected_rows(wrapper->client));
} }

View File

@ -34,6 +34,7 @@ void init_mysql2_client();
typedef struct { typedef struct {
VALUE encoding; VALUE encoding;
short int active; short int active;
short int closed;
MYSQL *client; MYSQL *client;
} mysql_client_wrapper; } mysql_client_wrapper;