Merge remote branch 'origin/encodingz_4_life'

* origin/encodingz_4_life:
  make sure the query string is converted to the connection's encoding before it's handed to libmysql
  Initial refactor of encoding support to ensure we map Ruby encodings to MySQL encodings properly.
This commit is contained in:
Brian Lopez 2010-07-09 09:04:08 -07:00
commit 5cc79feb35
6 changed files with 126 additions and 40 deletions

View File

@ -1,6 +1,6 @@
#include <mysql2_ext.h> #include <mysql2_ext.h>
VALUE mMysql2; VALUE mMysql2, cMysql2Client;
VALUE cMysql2Error; VALUE cMysql2Error;
ID sym_id, sym_version, sym_async; ID sym_id, sym_version, sym_async;
@ -10,10 +10,6 @@ ID sym_id, sym_version, sym_async;
return Qnil; \ return Qnil; \
} }
#ifdef HAVE_RUBY_ENCODING_H
rb_encoding *utf8Encoding;
#endif
/* /*
* non-blocking mysql_*() functions that we won't be wrapping since * non-blocking mysql_*() functions that we won't be wrapping since
* they do not appear to hit the network nor issue any interruptible * they do not appear to hit the network nor issue any interruptible
@ -204,7 +200,9 @@ static VALUE rb_mysql_client_async_result(VALUE self) {
return Qnil; return Qnil;
} }
return rb_mysql_result_to_obj(result); VALUE resultObj = rb_mysql_result_to_obj(result);
rb_iv_set(resultObj, "@encoding", rb_iv_get(self, "@encoding"));
return resultObj;
} }
static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) {
@ -223,6 +221,12 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) {
} }
} }
#ifdef HAVE_RUBY_ENCODING_H
rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding"));
// ensure the string is in the encoding the connection is expecting
args.sql = rb_str_export_to_enc(args.sql, conn_enc);
#endif
Check_Type(args.sql, T_STRING); Check_Type(args.sql, T_STRING);
Data_Get_Struct(self, MYSQL, client); Data_Get_Struct(self, MYSQL, client);
@ -264,6 +268,9 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) {
unsigned long newLen, oldLen; unsigned long newLen, oldLen;
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding"));
// ensure the string is in the encoding the connection is expecting
str = rb_str_export_to_enc(str, conn_enc);
#endif #endif
Check_Type(str, T_STRING); Check_Type(str, T_STRING);
@ -280,7 +287,7 @@ static VALUE rb_mysql_client_escape(VALUE self, VALUE str) {
} else { } else {
newStr = rb_str_new(escaped, newLen); newStr = rb_str_new(escaped, newLen);
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_enc_associate(newStr, utf8Encoding); rb_enc_associate(newStr, conn_enc);
if (default_internal_enc) { if (default_internal_enc) {
newStr = rb_str_export_to_enc(newStr, default_internal_enc); newStr = rb_str_export_to_enc(newStr, default_internal_enc);
} }
@ -293,12 +300,13 @@ static VALUE rb_mysql_client_info(RB_MYSQL_UNUSED VALUE self) {
VALUE version = rb_hash_new(), client_info; VALUE version = rb_hash_new(), client_info;
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding"));
#endif #endif
rb_hash_aset(version, sym_id, LONG2NUM(mysql_get_client_version())); rb_hash_aset(version, sym_id, LONG2NUM(mysql_get_client_version()));
client_info = rb_str_new2(mysql_get_client_info()); client_info = rb_str_new2(mysql_get_client_info());
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_enc_associate(client_info, utf8Encoding); rb_enc_associate(client_info, conn_enc);
if (default_internal_enc) { if (default_internal_enc) {
client_info = rb_str_export_to_enc(client_info, default_internal_enc); client_info = rb_str_export_to_enc(client_info, default_internal_enc);
} }
@ -312,6 +320,7 @@ static VALUE rb_mysql_client_server_info(VALUE self) {
VALUE version, server_info; VALUE version, server_info;
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding"));
#endif #endif
Data_Get_Struct(self, MYSQL, client); Data_Get_Struct(self, MYSQL, client);
@ -321,7 +330,7 @@ static VALUE rb_mysql_client_server_info(VALUE self) {
rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(client))); rb_hash_aset(version, sym_id, LONG2FIX(mysql_get_server_version(client)));
server_info = rb_str_new2(mysql_get_server_info(client)); server_info = rb_str_new2(mysql_get_server_info(client));
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_enc_associate(server_info, utf8Encoding); rb_enc_associate(server_info, conn_enc);
if (default_internal_enc) { if (default_internal_enc) {
server_info = rb_str_export_to_enc(server_info, default_internal_enc); server_info = rb_str_export_to_enc(server_info, default_internal_enc);
} }
@ -397,6 +406,15 @@ static VALUE set_charset_name(VALUE self, VALUE value)
Data_Get_Struct(self, MYSQL, client); Data_Get_Struct(self, MYSQL, client);
#ifdef HAVE_RUBY_ENCODING_H
VALUE new_encoding, old_encoding;
new_encoding = rb_funcall(cMysql2Client, rb_intern("encoding_from_charset"), 1, value);
old_encoding = rb_iv_get(self, "@encoding");
if (old_encoding == Qnil) {
rb_iv_set(self, "@encoding", new_encoding);
}
#endif
charset_name = StringValuePtr(value); charset_name = StringValuePtr(value);
if (mysql_options(client, MYSQL_SET_CHARSET_NAME, charset_name)) { if (mysql_options(client, MYSQL_SET_CHARSET_NAME, charset_name)) {
@ -440,7 +458,7 @@ static VALUE init_connection(VALUE self)
/* Ruby Extension initializer */ /* Ruby Extension initializer */
void Init_mysql2() { void Init_mysql2() {
mMysql2 = rb_define_module("Mysql2"); mMysql2 = rb_define_module("Mysql2");
VALUE cMysql2Client = rb_define_class_under(mMysql2, "Client", rb_cObject); cMysql2Client = rb_define_class_under(mMysql2, "Client", rb_cObject);
rb_define_alloc_func(cMysql2Client, allocate); rb_define_alloc_func(cMysql2Client, allocate);
@ -463,11 +481,6 @@ void Init_mysql2() {
cMysql2Error = rb_const_get(mMysql2, rb_intern("Error")); cMysql2Error = rb_const_get(mMysql2, rb_intern("Error"));
#ifdef HAVE_RUBY_ENCODING_H
utf8Encoding = rb_utf8_encoding();
#endif
init_mysql2_result(); init_mysql2_result();
sym_id = ID2SYM(rb_intern("id")); sym_id = ID2SYM(rb_intern("id"));

View File

@ -20,10 +20,6 @@
#include <ruby/encoding.h> #include <ruby/encoding.h>
#endif #endif
#ifdef HAVE_RUBY_ENCODING_H
extern rb_encoding *utf8Encoding;
#endif
#if defined(__GNUC__) && (__GNUC__ >= 3) #if defined(__GNUC__) && (__GNUC__ >= 3)
#define RB_MYSQL_UNUSED __attribute__ ((unused)) #define RB_MYSQL_UNUSED __attribute__ ((unused))
#else #else

View File

@ -45,7 +45,11 @@ static VALUE nogvl_fetch_row(void *ptr) {
return (VALUE)mysql_fetch_row(result); return (VALUE)mysql_fetch_row(result);
} }
static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsigned int idx, short int symbolize_keys) { static VALUE rb_mysql_result_fetch_field(VALUE self, unsigned int idx, short int symbolize_keys) {
mysql2_result_wrapper * wrapper;
GetMysql2Result(self, wrapper);
if (wrapper->fields == Qnil) { if (wrapper->fields == Qnil) {
wrapper->numberOfFields = mysql_num_fields(wrapper->result); wrapper->numberOfFields = mysql_num_fields(wrapper->result);
wrapper->fields = rb_ary_new2(wrapper->numberOfFields); wrapper->fields = rb_ary_new2(wrapper->numberOfFields);
@ -54,9 +58,10 @@ static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsign
VALUE rb_field = rb_ary_entry(wrapper->fields, idx); VALUE rb_field = rb_ary_entry(wrapper->fields, idx);
if (rb_field == Qnil) { if (rb_field == Qnil) {
MYSQL_FIELD *field = NULL; MYSQL_FIELD *field = NULL;
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
#endif rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding"));
#endif
field = mysql_fetch_field_direct(wrapper->result, idx); field = mysql_fetch_field_direct(wrapper->result, idx);
if (symbolize_keys) { if (symbolize_keys) {
@ -66,12 +71,12 @@ static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsign
rb_field = ID2SYM(rb_intern(buf)); rb_field = ID2SYM(rb_intern(buf));
} else { } else {
rb_field = rb_str_new(field->name, field->name_length); rb_field = rb_str_new(field->name, field->name_length);
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_enc_associate(rb_field, utf8Encoding); rb_enc_associate(rb_field, conn_enc);
if (default_internal_enc) { if (default_internal_enc) {
rb_field = rb_str_export_to_enc(rb_field, default_internal_enc); rb_field = rb_str_export_to_enc(rb_field, default_internal_enc);
} }
#endif #endif
} }
rb_ary_store(wrapper->fields, idx, rb_field); rb_ary_store(wrapper->fields, idx, rb_field);
} }
@ -89,6 +94,7 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) {
void * ptr; void * ptr;
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
rb_encoding *default_internal_enc = rb_default_internal_encoding(); rb_encoding *default_internal_enc = rb_default_internal_encoding();
rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding"));
#endif #endif
GetMysql2Result(self, wrapper); GetMysql2Result(self, wrapper);
@ -115,7 +121,7 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) {
} }
for (i = 0; i < wrapper->numberOfFields; i++) { for (i = 0; i < wrapper->numberOfFields; i++) {
VALUE field = rb_mysql_result_fetch_field(wrapper, i, symbolizeKeys); VALUE field = rb_mysql_result_fetch_field(self, i, symbolizeKeys);
if (row[i]) { if (row[i]) {
VALUE val; VALUE val;
switch(fields[i].type) { switch(fields[i].type) {
@ -196,7 +202,9 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) {
if ((fields[i].flags & BINARY_FLAG) || fields[i].charsetnr == 63) { if ((fields[i].flags & BINARY_FLAG) || fields[i].charsetnr == 63) {
rb_enc_associate(val, binaryEncoding); rb_enc_associate(val, binaryEncoding);
} else { } else {
rb_enc_associate(val, utf8Encoding); // TODO: we should probably lookup the encoding that was set on the field
// and either use that and/or fall back to the connection's encoding
rb_enc_associate(val, conn_enc);
if (default_internal_enc) { if (default_internal_enc) {
val = rb_str_export_to_enc(val, default_internal_enc); val = rb_str_export_to_enc(val, default_internal_enc);
} }
@ -225,7 +233,7 @@ static VALUE rb_mysql_result_fetch_fields(VALUE self) {
if (RARRAY_LEN(wrapper->fields) != wrapper->numberOfFields) { if (RARRAY_LEN(wrapper->fields) != wrapper->numberOfFields) {
for (i=0; i<wrapper->numberOfFields; i++) { for (i=0; i<wrapper->numberOfFields; i++) {
rb_mysql_result_fetch_field(wrapper, i, 0); rb_mysql_result_fetch_field(self, i, 0);
} }
} }
@ -318,7 +326,6 @@ void init_mysql2_result()
intern_utc = rb_intern("utc"); intern_utc = rb_intern("utc");
#ifdef HAVE_RUBY_ENCODING_H #ifdef HAVE_RUBY_ENCODING_H
utf8Encoding = rb_utf8_encoding();
binaryEncoding = rb_enc_find("binary"); binaryEncoding = rb_enc_find("binary");
#endif #endif
} }

View File

@ -8,7 +8,7 @@ module Mysql2
send(:"#{key}=", opts[key]) send(:"#{key}=", opts[key])
end end
# force the encoding to utf8 # force the encoding to utf8
self.charset_name = 'utf8' self.charset_name = opts[:encoding] || 'utf8'
ssl_set(*opts.values_at(:sslkey, :sslcert, :sslca, :sslcapath, :sslciper)) ssl_set(*opts.values_at(:sslkey, :sslcert, :sslca, :sslcapath, :sslciper))
@ -21,5 +21,54 @@ module Mysql2
connect user, pass, host, port, database, socket connect user, pass, host, port, database, socket
end end
# NOTE: from ruby-mysql
if defined? Encoding
CHARSET_MAP = {
"armscii8" => nil,
"ascii" => Encoding::US_ASCII,
"big5" => Encoding::Big5,
"binary" => Encoding::ASCII_8BIT,
"cp1250" => Encoding::Windows_1250,
"cp1251" => Encoding::Windows_1251,
"cp1256" => Encoding::Windows_1256,
"cp1257" => Encoding::Windows_1257,
"cp850" => Encoding::CP850,
"cp852" => Encoding::CP852,
"cp866" => Encoding::IBM866,
"cp932" => Encoding::Windows_31J,
"dec8" => nil,
"eucjpms" => Encoding::EucJP_ms,
"euckr" => Encoding::EUC_KR,
"gb2312" => Encoding::EUC_CN,
"gbk" => Encoding::GBK,
"geostd8" => nil,
"greek" => Encoding::ISO_8859_7,
"hebrew" => Encoding::ISO_8859_8,
"hp8" => nil,
"keybcs2" => nil,
"koi8r" => Encoding::KOI8_R,
"koi8u" => Encoding::KOI8_U,
"latin1" => Encoding::ISO_8859_1,
"latin2" => Encoding::ISO_8859_2,
"latin5" => Encoding::ISO_8859_9,
"latin7" => Encoding::ISO_8859_13,
"macce" => Encoding::MacCentEuro,
"macroman" => Encoding::MacRoman,
"sjis" => Encoding::SHIFT_JIS,
"swe7" => nil,
"tis620" => Encoding::TIS_620,
"ucs2" => Encoding::UTF_16BE,
"ujis" => Encoding::EucJP_ms,
"utf8" => Encoding::UTF_8,
}
def self.encoding_from_charset(charset)
charset = charset.to_s.downcase
enc = CHARSET_MAP[charset]
raise Mysql2::Error, "unsupported charset: #{charset}" unless enc
enc
end
end
end end
end end

View File

@ -74,11 +74,14 @@ describe Mysql2::Client do
info[:version].class.should eql(String) info[:version].class.should eql(String)
end end
if RUBY_VERSION =~ /^1.9/ if defined? Encoding
context "strings returned by #info" do context "strings returned by #info" do
it "should default to utf-8 if Encoding.default_internal is nil" do it "should default to the connection's encoding if Encoding.default_internal is nil" do
Encoding.default_internal = nil Encoding.default_internal = nil
@client.info[:version].encoding.should eql(Encoding.find('utf-8')) @client.info[:version].encoding.should eql(Encoding.find('utf-8'))
client2 = Mysql2::Client.new :encoding => 'ascii'
client2.info[:version].encoding.should eql(Encoding.find('us-ascii'))
end end
it "should use Encoding.default_internal" do it "should use Encoding.default_internal" do
@ -103,11 +106,14 @@ describe Mysql2::Client do
server_info[:version].class.should eql(String) server_info[:version].class.should eql(String)
end end
if RUBY_VERSION =~ /^1.9/ if defined? Encoding
context "strings returned by #server_info" do context "strings returned by #server_info" do
it "should default to utf-8 if Encoding.default_internal is nil" do it "should default to the connection's encoding if Encoding.default_internal is nil" do
Encoding.default_internal = nil Encoding.default_internal = nil
@client.server_info[:version].encoding.should eql(Encoding.find('utf-8')) @client.server_info[:version].encoding.should eql(Encoding.find('utf-8'))
client2 = Mysql2::Client.new :encoding => 'ascii'
client2.server_info[:version].encoding.should eql(Encoding.find('us-ascii'))
end end
it "should use Encoding.default_internal" do it "should use Encoding.default_internal" do

View File

@ -209,12 +209,17 @@ describe Mysql2::Result do
@test_result['enum_test'].should eql('val1') @test_result['enum_test'].should eql('val1')
end end
if RUBY_VERSION =~ /^1.9/ if defined? Encoding
context "string encoding for ENUM values" do context "string encoding for ENUM values" do
it "should default to utf-8 if Encoding.default_internal is nil" do it "should default to the connection's encoding if Encoding.default_internal is nil" do
Encoding.default_internal = nil Encoding.default_internal = nil
result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first
result['enum_test'].encoding.should eql(Encoding.find('utf-8')) result['enum_test'].encoding.should eql(Encoding.find('utf-8'))
client2 = Mysql2::Client.new :encoding => 'ascii'
client2.query "USE test"
result = client2.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first
result['enum_test'].encoding.should eql(Encoding.find('us-ascii'))
end end
it "should use Encoding.default_internal" do it "should use Encoding.default_internal" do
@ -233,12 +238,17 @@ describe Mysql2::Result do
@test_result['set_test'].should eql('val1,val2') @test_result['set_test'].should eql('val1,val2')
end end
if RUBY_VERSION =~ /^1.9/ if defined? Encoding
context "string encoding for SET values" do context "string encoding for SET values" do
it "should default to utf-8 if Encoding.default_internal is nil" do it "should default to the connection's encoding if Encoding.default_internal is nil" do
Encoding.default_internal = nil Encoding.default_internal = nil
result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first
result['set_test'].encoding.should eql(Encoding.find('utf-8')) result['set_test'].encoding.should eql(Encoding.find('utf-8'))
client2 = Mysql2::Client.new :encoding => 'ascii'
client2.query "USE test"
result = client2.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first
result['set_test'].encoding.should eql(Encoding.find('us-ascii'))
end end
it "should use Encoding.default_internal" do it "should use Encoding.default_internal" do
@ -257,7 +267,7 @@ describe Mysql2::Result do
@test_result['binary_test'].should eql("test#{"\000"*6}") @test_result['binary_test'].should eql("test#{"\000"*6}")
end end
if RUBY_VERSION =~ /^1.9/ if defined? Encoding
context "string encoding for BINARY values" do context "string encoding for BINARY values" do
it "should default to binary if Encoding.default_internal is nil" do it "should default to binary if Encoding.default_internal is nil" do
Encoding.default_internal = nil Encoding.default_internal = nil
@ -294,7 +304,7 @@ describe Mysql2::Result do
@test_result[field].should eql("test") @test_result[field].should eql("test")
end end
if RUBY_VERSION =~ /^1.9/ if defined? Encoding
context "string encoding for #{type} values" do context "string encoding for #{type} values" do
if ['VARBINARY', 'TINYBLOB', 'BLOB', 'MEDIUMBLOB', 'LONGBLOB'].include?(type) if ['VARBINARY', 'TINYBLOB', 'BLOB', 'MEDIUMBLOB', 'LONGBLOB'].include?(type)
it "should default to binary if Encoding.default_internal is nil" do it "should default to binary if Encoding.default_internal is nil" do
@ -316,6 +326,11 @@ describe Mysql2::Result do
Encoding.default_internal = nil Encoding.default_internal = nil
result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first
result[field].encoding.should eql(Encoding.find('utf-8')) result[field].encoding.should eql(Encoding.find('utf-8'))
client2 = Mysql2::Client.new :encoding => 'ascii'
client2.query "USE test"
result = client2.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1").first
result[field].encoding.should eql(Encoding.find('us-ascii'))
end end
it "should use Encoding.default_internal" do it "should use Encoding.default_internal" do