diff --git a/ext/mysql2_ext.c b/ext/mysql2_ext.c index 669f076..dbbea33 100644 --- a/ext/mysql2_ext.c +++ b/ext/mysql2_ext.c @@ -453,6 +453,8 @@ static VALUE rb_mysql_result_to_obj(MYSQL_RES * r) { wrapper->lastRowProcessed = 0; wrapper->resultFreed = 0; wrapper->result = r; + wrapper->fields = Qnil; + wrapper->rows = Qnil; rb_obj_call_init(obj, 0, NULL); return obj; } @@ -481,6 +483,60 @@ static void rb_mysql_result_mark(void * wrapper) { } } +static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsigned int idx, short int symbolize_keys) { + if (wrapper->fields == Qnil) { + wrapper->numberOfFields = mysql_num_fields(wrapper->result); + wrapper->fields = rb_ary_new2(wrapper->numberOfFields); + } + + VALUE rb_field = rb_ary_entry(wrapper->fields, idx); + if (rb_field == Qnil) { + MYSQL_FIELD *field = NULL; + #ifdef HAVE_RUBY_ENCODING_H + rb_encoding *default_internal_enc = rb_default_internal_encoding(); + #endif + + field = mysql_fetch_field_direct(wrapper->result, idx); + if (symbolize_keys) { + char buf[field->name_length+1]; + memcpy(buf, field->name, field->name_length); + buf[field->name_length] = 0; + rb_field = ID2SYM(rb_intern(buf)); + } else { + rb_field = rb_str_new(field->name, field->name_length); + #ifdef HAVE_RUBY_ENCODING_H + rb_enc_associate(rb_field, utf8Encoding); + if (default_internal_enc) { + rb_field = rb_str_export_to_enc(rb_field, default_internal_enc); + } + #endif + } + rb_ary_store(wrapper->fields, idx, rb_field); + } + + return rb_field; +} + +static VALUE rb_mysql_result_fetch_fields(VALUE self) { + mysql2_result_wrapper * wrapper; + unsigned int i = 0; + + GetMysql2Result(self, wrapper); + + if (wrapper->fields == Qnil) { + wrapper->numberOfFields = mysql_num_fields(wrapper->result); + wrapper->fields = rb_ary_new2(wrapper->numberOfFields); + } + + if (RARRAY_LEN(wrapper->fields) != wrapper->numberOfFields) { + for (i=0; inumberOfFields; i++) { + rb_mysql_result_fetch_field(wrapper, i, 0); + } + } + + return wrapper->fields; +} + /* * for small results, this won't hit the network, but there's no * reliable way for us to tell this so we'll always release the GVL @@ -519,37 +575,16 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { return Qnil; } - if (wrapper->numberOfFields == 0) { + rowHash = rb_hash_new(); + fields = mysql_fetch_fields(wrapper->result); + fieldLengths = mysql_fetch_lengths(wrapper->result); + if (wrapper->fields == Qnil) { wrapper->numberOfFields = mysql_num_fields(wrapper->result); wrapper->fields = rb_ary_new2(wrapper->numberOfFields); } - rowHash = rb_hash_new(); - fields = mysql_fetch_fields(wrapper->result); - fieldLengths = mysql_fetch_lengths(wrapper->result); for (i = 0; i < wrapper->numberOfFields; i++) { - - // lazily create fields, but only once - // we'll use cached versions from here on out - VALUE field = rb_ary_entry(wrapper->fields, i); - if (field == Qnil) { - if (symbolizeKeys) { - char buf[fields[i].name_length+1]; - memcpy(buf, fields[i].name, fields[i].name_length); - buf[fields[i].name_length] = 0; - field = ID2SYM(rb_intern(buf)); - } else { - field = rb_str_new(fields[i].name, fields[i].name_length); -#ifdef HAVE_RUBY_ENCODING_H - rb_enc_associate(field, utf8Encoding); - if (default_internal_enc) { - field = rb_str_export_to_enc(field, default_internal_enc); - } -#endif - } - rb_ary_store(wrapper->fields, i, field); - } - + VALUE field = rb_mysql_result_fetch_field(wrapper, i, symbolizeKeys); if (row[i]) { VALUE val; switch(fields[i].type) { @@ -747,6 +782,7 @@ void Init_mysql2_ext() { cMysql2Result = rb_define_class_under(mMysql2, "Result", rb_cObject); rb_define_method(cMysql2Result, "each", rb_mysql_result_each, -1); + rb_define_method(cMysql2Result, "fields", rb_mysql_result_fetch_fields, 0); VALUE mEnumerable = rb_const_get(rb_cObject, rb_intern("Enumerable")); rb_include_module(cMysql2Result, mEnumerable); diff --git a/ext/mysql2_ext.h b/ext/mysql2_ext.h index 3009ccd..b44af2c 100644 --- a/ext/mysql2_ext.h +++ b/ext/mysql2_ext.h @@ -55,10 +55,10 @@ static void rb_mysql_client_free(void * client); typedef struct { VALUE fields; VALUE rows; - unsigned long numberOfFields; + unsigned int numberOfFields; unsigned long numberOfRows; unsigned long lastRowProcessed; - int resultFreed; + short int resultFreed; MYSQL_RES *result; } mysql2_result_wrapper; #define GetMysql2Result(obj, sval) (sval = (mysql2_result_wrapper*)DATA_PTR(obj)); diff --git a/spec/mysql2/result_spec.rb b/spec/mysql2/result_spec.rb index 7c6f7f6..d747a23 100644 --- a/spec/mysql2/result_spec.rb +++ b/spec/mysql2/result_spec.rb @@ -47,6 +47,22 @@ describe Mysql2::Result do end end + context "#fields" do + before(:all) do + @client.query "USE test" + @test_result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1") + end + + it "method should exist" do + @test_result.should respond_to(:fields) + end + + it "should return an array of field names in proper order" do + result = @client.query "SELECT 'a', 'b', 'c'" + result.fields.should eql(['a', 'b', 'c']) + end + end + context "row data type mapping" do before(:all) do @client.query "USE test"