diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index 68c3efc..65c463f 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -9,7 +9,7 @@ static ID sym_id, sym_version, sym_async, sym_symbolize_keys, sym_as, sym_array; static ID intern_merge, intern_error_number_eql, intern_sql_state_eql; #define REQUIRE_OPEN_DB(wrapper) \ - if(wrapper->closed || !wrapper->client->net.vio) { \ + if(wrapper->closed) { \ rb_raise(cMysql2Error, "closed MySQL connection"); \ return Qnil; \ } @@ -83,11 +83,11 @@ static VALUE rb_raise_mysql2_error(MYSQL *client) { } static VALUE nogvl_init(void *ptr) { - MYSQL **client = (MYSQL **)ptr; + MYSQL *client; /* may initialize embedded server and read /etc/services off disk */ - *client = mysql_init(NULL); - return *client ? Qtrue : Qfalse; + client = mysql_init((MYSQL *)ptr); + return client ? Qtrue : Qfalse; } static VALUE nogvl_connect(void *ptr) { @@ -132,6 +132,9 @@ static void rb_mysql_client_free(void * ptr) { /* It's safe to call mysql_close() on an already closed connection. */ if (!wrapper->closed) { mysql_close(wrapper->client); + if (!wrapper->freed) { + free(wrapper->client); + } } xfree(ptr); } @@ -139,9 +142,11 @@ static void rb_mysql_client_free(void * ptr) { static VALUE nogvl_close(void * ptr) { mysql_client_wrapper *wrapper = ptr; if (!wrapper->closed) { - mysql_close(wrapper->client); - wrapper->client->net.fd = -1; wrapper->closed = 1; + mysql_close(wrapper->client); + if (!wrapper->freed) { + free(wrapper->client); + } } return Qnil; } @@ -153,6 +158,8 @@ static VALUE allocate(VALUE klass) { wrapper->encoding = Qnil; wrapper->active = 0; wrapper->closed = 0; + wrapper->freed = 0; + wrapper->client = (MYSQL*)malloc(sizeof(MYSQL)); return obj; } @@ -167,7 +174,7 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po args.passwd = NIL_P(pass) ? NULL : StringValuePtr(pass); args.db = NIL_P(database) ? NULL : StringValuePtr(database); args.mysql = wrapper->client; - args.client_flag = NUM2INT(flags); + args.client_flag = NUM2ULONG(flags); if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) { // unable to connect @@ -186,7 +193,9 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po static VALUE rb_mysql_client_close(VALUE self) { GET_CLIENT(self); - rb_thread_blocking_region(nogvl_close, wrapper, RUBY_UBF_IO, 0); + if (!wrapper->closed) { + rb_thread_blocking_region(nogvl_close, wrapper, RUBY_UBF_IO, 0); + } return Qnil; } @@ -335,6 +344,7 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { unsigned long newLen, oldLen; GET_CLIENT(self); + REQUIRE_OPEN_DB(wrapper); Check_Type(str, T_STRING); #ifdef HAVE_RUBY_ENCODING_H rb_encoding *default_internal_enc = rb_default_internal_encoding(); @@ -346,7 +356,6 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { oldLen = RSTRING_LEN(str); newStr = rb_str_new(0, oldLen*2+1); - REQUIRE_OPEN_DB(wrapper); newLen = mysql_real_escape_string(wrapper->client, RSTRING_PTR(newStr), StringValuePtr(str), oldLen); if (newLen == oldLen) { // no need to return a new ruby string if nothing changed @@ -366,6 +375,7 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { static VALUE rb_mysql_client_info(VALUE self) { VALUE version = rb_hash_new(), client_info; GET_CLIENT(self); + #ifdef HAVE_RUBY_ENCODING_H rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding); @@ -386,13 +396,13 @@ static VALUE rb_mysql_client_info(VALUE self) { static VALUE rb_mysql_client_server_info(VALUE self) { VALUE version, server_info; GET_CLIENT(self); + + REQUIRE_OPEN_DB(wrapper); #ifdef HAVE_RUBY_ENCODING_H rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *conn_enc = rb_to_encoding(wrapper->encoding); #endif - REQUIRE_OPEN_DB(wrapper); - version = rb_hash_new(); rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(wrapper->client))); server_info = rb_str_new2(mysql_get_server_info(wrapper->client)); @@ -420,8 +430,14 @@ static VALUE rb_mysql_client_last_id(VALUE self) { static VALUE rb_mysql_client_affected_rows(VALUE self) { GET_CLIENT(self); + my_ulonglong retVal; + REQUIRE_OPEN_DB(wrapper); - return ULL2NUM(mysql_affected_rows(wrapper->client)); + retVal = mysql_affected_rows(wrapper->client); + if (retVal == (my_ulonglong)-1) { + rb_raise_mysql2_error(wrapper->client); + } + return ULL2NUM(retVal); } static VALUE set_reconnect(VALUE self, VALUE value) { @@ -501,7 +517,7 @@ static VALUE set_ssl_options(VALUE self, VALUE key, VALUE cert, VALUE ca, VALUE static VALUE init_connection(VALUE self) { GET_CLIENT(self); - if (rb_thread_blocking_region(nogvl_init, ((void *) &wrapper->client), RUBY_UBF_IO, 0) == Qfalse) { + if (rb_thread_blocking_region(nogvl_init, wrapper->client, RUBY_UBF_IO, 0) == Qfalse) { /* TODO: warning - not enough memory? */ return rb_raise_mysql2_error(wrapper->client); } diff --git a/ext/mysql2/client.h b/ext/mysql2/client.h index d5c2993..fca99c4 100644 --- a/ext/mysql2/client.h +++ b/ext/mysql2/client.h @@ -35,6 +35,7 @@ typedef struct { VALUE encoding; short int active; short int closed; + short int freed; MYSQL *client; } mysql_client_wrapper; diff --git a/lib/active_record/connection_adapters/mysql2_adapter.rb b/lib/active_record/connection_adapters/mysql2_adapter.rb index 78a7901..a442fa8 100644 --- a/lib/active_record/connection_adapters/mysql2_adapter.rb +++ b/lib/active_record/connection_adapters/mysql2_adapter.rb @@ -617,8 +617,13 @@ module ActiveRecord # Turn this off. http://dev.rubyonrails.org/ticket/6778 variable_assignments = ['SQL_AUTO_IS_NULL=0'] encoding = @config[:encoding] + + # make sure we set the encoding variable_assignments << "NAMES '#{encoding}'" if encoding + # increase timeout so mysql server doesn't disconnect us + variable_assignments << "@@wait_timeout = #{@config[:wait_timeout] || 2592000}" + execute("SET #{variable_assignments.join(', ')}", :skip_logging) end diff --git a/lib/mysql2/client.rb b/lib/mysql2/client.rb index bf55621..ee1a36b 100644 --- a/lib/mysql2/client.rb +++ b/lib/mysql2/client.rb @@ -8,7 +8,8 @@ module Mysql2 :symbolize_keys => false, # return field names as symbols instead of strings :database_timezone => :local, # timezone Mysql2 will assume datetime objects are stored in :application_timezone => nil, # timezone Mysql2 will convert to before handing the object back to the caller - :cache_rows => true # tells Mysql2 to use it's internal row cache for results + :cache_rows => true, # tells Mysql2 to use it's internal row cache for results + :connect_flags => REMEMBER_OPTIONS | LONG_PASSWORD | LONG_FLAG | TRANSACTIONS | PROTOCOL_41 | SECURE_CONNECTION } def initialize(opts = {}) @@ -31,7 +32,7 @@ module Mysql2 port = opts[:port] || 3306 database = opts[:database] socket = opts[:socket] - flags = opts[:flags] || 0 + flags = opts[:flags] ? opts[:flags] | @query_options[:connect_flags] : @query_options[:connect_flags] connect user, pass, host, port, database, socket, flags end diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index cad599e..8927410 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -23,10 +23,10 @@ describe Mysql2::Client do end end client = klient.new :flags => Mysql2::Client::FOUND_ROWS - client.connect_args.last.last.should == Mysql2::Client::FOUND_ROWS + (client.connect_args.last.last & Mysql2::Client::FOUND_ROWS).should be_true end - it "should default flags to 0" do + it "should default flags to (REMEMBER_OPTIONS, LONG_PASSWORD, LONG_FLAG, TRANSACTIONS, PROTOCOL_41, SECURE_CONNECTION)" do klient = Class.new(Mysql2::Client) do attr_reader :connect_args def connect *args @@ -35,7 +35,12 @@ describe Mysql2::Client do end end client = klient.new - client.connect_args.last.last.should == 0 + (client.connect_args.last.last & (Mysql2::Client::REMEMBER_OPTIONS | + Mysql2::Client::LONG_PASSWORD | + Mysql2::Client::LONG_FLAG | + Mysql2::Client::TRANSACTIONS | + Mysql2::Client::PROTOCOL_41 | + Mysql2::Client::SECURE_CONNECTION)).should be_true end it "should have a global default_query_options hash" do @@ -71,6 +76,9 @@ describe Mysql2::Client do it "should be able to close properly" do @client.close.should be_nil + lambda { + @client.query "SELECT 1" + }.should raise_error(Mysql2::Error) end it "should respond to #query" do @@ -103,6 +111,13 @@ describe Mysql2::Client do }.should raise_error(Mysql2::Error) end + it "should require an open connection" do + @client.close + lambda { + @client.query "SELECT 1" + }.should raise_error(Mysql2::Error) + end + # XXX this test is not deterministic (because Unix signal handling is not) # and may fail on a loaded system if RUBY_PLATFORM !~ /mingw|mswin/ @@ -137,25 +152,34 @@ describe Mysql2::Client do @client.should respond_to(:escape) end - it "#escape should return a new SQL-escape version of the passed string" do - @client.escape("abc'def\"ghi\0jkl%mno").should eql("abc\\'def\\\"ghi\\0jkl%mno") - end + context "#escape" do + it "should return a new SQL-escape version of the passed string" do + @client.escape("abc'def\"ghi\0jkl%mno").should eql("abc\\'def\\\"ghi\\0jkl%mno") + end - it "#escape should return the passed string if nothing was escaped" do - str = "plain" - @client.escape(str).object_id.should eql(str.object_id) - end + it "should return the passed string if nothing was escaped" do + str = "plain" + @client.escape(str).object_id.should eql(str.object_id) + end - it "#escape should not overflow the thread stack" do - lambda { - Thread.new { @client.escape("'" * 256 * 1024) }.join - }.should_not raise_error(SystemStackError) - end + it "should not overflow the thread stack" do + lambda { + Thread.new { @client.escape("'" * 256 * 1024) }.join + }.should_not raise_error(SystemStackError) + end - it "#escape should not overflow the process stack" do - lambda { - Thread.new { @client.escape("'" * 1024 * 1024 * 4) }.join - }.should_not raise_error(SystemStackError) + it "should not overflow the process stack" do + lambda { + Thread.new { @client.escape("'" * 1024 * 1024 * 4) }.join + }.should_not raise_error(SystemStackError) + end + + it "should require an open connection" do + @client.close + lambda { + @client.escape "" + }.should raise_error(Mysql2::Error) + end end it "should respond to #info" do @@ -203,6 +227,13 @@ describe Mysql2::Client do server_info[:version].class.should eql(String) end + it "#server_info should require an open connection" do + @client.close + lambda { + @client.server_info + }.should raise_error(Mysql2::Error) + end + if defined? Encoding context "strings returned by #server_info" do it "should default to the connection's encoding if Encoding.default_internal is nil" do @@ -231,6 +262,13 @@ describe Mysql2::Client do @client.socket.should_not eql(0) end + it "#socket should require an open connection" do + @client.close + lambda { + @client.socket + }.should raise_error(Mysql2::Error) + end + it "should raise a Mysql2::Error exception upon connection failure" do lambda { bad_client = Mysql2::Client.new :host => "dfjhdi9wrhw", :username => 'asdfasdf8d2h'