From 74e99bae5f3fb19fa84ccd2821a481569d40aa36 Mon Sep 17 00:00:00 2001 From: Aaron Patterson Date: Sun, 4 Jul 2010 19:29:12 -0700 Subject: [PATCH] refactoring initialize to be done in ruby, adding setter methods for connection options --- ext/mysql2/mysql2_ext.c | 226 +++++++++++++++++++--------------------- ext/mysql2/mysql2_ext.h | 6 +- lib/mysql2.rb | 1 + lib/mysql2/client.rb | 25 +++++ 4 files changed, 133 insertions(+), 125 deletions(-) create mode 100644 lib/mysql2/client.rb diff --git a/ext/mysql2/mysql2_ext.c b/ext/mysql2/mysql2_ext.c index bb74690..6deb23a 100644 --- a/ext/mysql2/mysql2_ext.c +++ b/ext/mysql2/mysql2_ext.c @@ -62,117 +62,24 @@ static VALUE allocate(VALUE klass) ); } -/* Mysql2::Client */ -static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self) { +static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE port, VALUE database, VALUE socket) +{ MYSQL * client; - struct nogvl_connect_args args = { - .host = "localhost", - .user = NULL, - .passwd = NULL, - .db = NULL, - .port = 3306, - .unix_socket = NULL, - .client_flag = 0 - }; - VALUE opts; - VALUE rb_host, rb_socket, rb_port, rb_database, - rb_username, rb_password, rb_reconnect, - rb_connect_timeout; - VALUE rb_ssl_client_key, rb_ssl_client_cert, rb_ssl_ca_cert, - rb_ssl_ca_path, rb_ssl_cipher; - char *ssl_client_key = NULL, *ssl_client_cert = NULL, *ssl_ca_cert = NULL, - *ssl_ca_path = NULL, *ssl_cipher = NULL; - unsigned int connect_timeout = 0; - my_bool reconnect = 1; + struct nogvl_connect_args args; Data_Get_Struct(self, MYSQL, client); - if (rb_scan_args(argc, argv, "01", &opts) == 1) { - Check_Type(opts, T_HASH); - - if ((rb_host = rb_hash_aref(opts, sym_host)) != Qnil) { - args.host = StringValuePtr(rb_host); - } - - if ((rb_socket = rb_hash_aref(opts, sym_socket)) != Qnil) { - args.unix_socket = StringValuePtr(rb_socket); - } - - if ((rb_port = rb_hash_aref(opts, sym_port)) != Qnil) { - args.port = NUM2INT(rb_port); - } - - if ((rb_username = rb_hash_aref(opts, sym_username)) != Qnil) { - args.user = StringValuePtr(rb_username); - } - - if ((rb_password = rb_hash_aref(opts, sym_password)) != Qnil) { - args.passwd = StringValuePtr(rb_password); - } - - if ((rb_database = rb_hash_aref(opts, sym_database)) != Qnil) { - args.db = StringValuePtr(rb_database); - } - - if ((rb_reconnect = rb_hash_aref(opts, sym_reconnect)) != Qnil) { - reconnect = rb_reconnect == Qfalse ? 0 : 1; - } - - if ((rb_connect_timeout = rb_hash_aref(opts, sym_connect_timeout)) != Qnil) { - connect_timeout = NUM2INT(rb_connect_timeout); - } - - // SSL options - if ((rb_ssl_client_key = rb_hash_aref(opts, sym_sslkey)) != Qnil) { - ssl_client_key = StringValuePtr(rb_ssl_client_key); - } - - if ((rb_ssl_client_cert = rb_hash_aref(opts, sym_sslcert)) != Qnil) { - ssl_client_cert = StringValuePtr(rb_ssl_client_cert); - } - - if ((rb_ssl_ca_cert = rb_hash_aref(opts, sym_sslca)) != Qnil) { - ssl_ca_cert = StringValuePtr(rb_ssl_ca_cert); - } - - if ((rb_ssl_ca_path = rb_hash_aref(opts, sym_sslcapath)) != Qnil) { - ssl_ca_path = StringValuePtr(rb_ssl_ca_path); - } - - if ((rb_ssl_cipher = rb_hash_aref(opts, sym_sslcipher)) != Qnil) { - ssl_cipher = StringValuePtr(rb_ssl_cipher); - } - } - - if (rb_thread_blocking_region(nogvl_init, client, RUBY_UBF_IO, 0) == Qfalse) { - // TODO: warning - not enough memory? - return rb_raise_mysql2_error(client); - } - - // set default reconnect behavior - if (mysql_options(client, MYSQL_OPT_RECONNECT, &reconnect) != 0) { - // TODO: warning - unable to set reconnect behavior - rb_warn("%s\n", mysql_error(client)); - } - - // set default connection timeout behavior - if (connect_timeout != 0 && mysql_options(client, MYSQL_OPT_CONNECT_TIMEOUT, (const char *)&connect_timeout) != 0) { - // TODO: warning - unable to set connection timeout - rb_warn("%s\n", mysql_error(client)); - } - - // force the encoding to utf8 - if (mysql_options(client, MYSQL_SET_CHARSET_NAME, "utf8") != 0) { - // TODO: warning - unable to set charset - rb_warn("%s\n", mysql_error(client)); - } - - if (ssl_ca_cert != NULL || ssl_client_key != NULL) { - mysql_ssl_set(client, ssl_client_key, ssl_client_cert, ssl_ca_cert, ssl_ca_path, ssl_cipher); - } - + args.host = NIL_P(host) ? "localhost" : StringValuePtr(host); + args.unix_socket = NIL_P(socket) ? NULL : StringValuePtr(socket); + args.port = NIL_P(port) ? 3306 : NUM2INT(port); + args.user = NIL_P(user) ? NULL : StringValuePtr(user); + args.passwd = NIL_P(pass) ? NULL : StringValuePtr(pass); + args.db = NIL_P(database) ? NULL : StringValuePtr(database); args.mysql = client; - if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) { + args.client_flag = 0; + + if (rb_thread_blocking_region(nogvl_connect, &args, RUBY_UBF_IO, 0) == Qfalse) + { // unable to connect return rb_raise_mysql2_error(client); } @@ -729,6 +636,92 @@ static VALUE rb_raise_mysql2_error(MYSQL *client) { return Qnil; } +static VALUE set_reconnect(VALUE self, VALUE value) +{ + my_bool reconnect; + MYSQL * client; + + Data_Get_Struct(self, MYSQL, client); + + if(!NIL_P(value)) { + reconnect = value == Qfalse ? 0 : 1; + + /* set default reconnect behavior */ + if (mysql_options(client, MYSQL_OPT_RECONNECT, &reconnect)) { + /* TODO: warning - unable to set reconnect behavior */ + rb_warn("%s\n", mysql_error(client)); + } + } + return value; +} + +static VALUE set_connect_timeout(VALUE self, VALUE value) +{ + unsigned int connect_timeout = 0; + MYSQL * client; + + Data_Get_Struct(self, MYSQL, client); + + if(!NIL_P(value)) { + connect_timeout = NUM2INT(value); + if(0 == connect_timeout) return value; + + /* set default connection timeout behavior */ + if (mysql_options(client, MYSQL_OPT_CONNECT_TIMEOUT, &connect_timeout)) { + /* TODO: warning - unable to set connection timeout */ + rb_warn("%s\n", mysql_error(client)); + } + } + return value; +} + +static VALUE set_charset_name(VALUE self, VALUE value) +{ + char * charset_name; + MYSQL * client; + + Data_Get_Struct(self, MYSQL, client); + + charset_name = StringValuePtr(value); + + if (mysql_options(client, MYSQL_SET_CHARSET_NAME, charset_name)) { + /* TODO: warning - unable to set charset */ + rb_warn("%s\n", mysql_error(client)); + } + + return value; +} + +static VALUE set_ssl_options(VALUE self, VALUE key, VALUE cert, VALUE ca, VALUE capath, VALUE cipher) +{ + MYSQL * client; + Data_Get_Struct(self, MYSQL, client); + + if(!NIL_P(ca) || !NIL_P(key)) { + mysql_ssl_set(client, + NIL_P(key) ? NULL : StringValuePtr(key), + NIL_P(cert) ? NULL : StringValuePtr(cert), + NIL_P(ca) ? NULL : StringValuePtr(ca), + NIL_P(capath) ? NULL : StringValuePtr(capath), + NIL_P(cipher) ? NULL : StringValuePtr(cipher)); + } + + return self; +} + +static VALUE init_connection(VALUE self) +{ + MYSQL * client; + Data_Get_Struct(self, MYSQL, client); + + if (rb_thread_blocking_region(nogvl_init, client, RUBY_UBF_IO, 0) == Qfalse) { + /* TODO: warning - not enough memory? */ + return rb_raise_mysql2_error(client); + } + + return self; +} + /* Ruby Extension initializer */ void Init_mysql2() { cBigDecimal = rb_const_get(rb_cObject, rb_intern("BigDecimal")); @@ -740,7 +733,6 @@ void Init_mysql2() { rb_define_alloc_func(cMysql2Client, allocate); - rb_define_method(cMysql2Client, "initialize", rb_mysql_client_init, -1); rb_define_method(cMysql2Client, "close", rb_mysql_client_close, 0); rb_define_method(cMysql2Client, "query", rb_mysql_client_query, -1); rb_define_method(cMysql2Client, "escape", rb_mysql_client_escape, 1); @@ -751,6 +743,13 @@ void Init_mysql2() { rb_define_method(cMysql2Client, "last_id", rb_mysql_client_last_id, 0); rb_define_method(cMysql2Client, "affected_rows", rb_mysql_client_affected_rows, 0); + rb_define_private_method(cMysql2Client, "reconnect=", set_reconnect, 1); + rb_define_private_method(cMysql2Client, "connect_timeout=", set_connect_timeout, 1); + rb_define_private_method(cMysql2Client, "charset_name=", set_charset_name, 1); + rb_define_private_method(cMysql2Client, "ssl_set", set_ssl_options, 5); + rb_define_private_method(cMysql2Client, "init_connection", init_connection, 0); + rb_define_private_method(cMysql2Client, "connect", rb_connect, 6); + cMysql2Error = rb_const_get(mMysql2, rb_intern("Error")); cMysql2Result = rb_define_class_under(mMysql2, "Result", rb_cObject); @@ -764,21 +763,8 @@ void Init_mysql2() { intern_utc = rb_intern("utc"); sym_symbolize_keys = ID2SYM(rb_intern("symbolize_keys")); - sym_reconnect = ID2SYM(rb_intern("reconnect")); - sym_database = ID2SYM(rb_intern("database")); - sym_username = ID2SYM(rb_intern("username")); - sym_password = ID2SYM(rb_intern("password")); - sym_host = ID2SYM(rb_intern("host")); - sym_port = ID2SYM(rb_intern("port")); - sym_socket = ID2SYM(rb_intern("socket")); - sym_connect_timeout = ID2SYM(rb_intern("connect_timeout")); sym_id = ID2SYM(rb_intern("id")); sym_version = ID2SYM(rb_intern("version")); - sym_sslkey = ID2SYM(rb_intern("sslkey")); - sym_sslcert = ID2SYM(rb_intern("sslcert")); - sym_sslca = ID2SYM(rb_intern("sslca")); - sym_sslcapath = ID2SYM(rb_intern("sslcapath")); - sym_sslcipher = ID2SYM(rb_intern("sslcipher")); sym_async = ID2SYM(rb_intern("async")); #ifdef HAVE_RUBY_ENCODING_H diff --git a/ext/mysql2/mysql2_ext.h b/ext/mysql2/mysql2_ext.h index 3580dbb..3e1b4ad 100644 --- a/ext/mysql2/mysql2_ext.h +++ b/ext/mysql2/mysql2_ext.h @@ -30,11 +30,7 @@ static ID intern_new, intern_utc; /* Mysql2::Error */ static VALUE cMysql2Error; -static ID sym_socket, sym_host, sym_port, sym_username, sym_password, - sym_database, sym_reconnect, sym_connect_timeout, sym_id, sym_version, - sym_sslkey, sym_sslcert, sym_sslca, sym_sslcapath, sym_sslcipher, - sym_symbolize_keys, sym_async; -static VALUE rb_mysql_client_init(int argc, VALUE * argv, VALUE self); +static ID sym_id, sym_version, sym_symbolize_keys, sym_async; static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self); static VALUE rb_mysql_client_escape(VALUE self, VALUE str); static VALUE rb_mysql_client_info(VALUE self); diff --git a/lib/mysql2.rb b/lib/mysql2.rb index 26bb058..2ee1522 100644 --- a/lib/mysql2.rb +++ b/lib/mysql2.rb @@ -4,6 +4,7 @@ require 'bigdecimal' require 'mysql2/error' require 'mysql2/mysql2' +require 'mysql2/client' # = Mysql2 # diff --git a/lib/mysql2/client.rb b/lib/mysql2/client.rb new file mode 100644 index 0000000..0fc61fd --- /dev/null +++ b/lib/mysql2/client.rb @@ -0,0 +1,25 @@ +module Mysql2 + class Client + def initialize opts = {} + init_connection + + [:reconnect, :connect_timeout].each do |key| + next unless opts.key?(key) + send(:"#{key}=", opts[key]) + end + # force the encoding to utf8 + self.charset_name = 'utf8' + + ssl_set(*opts.values_at(:sslkey, :sslcert, :sslca, :sslcapath, :sslciper)) + + user = opts[:username] + pass = opts[:password] + host = opts[:host] || 'localhost' + port = opts[:port] || 3306 + database = opts[:database] + socket = opts[:socket] + + connect user, pass, host, port, database, socket + end + end +end