diff --git a/ext/mysql2/mysql2_ext.c b/ext/mysql2/mysql2_ext.c index f15eee0..181d6fc 100644 --- a/ext/mysql2/mysql2_ext.c +++ b/ext/mysql2/mysql2_ext.c @@ -4,6 +4,12 @@ VALUE mMysql2, cMysql2Client; VALUE cMysql2Error, intern_encoding_from_charset; ID sym_id, sym_version, sym_async; +#define REQUIRE_OPEN_DB(_ctxt) \ + if(!_ctxt->net.vio) { \ + rb_raise(cMysql2Error, "closed MySQL connection"); \ + return Qnil; \ + } + /* * non-blocking mysql_*() functions that we won't be wrapping since * they do not appear to hit the network nor issue any interruptible @@ -180,6 +186,7 @@ static VALUE rb_mysql_client_async_result(VALUE self) { Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); if (rb_thread_blocking_region(nogvl_read_query_result, client, RUBY_UBF_IO, 0) == Qfalse) { return rb_raise_mysql2_error(client); } @@ -222,6 +229,8 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); + args.mysql = client; if (rb_thread_blocking_region(nogvl_send_query, &args, RUBY_UBF_IO, 0) == Qfalse) { return rb_raise_mysql2_error(client); @@ -270,6 +279,7 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); newLen = mysql_real_escape_string(client, escaped, StringValuePtr(str), RSTRING_LEN(str)); if (newLen == oldLen) { // no need to return a new ruby string if nothing changed @@ -314,6 +324,7 @@ static VALUE rb_mysql_client_server_info(VALUE self) { #endif Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); version = rb_hash_new(); rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(client))); @@ -331,18 +342,21 @@ static VALUE rb_mysql_client_server_info(VALUE self) { static VALUE rb_mysql_client_socket(VALUE self) { MYSQL * client; Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); return INT2NUM(client->net.fd); } static VALUE rb_mysql_client_last_id(VALUE self) { MYSQL * client; Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); return ULL2NUM(mysql_insert_id(client)); } static VALUE rb_mysql_client_affected_rows(VALUE self) { MYSQL * client; Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); return ULL2NUM(mysql_affected_rows(client)); } diff --git a/lib/mysql2/client.rb b/lib/mysql2/client.rb index b98aa21..e49ed4b 100644 --- a/lib/mysql2/client.rb +++ b/lib/mysql2/client.rb @@ -3,14 +3,13 @@ module Mysql2 def initialize opts = {} init_connection - self.connect_timeout = opts[:connect_timeout] - + [:reconnect, :connect_timeout].each do |key| + next unless opts.key?(key) + send(:"#{key}=", opts[key]) + end # force the encoding to utf8 self.charset_name = opts[:encoding] || 'utf8' - # force reconnection behavior in libmysql - self.reconnect = true - ssl_set(*opts.values_at(:sslkey, :sslcert, :sslca, :sslcapath, :sslciper)) user = opts[:username] diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index 065c129..7165f48 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -45,6 +45,13 @@ describe Mysql2::Client do @client.close.should be_nil end + it "should raise an exception when closed twice" do + @client.close.should be_nil + lambda { + @client.close + }.should raise_error(Mysql2::Error) + end + it "should respond to #query" do @client.should respond_to :query end