diff --git a/ext/mysql2_ext.c b/ext/mysql2_ext.c index af0f622..a49916e 100644 --- a/ext/mysql2_ext.c +++ b/ext/mysql2_ext.c @@ -7,8 +7,12 @@ static VALUE rb_mysql_client_new(int argc, VALUE * argv, VALUE klass) { VALUE rb_host, rb_socket, rb_port, rb_database, rb_username, rb_password, rb_reconnect, rb_connect_timeout; + VALUE rb_ssl_client_key, rb_ssl_client_cert, rb_ssl_ca_cert, + rb_ssl_ca_path, rb_ssl_cipher; char *host = "localhost", *socket = NULL, *username = NULL, *password = NULL, *database = NULL; + char *ssl_client_key = NULL, *ssl_client_cert = NULL, *ssl_ca_cert = NULL, + *ssl_ca_path = NULL, *ssl_cipher = NULL; unsigned int port = 3306, connect_timeout = 0; my_bool reconnect = 0; @@ -55,6 +59,32 @@ static VALUE rb_mysql_client_new(int argc, VALUE * argv, VALUE klass) { Check_Type(rb_connect_timeout, T_FIXNUM); connect_timeout = FIX2INT(rb_connect_timeout); } + + // SSL options + if ((rb_ssl_client_key = rb_hash_aref(opts, sym_sslkey)) != Qnil) { + Check_Type(rb_ssl_client_key, T_STRING); + ssl_client_key = RSTRING_PTR(rb_ssl_client_key); + } + + if ((rb_ssl_client_cert = rb_hash_aref(opts, sym_sslcert)) != Qnil) { + Check_Type(rb_ssl_client_cert, T_STRING); + ssl_client_cert = RSTRING_PTR(rb_ssl_client_cert); + } + + if ((rb_ssl_ca_cert = rb_hash_aref(opts, sym_sslca)) != Qnil) { + Check_Type(rb_ssl_ca_cert, T_STRING); + ssl_ca_cert = RSTRING_PTR(rb_ssl_ca_cert); + } + + if ((rb_ssl_ca_path = rb_hash_aref(opts, sym_sslcapath)) != Qnil) { + Check_Type(rb_ssl_ca_path, T_STRING); + ssl_ca_path = RSTRING_PTR(rb_ssl_ca_path); + } + + if ((rb_ssl_cipher = rb_hash_aref(opts, sym_sslcipher)) != Qnil) { + Check_Type(rb_ssl_cipher, T_STRING); + ssl_cipher = RSTRING_PTR(rb_ssl_cipher); + } } if (!mysql_init(client)) { @@ -81,6 +111,10 @@ static VALUE rb_mysql_client_new(int argc, VALUE * argv, VALUE klass) { rb_warn("%s\n", mysql_error(client)); } + if (ssl_ca_cert != NULL || ssl_client_key != NULL) { + mysql_ssl_set(client, ssl_client_key, ssl_client_cert, ssl_ca_cert, ssl_ca_path, ssl_cipher); + } + if (mysql_real_connect(client, host, username, password, database, port, socket, 0) == NULL) { // unable to connect rb_raise(cMysql2Error, "%s", mysql_error(client)); @@ -414,6 +448,11 @@ void Init_mysql2_ext() { sym_connect_timeout = ID2SYM(rb_intern("connect_timeout")); sym_id = ID2SYM(rb_intern("id")); sym_version = ID2SYM(rb_intern("version")); + sym_sslkey = ID2SYM(rb_intern("sslkey")); + sym_sslcert = ID2SYM(rb_intern("sslcert")); + sym_sslca = ID2SYM(rb_intern("sslca")); + sym_sslcapath = ID2SYM(rb_intern("sslcapath")); + sym_sslcipher = ID2SYM(rb_intern("sslcipher")); #ifdef HAVE_RUBY_ENCODING_H utf8Encoding = rb_enc_find_index("UTF-8"); diff --git a/ext/mysql2_ext.h b/ext/mysql2_ext.h index 5600ef4..c72c8b9 100644 --- a/ext/mysql2_ext.h +++ b/ext/mysql2_ext.h @@ -20,7 +20,8 @@ VALUE cMysql2Error; /* Mysql2::Client */ #define GetMysql2Client(obj, sval) (sval = (MYSQL*)DATA_PTR(obj)); static ID sym_socket, sym_host, sym_port, sym_username, sym_password, - sym_database, sym_reconnect, sym_connect_timeout, sym_id, sym_version; + sym_database, sym_reconnect, sym_connect_timeout, sym_id, sym_version, + sym_sslkey, sym_sslcert, sym_sslca, sym_sslcapath, sym_sslcipher; static VALUE rb_mysql_client_new(int argc, VALUE * argv, VALUE klass); static VALUE rb_mysql_client_init(VALUE self, int argc, VALUE * argv); static VALUE rb_mysql_client_query(VALUE self, VALUE query); diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index dd4e94d..d318234 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -6,6 +6,29 @@ describe Mysql2::Client do @client = Mysql2::Client.new end + it "should be able to connect via SSL options" do + pending("DON'T WORRY, THIS TEST PASSES :) - but is machine-specific. You need to have MySQL running with SSL configured and enabled. Then update the paths in this test to your needs and remove the pending state.") + ssl_client = nil + lambda { + ssl_client = Mysql2::Client.new( + :sslkey => '/path/to/client-key.pem', + :sslcert => '/path/to/client-cert.pem', + :sslca => '/path/to/ca-cert.pem', + :sslcapath => '/path/to/newcerts/', + :sslcipher => 'DHE-RSA-AES256-SHA' + ) + }.should_not raise_error(Mysql2::Error) + + results = ssl_client.query("SHOW STATUS WHERE Variable_name = \"Ssl_version\" OR Variable_name = \"Ssl_cipher\"").to_a + results[0]['Variable_name'].should eql('Ssl_cipher') + results[0]['Value'].should_not be_nil + results[0]['Value'].class.should eql(String) + + results[1]['Variable_name'].should eql('Ssl_version') + results[1]['Value'].should_not be_nil + results[1]['Value'].class.should eql(String) + end + it "should respond to #query" do @client.should respond_to :query end