From e385e7cf6bf7c13005e3ee1af4875fe39def39ae Mon Sep 17 00:00:00 2001 From: Brian Lopez Date: Thu, 8 Jul 2010 22:22:20 -0700 Subject: [PATCH] Initial refactor of encoding support to ensure we map Ruby encodings to MySQL encodings properly. --- ext/mysql2/mysql2_ext.c | 37 ++++++++++++++++----------- ext/mysql2/mysql2_ext.h | 4 --- ext/mysql2/result.c | 27 ++++++++++++-------- lib/mysql2/client.rb | 51 +++++++++++++++++++++++++++++++++++++- spec/mysql2/client_spec.rb | 14 ++++++++--- spec/mysql2/result_spec.rb | 27 +++++++++++++++----- 6 files changed, 120 insertions(+), 40 deletions(-) diff --git a/ext/mysql2/mysql2_ext.c b/ext/mysql2/mysql2_ext.c index 2b0dbf9..a992cbf 100644 --- a/ext/mysql2/mysql2_ext.c +++ b/ext/mysql2/mysql2_ext.c @@ -1,6 +1,6 @@ #include -VALUE mMysql2; +VALUE mMysql2, cMysql2Client; VALUE cMysql2Error; ID sym_id, sym_version, sym_async; @@ -10,10 +10,6 @@ ID sym_id, sym_version, sym_async; return Qnil; \ } -#ifdef HAVE_RUBY_ENCODING_H -rb_encoding *utf8Encoding; -#endif - /* * non-blocking mysql_*() functions that we won't be wrapping since * they do not appear to hit the network nor issue any interruptible @@ -204,7 +200,9 @@ static VALUE rb_mysql_client_async_result(VALUE self) { return Qnil; } - return rb_mysql_result_to_obj(result); + VALUE resultObj = rb_mysql_result_to_obj(result); + rb_iv_set(resultObj, "@encoding", rb_iv_get(self, "@encoding")); + return resultObj; } static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { @@ -264,6 +262,9 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { unsigned long newLen, oldLen; #ifdef HAVE_RUBY_ENCODING_H rb_encoding *default_internal_enc = rb_default_internal_encoding(); + rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding")); + // ensure the string is in the encoding the connection is expecting + str = rb_str_export_to_enc(str, conn_enc); #endif Check_Type(str, T_STRING); @@ -280,7 +281,7 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) { } else { newStr = rb_str_new(escaped, newLen); #ifdef HAVE_RUBY_ENCODING_H - rb_enc_associate(newStr, utf8Encoding); + rb_enc_associate(newStr, conn_enc); if (default_internal_enc) { newStr = rb_str_export_to_enc(newStr, default_internal_enc); } @@ -293,12 +294,13 @@ static VALUE rb_mysql_client_info(RB_MYSQL_UNUSED VALUE self) { VALUE version = rb_hash_new(), client_info; #ifdef HAVE_RUBY_ENCODING_H rb_encoding *default_internal_enc = rb_default_internal_encoding(); + rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding")); #endif rb_hash_aset(version, sym_id, LONG2NUM(mysql_get_client_version())); client_info = rb_str_new2(mysql_get_client_info()); #ifdef HAVE_RUBY_ENCODING_H - rb_enc_associate(client_info, utf8Encoding); + rb_enc_associate(client_info, conn_enc); if (default_internal_enc) { client_info = rb_str_export_to_enc(client_info, default_internal_enc); } @@ -312,6 +314,7 @@ static VALUE rb_mysql_client_server_info(VALUE self) { VALUE version, server_info; #ifdef HAVE_RUBY_ENCODING_H rb_encoding *default_internal_enc = rb_default_internal_encoding(); + rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding")); #endif Data_Get_Struct(self, MYSQL, client); @@ -321,7 +324,7 @@ static VALUE rb_mysql_client_server_info(VALUE self) { rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(client))); server_info = rb_str_new2(mysql_get_server_info(client)); #ifdef HAVE_RUBY_ENCODING_H - rb_enc_associate(server_info, utf8Encoding); + rb_enc_associate(server_info, conn_enc); if (default_internal_enc) { server_info = rb_str_export_to_enc(server_info, default_internal_enc); } @@ -397,6 +400,15 @@ static VALUE set_charset_name(VALUE self, VALUE value) Data_Get_Struct(self, MYSQL, client); +#ifdef HAVE_RUBY_ENCODING_H + VALUE new_encoding, old_encoding; + new_encoding = rb_funcall(cMysql2Client, rb_intern("encoding_from_charset"), 1, value); + old_encoding = rb_iv_get(self, "@encoding"); + if (old_encoding == Qnil) { + rb_iv_set(self, "@encoding", new_encoding); + } +#endif + charset_name = StringValuePtr(value); if (mysql_options(client, MYSQL_SET_CHARSET_NAME, charset_name)) { @@ -440,7 +452,7 @@ static VALUE init_connection(VALUE self) /* Ruby Extension initializer */ void Init_mysql2() { mMysql2 = rb_define_module("Mysql2"); - VALUE cMysql2Client = rb_define_class_under(mMysql2, "Client", rb_cObject); + cMysql2Client = rb_define_class_under(mMysql2, "Client", rb_cObject); rb_define_alloc_func(cMysql2Client, allocate); @@ -463,11 +475,6 @@ void Init_mysql2() { cMysql2Error = rb_const_get(mMysql2, rb_intern("Error")); - -#ifdef HAVE_RUBY_ENCODING_H - utf8Encoding = rb_utf8_encoding(); -#endif - init_mysql2_result(); sym_id = ID2SYM(rb_intern("id")); diff --git a/ext/mysql2/mysql2_ext.h b/ext/mysql2/mysql2_ext.h index 602ecd9..0c18884 100644 --- a/ext/mysql2/mysql2_ext.h +++ b/ext/mysql2/mysql2_ext.h @@ -20,10 +20,6 @@ #include #endif -#ifdef HAVE_RUBY_ENCODING_H -extern rb_encoding *utf8Encoding; -#endif - #if defined(__GNUC__) && (__GNUC__ >= 3) #define RB_MYSQL_UNUSED __attribute__ ((unused)) #else diff --git a/ext/mysql2/result.c b/ext/mysql2/result.c index 5bf0baf..2aff55d 100644 --- a/ext/mysql2/result.c +++ b/ext/mysql2/result.c @@ -45,7 +45,11 @@ static VALUE nogvl_fetch_row(void *ptr) { return (VALUE)mysql_fetch_row(result); } -static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsigned int idx, short int symbolize_keys) { +static VALUE rb_mysql_result_fetch_field(VALUE self, unsigned int idx, short int symbolize_keys) { + mysql2_result_wrapper * wrapper; + + GetMysql2Result(self, wrapper); + if (wrapper->fields == Qnil) { wrapper->numberOfFields = mysql_num_fields(wrapper->result); wrapper->fields = rb_ary_new2(wrapper->numberOfFields); @@ -54,9 +58,10 @@ static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsign VALUE rb_field = rb_ary_entry(wrapper->fields, idx); if (rb_field == Qnil) { MYSQL_FIELD *field = NULL; - #ifdef HAVE_RUBY_ENCODING_H +#ifdef HAVE_RUBY_ENCODING_H rb_encoding *default_internal_enc = rb_default_internal_encoding(); - #endif + rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding")); +#endif field = mysql_fetch_field_direct(wrapper->result, idx); if (symbolize_keys) { @@ -66,12 +71,12 @@ static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsign rb_field = ID2SYM(rb_intern(buf)); } else { rb_field = rb_str_new(field->name, field->name_length); - #ifdef HAVE_RUBY_ENCODING_H - rb_enc_associate(rb_field, utf8Encoding); +#ifdef HAVE_RUBY_ENCODING_H + rb_enc_associate(rb_field, conn_enc); if (default_internal_enc) { rb_field = rb_str_export_to_enc(rb_field, default_internal_enc); } - #endif +#endif } rb_ary_store(wrapper->fields, idx, rb_field); } @@ -89,6 +94,7 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { void * ptr; #ifdef HAVE_RUBY_ENCODING_H rb_encoding *default_internal_enc = rb_default_internal_encoding(); + rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding")); #endif GetMysql2Result(self, wrapper); @@ -115,7 +121,7 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { } for (i = 0; i < wrapper->numberOfFields; i++) { - VALUE field = rb_mysql_result_fetch_field(wrapper, i, symbolizeKeys); + VALUE field = rb_mysql_result_fetch_field(self, i, symbolizeKeys); if (row[i]) { VALUE val; switch(fields[i].type) { @@ -196,7 +202,9 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { if ((fields[i].flags & BINARY_FLAG) || fields[i].charsetnr == 63) { rb_enc_associate(val, binaryEncoding); } else { - rb_enc_associate(val, utf8Encoding); + // TODO: we should probably lookup the encoding that was set on the field + // and either use that and/or fall back to the connection's encoding + rb_enc_associate(val, conn_enc); if (default_internal_enc) { val = rb_str_export_to_enc(val, default_internal_enc); } @@ -225,7 +233,7 @@ static VALUE rb_mysql_result_fetch_fields(VALUE self) { if (RARRAY_LEN(wrapper->fields) != wrapper->numberOfFields) { for (i=0; inumberOfFields; i++) { - rb_mysql_result_fetch_field(wrapper, i, 0); + rb_mysql_result_fetch_field(self, i, 0); } } @@ -318,7 +326,6 @@ void init_mysql2_result() intern_utc = rb_intern("utc"); #ifdef HAVE_RUBY_ENCODING_H - utf8Encoding = rb_utf8_encoding(); binaryEncoding = rb_enc_find("binary"); #endif } diff --git a/lib/mysql2/client.rb b/lib/mysql2/client.rb index 0fc61fd..9450f91 100644 --- a/lib/mysql2/client.rb +++ b/lib/mysql2/client.rb @@ -8,7 +8,7 @@ module Mysql2 send(:"#{key}=", opts[key]) end # force the encoding to utf8 - self.charset_name = 'utf8' + self.charset_name = opts[:encoding] || 'utf8' ssl_set(*opts.values_at(:sslkey, :sslcert, :sslca, :sslcapath, :sslciper)) @@ -21,5 +21,54 @@ module Mysql2 connect user, pass, host, port, database, socket end + + # NOTE: from ruby-mysql + if defined? Encoding + CHARSET_MAP = { + "armscii8" => nil, + "ascii" => Encoding::US_ASCII, + "big5" => Encoding::Big5, + "binary" => Encoding::ASCII_8BIT, + "cp1250" => Encoding::Windows_1250, + "cp1251" => Encoding::Windows_1251, + "cp1256" => Encoding::Windows_1256, + "cp1257" => Encoding::Windows_1257, + "cp850" => Encoding::CP850, + "cp852" => Encoding::CP852, + "cp866" => Encoding::IBM866, + "cp932" => Encoding::Windows_31J, + "dec8" => nil, + "eucjpms" => Encoding::EucJP_ms, + "euckr" => Encoding::EUC_KR, + "gb2312" => Encoding::EUC_CN, + "gbk" => Encoding::GBK, + "geostd8" => nil, + "greek" => Encoding::ISO_8859_7, + "hebrew" => Encoding::ISO_8859_8, + "hp8" => nil, + "keybcs2" => nil, + "koi8r" => Encoding::KOI8_R, + "koi8u" => Encoding::KOI8_U, + "latin1" => Encoding::ISO_8859_1, + "latin2" => Encoding::ISO_8859_2, + "latin5" => Encoding::ISO_8859_9, + "latin7" => Encoding::ISO_8859_13, + "macce" => Encoding::MacCentEuro, + "macroman" => Encoding::MacRoman, + "sjis" => Encoding::SHIFT_JIS, + "swe7" => nil, + "tis620" => Encoding::TIS_620, + "ucs2" => Encoding::UTF_16BE, + "ujis" => Encoding::EucJP_ms, + "utf8" => Encoding::UTF_8, + } + + def self.encoding_from_charset(charset) + charset = charset.to_s.downcase + enc = CHARSET_MAP[charset] + raise Mysql2::Error, "unsupported charset: #{charset}" unless enc + enc + end + end end end diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index bcfec2a..88cf602 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -74,11 +74,14 @@ describe Mysql2::Client do info[:version].class.should eql(String) end - if RUBY_VERSION =~ /^1.9/ + if defined? Encoding context "strings returned by #info" do - it "should default to utf-8 if Encoding.default_internal is nil" do + it "should default to the connection's encoding if Encoding.default_internal is nil" do Encoding.default_internal = nil @client.info[:version].encoding.should eql(Encoding.find('utf-8')) + + client2 = Mysql2::Client.new :encoding => 'ascii' + client2.info[:version].encoding.should eql(Encoding.find('us-ascii')) end it "should use Encoding.default_internal" do @@ -103,11 +106,14 @@ describe Mysql2::Client do server_info[:version].class.should eql(String) end - if RUBY_VERSION =~ /^1.9/ + if defined? Encoding context "strings returned by #server_info" do - it "should default to utf-8 if Encoding.default_internal is nil" do + it "should default to the connection's encoding if Encoding.default_internal is nil" do Encoding.default_internal = nil @client.server_info[:version].encoding.should eql(Encoding.find('utf-8')) + + client2 = Mysql2::Client.new :encoding => 'ascii' + client2.server_info[:version].encoding.should eql(Encoding.find('us-ascii')) end it "should use Encoding.default_internal" do diff --git a/spec/mysql2/result_spec.rb b/spec/mysql2/result_spec.rb index d167c2a..2fa0bb8 100644 --- a/spec/mysql2/result_spec.rb +++ b/spec/mysql2/result_spec.rb @@ -209,12 +209,17 @@ describe Mysql2::Result do @test_result['enum_test'].should eql('val1') end - if RUBY_VERSION =~ /^1.9/ + if defined? Encoding context "string encoding for ENUM values" do - it "should default to utf-8 if Encoding.default_internal is nil" do + it "should default to the connection's encoding if Encoding.default_internal is nil" do Encoding.default_internal = nil result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first result['enum_test'].encoding.should eql(Encoding.find('utf-8')) + + client2 = Mysql2::Client.new :encoding => 'ascii' + client2.query "USE test" + result = client2.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first + result['enum_test'].encoding.should eql(Encoding.find('us-ascii')) end it "should use Encoding.default_internal" do @@ -233,12 +238,17 @@ describe Mysql2::Result do @test_result['set_test'].should eql('val1,val2') end - if RUBY_VERSION =~ /^1.9/ + if defined? Encoding context "string encoding for SET values" do - it "should default to utf-8 if Encoding.default_internal is nil" do + it "should default to the connection's encoding if Encoding.default_internal is nil" do Encoding.default_internal = nil result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first result['set_test'].encoding.should eql(Encoding.find('utf-8')) + + client2 = Mysql2::Client.new :encoding => 'ascii' + client2.query "USE test" + result = client2.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first + result['set_test'].encoding.should eql(Encoding.find('us-ascii')) end it "should use Encoding.default_internal" do @@ -257,7 +267,7 @@ describe Mysql2::Result do @test_result['binary_test'].should eql("test#{"\000"*6}") end - if RUBY_VERSION =~ /^1.9/ + if defined? Encoding context "string encoding for BINARY values" do it "should default to binary if Encoding.default_internal is nil" do Encoding.default_internal = nil @@ -294,7 +304,7 @@ describe Mysql2::Result do @test_result[field].should eql("test") end - if RUBY_VERSION =~ /^1.9/ + if defined? Encoding context "string encoding for #{type} values" do if ['VARBINARY', 'TINYBLOB', 'BLOB', 'MEDIUMBLOB', 'LONGBLOB'].include?(type) it "should default to binary if Encoding.default_internal is nil" do @@ -316,6 +326,11 @@ describe Mysql2::Result do Encoding.default_internal = nil result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first result[field].encoding.should eql(Encoding.find('utf-8')) + + client2 = Mysql2::Client.new :encoding => 'ascii' + client2.query "USE test" + result = client2.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first + result[field].encoding.should eql(Encoding.find('us-ascii')) end it "should use Encoding.default_internal" do