Merge branch 'master' into stmt

* master:
  avoid potential race-condition with closing a connection
  add option for setting the wait_timeout in the AR adapter (this can be done in database.yml)
  add some more defaults to the connect flags
  add connect_flags to default options and add REMEMBER_OPTIONS to that list. fix NUM2INT to be NUM2ULONG as it should be for flags
  free the client after close if we can
  forgot to remove this
  get rid of double-pointer casting
  a couple of minor updates to connection management with some specs
  check for error from mysql_affected_rows call
  change connection check symantecs
This commit is contained in:
Brian Lopez 2010-10-15 16:19:11 -07:00
commit ddcd1064cd
5 changed files with 95 additions and 34 deletions

View File

@ -9,7 +9,7 @@ static ID sym_id, sym_version, sym_async, sym_symbolize_keys, sym_as, sym_array;
static ID intern_merge, intern_error_number_eql, intern_sql_state_eql; static ID intern_merge, intern_error_number_eql, intern_sql_state_eql;
#define REQUIRE_OPEN_DB(wrapper) \ #define REQUIRE_OPEN_DB(wrapper) \
if(wrapper->closed || !wrapper->client->net.vio) { \ if(wrapper->closed) { \
rb_raise(cMysql2Error, "closed MySQL connection"); \ rb_raise(cMysql2Error, "closed MySQL connection"); \
return Qnil; \ return Qnil; \
} }
@ -83,11 +83,11 @@ static VALUE rb_raise_mysql2_error(MYSQL *client) {
} }
static VALUE nogvl_init(void *ptr) { static VALUE nogvl_init(void *ptr) {
MYSQL **client = (MYSQL **)ptr; MYSQL *client;
/* may initialize embedded server and read /etc/services off disk */ /* may initialize embedded server and read /etc/services off disk */
*client = mysql_init(NULL); client = mysql_init((MYSQL *)ptr);
return *client ? Qtrue : Qfalse; return client ? Qtrue : Qfalse;
} }
static VALUE nogvl_connect(void *ptr) { static VALUE nogvl_connect(void *ptr) {
@ -132,6 +132,9 @@ static void rb_mysql_client_free(void * ptr) {
/* It's safe to call mysql_close() on an already closed connection. */ /* It's safe to call mysql_close() on an already closed connection. */
if (!wrapper->closed) { if (!wrapper->closed) {
mysql_close(wrapper->client); mysql_close(wrapper->client);
if (!wrapper->freed) {
free(wrapper->client);
}
} }
xfree(ptr); xfree(ptr);
} }
@ -139,9 +142,11 @@ static void rb_mysql_client_free(void * ptr) {
static VALUE nogvl_close(void * ptr) { static VALUE nogvl_close(void * ptr) {
mysql_client_wrapper *wrapper = ptr; mysql_client_wrapper *wrapper = ptr;
if (!wrapper->closed) { if (!wrapper->closed) {
mysql_close(wrapper->client);
wrapper->client->net.fd = -1;
wrapper->closed = 1; wrapper->closed = 1;
mysql_close(wrapper->client);
if (!wrapper->freed) {
free(wrapper->client);
}
} }
return Qnil; return Qnil;
} }
@ -153,6 +158,8 @@ static VALUE allocate(VALUE klass) {
wrapper->encoding = Qnil; wrapper->encoding = Qnil;
wrapper->active = 0; wrapper->active = 0;
wrapper->closed = 0; wrapper->closed = 0;
wrapper->freed = 0;
wrapper->client = (MYSQL*)malloc(sizeof(MYSQL));
return obj; return obj;
} }
@ -167,7 +174,7 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po
args.passwd = NIL_P(pass) ? NULL : StringValuePtr(pass); args.passwd = NIL_P(pass) ? NULL : StringValuePtr(pass);
args.db = NIL_P(database) ? NULL : StringValuePtr(database); args.db = NIL_P(database) ? NULL : StringValuePtr(database);
args.mysql = wrapper->client; args.mysql = wrapper->client;
args.client_flag = NUM2INT(flags); args.client_flag = NUM2ULONG(flags);
if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) { if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) {
// unable to connect // unable to connect
@ -186,7 +193,9 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po
static VALUE rb_mysql_client_close(VALUE self) { static VALUE rb_mysql_client_close(VALUE self) {
GET_CLIENT(self); GET_CLIENT(self);
rb_thread_blocking_region(nogvl_close, wrapper, RUBY_UBF_IO, 0); if (!wrapper->closed) {
rb_thread_blocking_region(nogvl_close, wrapper, RUBY_UBF_IO, 0);
}
return Qnil; return Qnil;
} }
@ -335,6 +344,7 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) {
unsigned long newLen, oldLen; unsigned long newLen, oldLen;
GET_CLIENT(self); GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper);
Check_Type(str, T_STRING); Check_Type(str, T_STRING);
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
@ -346,7 +356,6 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) {
oldLen = RSTRING_LEN(str); oldLen = RSTRING_LEN(str);
newStr = rb_str_new(0, oldLen*2+1); newStr = rb_str_new(0, oldLen*2+1);
REQUIRE_OPEN_DB(wrapper);
newLen = mysql_real_escape_string(wrapper->client, RSTRING_PTR(newStr), StringValuePtr(str), oldLen); newLen = mysql_real_escape_string(wrapper->client, RSTRING_PTR(newStr), StringValuePtr(str), oldLen);
if (newLen == oldLen) { if (newLen == oldLen) {
// no need to return a new ruby string if nothing changed // no need to return a new ruby string if nothing changed
@ -366,6 +375,7 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) {
static VALUE rb_mysql_client_info(VALUE self) { static VALUE rb_mysql_client_info(VALUE self) {
VALUE version = rb_hash_new(), client_info; VALUE version = rb_hash_new(), client_info;
GET_CLIENT(self); GET_CLIENT(self);
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding); rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding);
@ -386,13 +396,13 @@ static VALUE rb_mysql_client_info(VALUE self) {
static VALUE rb_mysql_client_server_info(VALUE self) { static VALUE rb_mysql_client_server_info(VALUE self) {
VALUE version, server_info; VALUE version, server_info;
GET_CLIENT(self); GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper);
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding); rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding);
#endif #endif
REQUIRE_OPEN_DB(wrapper);
version = rb_hash_new(); version = rb_hash_new();
rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(wrapper->client))); rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(wrapper->client)));
server_info = rb_str_new2(mysql_get_server_info(wrapper->client)); server_info = rb_str_new2(mysql_get_server_info(wrapper->client));
@ -420,8 +430,14 @@ static VALUE rb_mysql_client_last_id(VALUE self) {
static VALUE rb_mysql_client_affected_rows(VALUE self) { static VALUE rb_mysql_client_affected_rows(VALUE self) {
GET_CLIENT(self); GET_CLIENT(self);
my_ulonglong retVal;
REQUIRE_OPEN_DB(wrapper); REQUIRE_OPEN_DB(wrapper);
return ULL2NUM(mysql_affected_rows(wrapper->client)); retVal = mysql_affected_rows(wrapper->client);
if (retVal == (my_ulonglong)-1) {
rb_raise_mysql2_error(wrapper->client);
}
return ULL2NUM(retVal);
} }
static VALUE set_reconnect(VALUE self, VALUE value) { static VALUE set_reconnect(VALUE self, VALUE value) {
@ -501,7 +517,7 @@ static VALUE set_ssl_options(VALUE self, VALUE key, VALUE cert, VALUE ca, VALUE
static VALUE init_connection(VALUE self) { static VALUE init_connection(VALUE self) {
GET_CLIENT(self); GET_CLIENT(self);
if (rb_thread_blocking_region(nogvl_init, ((void *) &wrapper->client), RUBY_UBF_IO, 0) == Qfalse) { if (rb_thread_blocking_region(nogvl_init, wrapper->client, RUBY_UBF_IO, 0) == Qfalse) {
/* TODO: warning - not enough memory? */ /* TODO: warning - not enough memory? */
return rb_raise_mysql2_error(wrapper->client); return rb_raise_mysql2_error(wrapper->client);
} }

View File

@ -35,6 +35,7 @@ typedef struct {
VALUE encoding; VALUE encoding;
short int active; short int active;
short int closed; short int closed;
short int freed;
MYSQL *client; MYSQL *client;
} mysql_client_wrapper; } mysql_client_wrapper;

View File

@ -617,8 +617,13 @@ module ActiveRecord
# Turn this off. http://dev.rubyonrails.org/ticket/6778 # Turn this off. http://dev.rubyonrails.org/ticket/6778
variable_assignments = ['SQL_AUTO_IS_NULL=0'] variable_assignments = ['SQL_AUTO_IS_NULL=0']
encoding = @config[:encoding] encoding = @config[:encoding]
# make sure we set the encoding
variable_assignments << "NAMES '#{encoding}'" if encoding variable_assignments << "NAMES '#{encoding}'" if encoding
# increase timeout so mysql server doesn't disconnect us
variable_assignments << "@@wait_timeout = #{@config[:wait_timeout] || 2592000}"
execute("SET #{variable_assignments.join(', ')}", :skip_logging) execute("SET #{variable_assignments.join(', ')}", :skip_logging)
end end

View File

@ -8,7 +8,8 @@ module Mysql2
:symbolize_keys => false, # return field names as symbols instead of strings :symbolize_keys => false, # return field names as symbols instead of strings
:database_timezone => :local, # timezone Mysql2 will assume datetime objects are stored in :database_timezone => :local, # timezone Mysql2 will assume datetime objects are stored in
:application_timezone => nil, # timezone Mysql2 will convert to before handing the object back to the caller :application_timezone => nil, # timezone Mysql2 will convert to before handing the object back to the caller
:cache_rows => true # tells Mysql2 to use it's internal row cache for results :cache_rows => true, # tells Mysql2 to use it's internal row cache for results
:connect_flags => REMEMBER_OPTIONS | LONG_PASSWORD | LONG_FLAG | TRANSACTIONS | PROTOCOL_41 | SECURE_CONNECTION
} }
def initialize(opts = {}) def initialize(opts = {})
@ -31,7 +32,7 @@ module Mysql2
port = opts[:port] || 3306 port = opts[:port] || 3306
database = opts[:database] database = opts[:database]
socket = opts[:socket] socket = opts[:socket]
flags = opts[:flags] || 0 flags = opts[:flags] ? opts[:flags] | @query_options[:connect_flags] : @query_options[:connect_flags]
connect user, pass, host, port, database, socket, flags connect user, pass, host, port, database, socket, flags
end end

View File

@ -23,10 +23,10 @@ describe Mysql2::Client do
end end
end end
client = klient.new :flags => Mysql2::Client::FOUND_ROWS client = klient.new :flags => Mysql2::Client::FOUND_ROWS
client.connect_args.last.last.should == Mysql2::Client::FOUND_ROWS (client.connect_args.last.last & Mysql2::Client::FOUND_ROWS).should be_true
end end
it "should default flags to 0" do it "should default flags to (REMEMBER_OPTIONS, LONG_PASSWORD, LONG_FLAG, TRANSACTIONS, PROTOCOL_41, SECURE_CONNECTION)" do
klient = Class.new(Mysql2::Client) do klient = Class.new(Mysql2::Client) do
attr_reader :connect_args attr_reader :connect_args
def connect *args def connect *args
@ -35,7 +35,12 @@ describe Mysql2::Client do
end end
end end
client = klient.new client = klient.new
client.connect_args.last.last.should == 0 (client.connect_args.last.last & (Mysql2::Client::REMEMBER_OPTIONS |
Mysql2::Client::LONG_PASSWORD |
Mysql2::Client::LONG_FLAG |
Mysql2::Client::TRANSACTIONS |
Mysql2::Client::PROTOCOL_41 |
Mysql2::Client::SECURE_CONNECTION)).should be_true
end end
it "should have a global default_query_options hash" do it "should have a global default_query_options hash" do
@ -71,6 +76,9 @@ describe Mysql2::Client do
it "should be able to close properly" do it "should be able to close properly" do
@client.close.should be_nil @client.close.should be_nil
lambda {
@client.query "SELECT 1"
}.should raise_error(Mysql2::Error)
end end
it "should respond to #query" do it "should respond to #query" do
@ -103,6 +111,13 @@ describe Mysql2::Client do
}.should raise_error(Mysql2::Error) }.should raise_error(Mysql2::Error)
end end
it "should require an open connection" do
@client.close
lambda {
@client.query "SELECT 1"
}.should raise_error(Mysql2::Error)
end
# XXX this test is not deterministic (because Unix signal handling is not) # XXX this test is not deterministic (because Unix signal handling is not)
# and may fail on a loaded system # and may fail on a loaded system
if RUBY_PLATFORM !~ /mingw|mswin/ if RUBY_PLATFORM !~ /mingw|mswin/
@ -137,25 +152,34 @@ describe Mysql2::Client do
@client.should respond_to(:escape) @client.should respond_to(:escape)
end end
it "#escape should return a new SQL-escape version of the passed string" do context "#escape" do
@client.escape("abc'def\"ghi\0jkl%mno").should eql("abc\\'def\\\"ghi\\0jkl%mno") it "should return a new SQL-escape version of the passed string" do
end @client.escape("abc'def\"ghi\0jkl%mno").should eql("abc\\'def\\\"ghi\\0jkl%mno")
end
it "#escape should return the passed string if nothing was escaped" do it "should return the passed string if nothing was escaped" do
str = "plain" str = "plain"
@client.escape(str).object_id.should eql(str.object_id) @client.escape(str).object_id.should eql(str.object_id)
end end
it "#escape should not overflow the thread stack" do it "should not overflow the thread stack" do
lambda { lambda {
Thread.new { @client.escape("'" * 256 * 1024) }.join Thread.new { @client.escape("'" * 256 * 1024) }.join
}.should_not raise_error(SystemStackError) }.should_not raise_error(SystemStackError)
end end
it "#escape should not overflow the process stack" do it "should not overflow the process stack" do
lambda { lambda {
Thread.new { @client.escape("'" * 1024 * 1024 * 4) }.join Thread.new { @client.escape("'" * 1024 * 1024 * 4) }.join
}.should_not raise_error(SystemStackError) }.should_not raise_error(SystemStackError)
end
it "should require an open connection" do
@client.close
lambda {
@client.escape ""
}.should raise_error(Mysql2::Error)
end
end end
it "should respond to #info" do it "should respond to #info" do
@ -203,6 +227,13 @@ describe Mysql2::Client do
server_info[:version].class.should eql(String) server_info[:version].class.should eql(String)
end end
it "#server_info should require an open connection" do
@client.close
lambda {
@client.server_info
}.should raise_error(Mysql2::Error)
end
if defined? Encoding if defined? Encoding
context "strings returned by #server_info" do context "strings returned by #server_info" do
it "should default to the connection's encoding if Encoding.default_internal is nil" do it "should default to the connection's encoding if Encoding.default_internal is nil" do
@ -231,6 +262,13 @@ describe Mysql2::Client do
@client.socket.should_not eql(0) @client.socket.should_not eql(0)
end end
it "#socket should require an open connection" do
@client.close
lambda {
@client.socket
}.should raise_error(Mysql2::Error)
end
it "should raise a Mysql2::Error exception upon connection failure" do it "should raise a Mysql2::Error exception upon connection failure" do
lambda { lambda {
bad_client = Mysql2::Client.new :host => "dfjhdi9wrhw", :username => 'asdfasdf8d2h' bad_client = Mysql2::Client.new :host => "dfjhdi9wrhw", :username => 'asdfasdf8d2h'