diff --git a/ext/mysql2_ext.c b/ext/mysql2_ext.c index a49916e..7c09869 100644 --- a/ext/mysql2_ext.c +++ b/ext/mysql2_ext.c @@ -125,7 +125,7 @@ static VALUE rb_mysql_client_new(int argc, VALUE * argv, VALUE klass) { return obj; } -static VALUE rb_mysql_client_init(VALUE self, int argc, VALUE * argv) { +static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) { return self; } @@ -136,11 +136,21 @@ void rb_mysql_client_free(void * client) { } } -static VALUE rb_mysql_client_query(VALUE self, VALUE sql) { +static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { MYSQL * client; MYSQL_RES * result; fd_set fdset; int fd, retval; + int async = 0; + VALUE sql, opts; + VALUE rb_async; + + if (rb_scan_args(argc, argv, "11", &sql, &opts) == 2) { + if ((rb_async = rb_hash_aref(opts, sym_async)) != Qnil) { + async = rb_async == Qtrue ? 1 : 0; + } + } + Check_Type(sql, T_STRING); GetMysql2Client(self, client); @@ -149,37 +159,29 @@ static VALUE rb_mysql_client_query(VALUE self, VALUE sql) { return Qnil; } - // the below code is largely from do_mysql - // http://github.com/datamapper/do - fd = client->net.fd; - for(;;) { - FD_ZERO(&fdset); - FD_SET(fd, &fdset); + if (!async) { + // the below code is largely from do_mysql + // http://github.com/datamapper/do + fd = client->net.fd; + for(;;) { + FD_ZERO(&fdset); + FD_SET(fd, &fdset); - retval = rb_thread_select(fd + 1, &fdset, NULL, NULL, NULL); + retval = rb_thread_select(fd + 1, &fdset, NULL, NULL, NULL); - if (retval < 0) { - rb_sys_fail(0); + if (retval < 0) { + rb_sys_fail(0); + } + + if (retval > 0) { + break; + } } - if (retval > 0) { - break; - } - } - - if (mysql_read_query_result(client) != 0) { - rb_raise(cMysql2Error, "%s", mysql_error(client)); + return rb_mysql_client_async_result(self); + } else { return Qnil; } - - result = mysql_store_result(client); - if (result == NULL) { - if (mysql_field_count(client) != 0) { - rb_raise(cMysql2Error, "%s", mysql_error(client)); - } - return Qnil; - } - return rb_mysql_result_to_obj(result); } static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { @@ -229,6 +231,27 @@ static VALUE rb_mysql_client_socket(VALUE self) { return INT2NUM(client->net.fd); } +static VALUE rb_mysql_client_async_result(VALUE self) { + MYSQL * client; + MYSQL_RES * result; + GetMysql2Client(self, client); + + if (mysql_read_query_result(client) != 0) { + rb_raise(cMysql2Error, "%s", mysql_error(client)); + return Qnil; + } + + result = mysql_store_result(client); + if (result == NULL) { + if (mysql_field_count(client) != 0) { + rb_raise(cMysql2Error, "%s", mysql_error(client)); + } + return Qnil; + } + + return rb_mysql_result_to_obj(result); +} + /* Mysql2::Result */ static VALUE rb_mysql_result_to_obj(MYSQL_RES * r) { VALUE obj; @@ -420,11 +443,12 @@ void Init_mysql2_ext() { VALUE cMysql2Client = rb_define_class_under(mMysql2, "Client", rb_cObject); rb_define_singleton_method(cMysql2Client, "new", rb_mysql_client_new, -1); rb_define_method(cMysql2Client, "initialize", rb_mysql_client_init, -1); - rb_define_method(cMysql2Client, "query", rb_mysql_client_query, 1); + rb_define_method(cMysql2Client, "query", rb_mysql_client_query, -1); rb_define_method(cMysql2Client, "escape", rb_mysql_client_escape, 1); rb_define_method(cMysql2Client, "info", rb_mysql_client_info, 0); rb_define_method(cMysql2Client, "server_info", rb_mysql_client_server_info, 0); rb_define_method(cMysql2Client, "socket", rb_mysql_client_socket, 0); + rb_define_method(cMysql2Client, "async_result", rb_mysql_client_async_result, 0); cMysql2Error = rb_define_class_under(mMysql2, "Error", rb_eStandardError); @@ -453,6 +477,7 @@ void Init_mysql2_ext() { sym_sslca = ID2SYM(rb_intern("sslca")); sym_sslcapath = ID2SYM(rb_intern("sslcapath")); sym_sslcipher = ID2SYM(rb_intern("sslcipher")); + sym_async = ID2SYM(rb_intern("async")); #ifdef HAVE_RUBY_ENCODING_H utf8Encoding = rb_enc_find_index("UTF-8"); diff --git a/ext/mysql2_ext.h b/ext/mysql2_ext.h index f366001..92fcb6a 100644 --- a/ext/mysql2_ext.h +++ b/ext/mysql2_ext.h @@ -28,20 +28,21 @@ VALUE cMysql2Error; #define GetMysql2Client(obj, sval) (sval = (MYSQL*)DATA_PTR(obj)); static ID sym_socket, sym_host, sym_port, sym_username, sym_password, sym_database, sym_reconnect, sym_connect_timeout, sym_id, sym_version, - sym_sslkey, sym_sslcert, sym_sslca, sym_sslcapath, sym_sslcipher; + sym_sslkey, sym_sslcert, sym_sslca, sym_sslcapath, sym_sslcipher, + sym_symbolize_keys, sym_async; static VALUE rb_mysql_client_new(int argc, VALUE * argv, VALUE klass); -static VALUE rb_mysql_client_init(VALUE self, int argc, VALUE * argv); -static VALUE rb_mysql_client_query(VALUE self, VALUE query); +static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self); +static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self); static VALUE rb_mysql_client_escape(VALUE self, VALUE str); static VALUE rb_mysql_client_info(VALUE self); static VALUE rb_mysql_client_server_info(VALUE self); static VALUE rb_mysql_client_socket(VALUE self); +static VALUE rb_mysql_client_async_result(VALUE self); void rb_mysql_client_free(void * client); /* Mysql2::Result */ #define GetMysql2Result(obj, sval) (sval = (MYSQL_RES*)DATA_PTR(obj)); VALUE cMysql2Result; -static ID sym_symbolize_keys; static VALUE rb_mysql_result_to_obj(MYSQL_RES * res); static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self); static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self); diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index d318234..a184e52 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -6,6 +6,11 @@ describe Mysql2::Client do @client = Mysql2::Client.new end + after(:each) do + # forcefully clean up old connections + GC.start + end + it "should be able to connect via SSL options" do pending("DON'T WORRY, THIS TEST PASSES :) - but is machine-specific. You need to have MySQL running with SSL configured and enabled. Then update the paths in this test to your needs and remove the pending state.") ssl_client = nil @@ -90,4 +95,25 @@ describe Mysql2::Client do good_client = Mysql2::Client.new }.should_not raise_error(Mysql2::Error) end + + it "evented async queries should be supported" do + # should immediately return nil + @client.query("SELECT sleep(0.1)", :async => true).should eql(nil) + + io_wrapper = IO.for_fd(@client.socket) + loops = 0 + loop do + if IO.select([io_wrapper], nil, nil, 0.05) + break + else + loops += 1 + end + end + + # make sure we waited some period of time + (loops >= 1).should be_true + + result = @client.async_result + result.class.should eql(Mysql2::Result) + end end \ No newline at end of file