diff --git a/benchmark/query_with_mysql_casting.rb b/benchmark/query_with_mysql_casting.rb index 61546fe..f8b6dd8 100644 --- a/benchmark/query_with_mysql_casting.rb +++ b/benchmark/query_with_mysql_casting.rb @@ -45,8 +45,8 @@ Benchmark.bmbm do |x| x.report do puts "Mysql2" number_of.times do - mysql2_result = mysql2.query sql - mysql2_result.each(:symbolize_keys => true) do |res| + mysql2_result = mysql2.query sql, :symbolize_keys => true + mysql2_result.each do |res| # puts res.inspect end end diff --git a/benchmark/query_without_mysql_casting.rb b/benchmark/query_without_mysql_casting.rb index 7d0bb03..fa12bf6 100644 --- a/benchmark/query_without_mysql_casting.rb +++ b/benchmark/query_without_mysql_casting.rb @@ -17,8 +17,8 @@ Benchmark.bmbm do |x| x.report do puts "Mysql2" number_of.times do - mysql2_result = mysql2.query sql - mysql2_result.each(:symbolize_keys => true) do |res| + mysql2_result = mysql2.query sql, :symbolize_keys => true + mysql2_result.each do |res| # puts res.inspect end end diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index 356c6f5..dd99784 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -3,7 +3,8 @@ VALUE cMysql2Client; extern VALUE mMysql2, cMysql2Error, intern_encoding_from_charset; -extern ID sym_id, sym_version, sym_async; +extern ID sym_id, sym_version, sym_async, sym_symbolize_keys, sym_as, sym_array; +extern ID intern_merge; #define REQUIRE_OPEN_DB(_ctxt) \ if(!_ctxt->net.vio) { \ @@ -225,7 +226,12 @@ static VALUE rb_mysql_client_async_result(VALUE self) { } VALUE resultObj = rb_mysql_result_to_obj(result); + // pass-through query options for result construction later + rb_iv_set(resultObj, "@query_options", rb_obj_dup(rb_iv_get(self, "@query_options"))); + +#ifdef HAVE_RUBY_ENCODING_H rb_iv_set(resultObj, "@encoding", rb_iv_get(self, "@encoding")); +#endif return resultObj; } @@ -234,29 +240,31 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { fd_set fdset; int fd, retval; int async = 0; - VALUE opts; - VALUE rb_async; + VALUE opts, defaults; + MYSQL *client; - MYSQL * client; + Data_Get_Struct(self, MYSQL, client); + REQUIRE_OPEN_DB(client); + args.mysql = client; + defaults = rb_iv_get(self, "@query_options"); if (rb_scan_args(argc, argv, "11", &args.sql, &opts) == 2) { - if ((rb_async = rb_hash_aref(opts, sym_async)) != Qnil) { - async = rb_async == Qtrue ? 1 : 0; + opts = rb_funcall(defaults, intern_merge, 1, opts); + rb_iv_set(self, "@query_options", opts); + + if (rb_hash_aref(opts, sym_async) == Qtrue) { + async = 1; } + } else { + opts = defaults; } - Check_Type(args.sql, T_STRING); #ifdef HAVE_RUBY_ENCODING_H rb_encoding *conn_enc = rb_to_encoding(rb_iv_get(self, "@encoding")); // ensure the string is in the encoding the connection is expecting args.sql = rb_str_export_to_enc(args.sql, conn_enc); #endif - Data_Get_Struct(self, MYSQL, client); - - REQUIRE_OPEN_DB(client); - - args.mysql = client; if (rb_thread_blocking_region(nogvl_send_query, &args, RUBY_UBF_IO, 0) == Qfalse) { return rb_raise_mysql2_error(client); } @@ -272,15 +280,20 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { retval = rb_thread_select(fd + 1, &fdset, NULL, NULL, NULL); if (retval < 0) { - rb_sys_fail(0); + rb_sys_fail(0); } if (retval > 0) { - break; + break; } } - return rb_mysql_client_async_result(self); + VALUE result = rb_mysql_client_async_result(self); + + // pass-through query options for result construction later + rb_iv_set(result, "@query_options", rb_obj_dup(opts)); + + return result; } else { return Qnil; } diff --git a/ext/mysql2/mysql2_ext.c b/ext/mysql2/mysql2_ext.c index 68e276b..284431f 100644 --- a/ext/mysql2/mysql2_ext.c +++ b/ext/mysql2/mysql2_ext.c @@ -1,16 +1,22 @@ #include VALUE mMysql2, cMysql2Error, intern_encoding_from_charset; -ID sym_id, sym_version, sym_async; +ID sym_id, sym_version, sym_async, sym_symbolize_keys, sym_as, sym_array; +ID intern_merge; /* Ruby Extension initializer */ void Init_mysql2() { mMysql2 = rb_define_module("Mysql2"); cMysql2Error = rb_const_get(mMysql2, rb_intern("Error")); - sym_id = ID2SYM(rb_intern("id")); - sym_version = ID2SYM(rb_intern("version")); - sym_async = ID2SYM(rb_intern("async")); + intern_merge = rb_intern("merge"); + + sym_array = ID2SYM(rb_intern("array")); + sym_as = ID2SYM(rb_intern("as")); + sym_id = ID2SYM(rb_intern("id")); + sym_version = ID2SYM(rb_intern("version")); + sym_async = ID2SYM(rb_intern("async")); + sym_symbolize_keys = ID2SYM(rb_intern("symbolize_keys")); intern_encoding_from_charset = rb_intern("encoding_from_charset"); diff --git a/ext/mysql2/result.c b/ext/mysql2/result.c index 6a2df6c..f2afbd0 100644 --- a/ext/mysql2/result.c +++ b/ext/mysql2/result.c @@ -4,12 +4,13 @@ rb_encoding *binaryEncoding; #endif -ID sym_symbolize_keys; ID intern_new, intern_utc, intern_encoding_from_charset_code; VALUE cMysql2Result; VALUE cBigDecimal, cDate, cDateTime; extern VALUE mMysql2, cMysql2Client, cMysql2Error, intern_encoding_from_charset; +extern ID sym_symbolize_keys, sym_as, sym_array; +extern ID intern_merge; static void rb_mysql_result_mark(void * wrapper) { mysql2_result_wrapper * w = wrapper; @@ -85,12 +86,13 @@ static VALUE rb_mysql_result_fetch_field(VALUE self, unsigned int idx, short int return rb_field; } -static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { - VALUE rowHash, opts, block; +static VALUE rb_mysql_result_fetch_row(VALUE self, VALUE opts) { + VALUE rowVal; mysql2_result_wrapper * wrapper; MYSQL_ROW row; MYSQL_FIELD * fields = NULL; - unsigned int i = 0, symbolizeKeys = 0; + unsigned int i = 0; + int symbolizeKeys = 0, asArray = 0; unsigned long * fieldLengths; void * ptr; #ifdef HAVE_RUBY_ENCODING_H @@ -100,11 +102,12 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { GetMysql2Result(self, wrapper); - if (rb_scan_args(argc, argv, "01&", &opts, &block) == 1) { - Check_Type(opts, T_HASH); - if (rb_hash_aref(opts, sym_symbolize_keys) == Qtrue) { - symbolizeKeys = 1; - } + if (rb_hash_aref(opts, sym_symbolize_keys) == Qtrue) { + symbolizeKeys = 1; + } + + if (rb_hash_aref(opts, sym_as) == sym_array) { + asArray = 1; } ptr = wrapper->result; @@ -113,7 +116,11 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { return Qnil; } - rowHash = rb_hash_new(); + if (asArray) { + rowVal = rb_ary_new2(wrapper->numberOfFields); + } else { + rowVal = rb_hash_new(); + } fields = mysql_fetch_fields(wrapper->result); fieldLengths = mysql_fetch_lengths(wrapper->result); if (wrapper->fields == Qnil) { @@ -220,12 +227,20 @@ static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { #endif break; } - rb_hash_aset(rowHash, field, val); + if (asArray) { + rb_ary_push(rowVal, val); + } else { + rb_hash_aset(rowVal, field, val); + } } else { - rb_hash_aset(rowHash, field, Qnil); + if (asArray) { + rb_ary_push(rowVal, Qnil); + } else { + rb_hash_aset(rowVal, field, Qnil); + } } } - return rowHash; + return rowVal; } static VALUE rb_mysql_result_fetch_fields(VALUE self) { @@ -249,18 +264,24 @@ static VALUE rb_mysql_result_fetch_fields(VALUE self) { } static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { - VALUE opts, block; + VALUE defaults, opts, block; mysql2_result_wrapper * wrapper; unsigned long i; GetMysql2Result(self, wrapper); - rb_scan_args(argc, argv, "01&", &opts, &block); + defaults = rb_iv_get(self, "@query_options"); + if (rb_scan_args(argc, argv, "01&", &opts, &block) == 1) { + opts = rb_funcall(defaults, intern_merge, 1, opts); + } else { + opts = defaults; + } if (wrapper->lastRowProcessed == 0) { wrapper->numberOfRows = mysql_num_rows(wrapper->result); if (wrapper->numberOfRows == 0) { - return Qnil; + wrapper->rows = rb_ary_new(); + return wrapper->rows; } wrapper->rows = rb_ary_new2(wrapper->numberOfRows); } @@ -279,7 +300,7 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { if (i < rowsProcessed) { row = rb_ary_entry(wrapper->rows, i); } else { - row = rb_mysql_result_fetch_row(argc, argv, self); + row = rb_mysql_result_fetch_row(self, opts); rb_ary_store(wrapper->rows, i, row); wrapper->lastRowProcessed++; } @@ -319,8 +340,7 @@ VALUE rb_mysql_result_to_obj(MYSQL_RES * r) { return obj; } -void init_mysql2_result() -{ +void init_mysql2_result() { cBigDecimal = rb_const_get(rb_cObject, rb_intern("BigDecimal")); cDate = rb_const_get(rb_cObject, rb_intern("Date")); cDateTime = rb_const_get(rb_cObject, rb_intern("DateTime")); @@ -329,7 +349,6 @@ void init_mysql2_result() rb_define_method(cMysql2Result, "each", rb_mysql_result_each, -1); rb_define_method(cMysql2Result, "fields", rb_mysql_result_fetch_fields, 0); - sym_symbolize_keys = ID2SYM(rb_intern("symbolize_keys")); intern_new = rb_intern("new"); intern_utc = rb_intern("utc"); intern_encoding_from_charset_code = rb_intern("encoding_from_charset_code"); diff --git a/lib/active_record/connection_adapters/mysql2_adapter.rb b/lib/active_record/connection_adapters/mysql2_adapter.rb index 13e847a..01327b7 100644 --- a/lib/active_record/connection_adapters/mysql2_adapter.rb +++ b/lib/active_record/connection_adapters/mysql2_adapter.rb @@ -161,6 +161,7 @@ module ActiveRecord super(connection, logger) @connection_options, @config = connection_options, config @quoted_column_names, @quoted_table_names = {}, {} + configure_connection end def adapter_name @@ -263,14 +264,36 @@ module ActiveRecord # DATABASE STATEMENTS ====================================== + # Returns a record hash with the column names as keys and column values + # as values. + def select_one(sql, name = nil) + result = execute(sql, name) + result.each(:as => :hash) do |r| + return r + end + end + + # Returns a single value from a record + def select_value(sql, name = nil) + result = execute(sql, name) + if first = result.first + first.first + end + end + + # Returns an array of the values of the first column in a select: + # select_values("SELECT id FROM companies LIMIT 3") => [1,2,3] def select_values(sql, name = nil) - select(sql, name).map { |row| row.values.first } + execute(sql, name).map { |row| row.first } end + # Returns an array of arrays containing the field values. + # Order is the same as that returned by +columns+. def select_rows(sql, name = nil) - select(sql, name).map { |row| row.values } + execute(sql, name).to_a end + # Executes the SQL statement in the context of this connection. def execute(sql, name = nil) if name == :skip_logging @connection.query(sql) @@ -393,8 +416,8 @@ module ActiveRecord def tables(name = nil) tables = [] - execute("SHOW TABLES", name).each(:symbolize_keys => true) do |field| - tables << field.values.first + execute("SHOW TABLES", name).each do |field| + tables << field.first end tables end @@ -407,7 +430,7 @@ module ActiveRecord indexes = [] current_index = nil result = execute("SHOW KEYS FROM #{quote_table_name(table_name)}", name) - result.each(:symbolize_keys => true) do |row| + result.each(:symbolize_keys => true, :as => :hash) do |row| if current_index != row[:Key_name] next if row[:Key_name] == PRIMARY # skip the primary key current_index = row[:Key_name] @@ -423,7 +446,7 @@ module ActiveRecord sql = "SHOW FIELDS FROM #{quote_table_name(table_name)}" columns = [] result = execute(sql, :skip_logging) - result.each(:symbolize_keys => true) { |field| + result.each(:symbolize_keys => true, :as => :hash) { |field| columns << Mysql2Column.new(field[:Field], field[:Default], field[:Type], field[:Null] == "YES") } columns @@ -520,7 +543,7 @@ module ActiveRecord def pk_and_sequence_for(table) keys = [] result = execute("describe #{quote_table_name(table)}") - result.each(:symbolize_keys => true) do |row| + result.each(:symbolize_keys => true, :as => :hash) do |row| keys << row[:Field] if row[:Key] == "PRI" end keys.length == 1 ? [keys.first, nil] : nil @@ -574,6 +597,7 @@ module ActiveRecord end def configure_connection + @connection.query_options.merge!(:as => :array) encoding = @config[:encoding] execute("SET NAMES '#{encoding}'", :skip_logging) if encoding @@ -582,8 +606,10 @@ module ActiveRecord execute("SET SQL_AUTO_IS_NULL=0", :skip_logging) end + # Returns an array of record hashes with the column names as keys and + # column values as values. def select(sql, name = nil) - execute(sql, name).to_a + result = execute(sql, name).each(:as => :hash) end def supports_views? diff --git a/lib/mysql2/client.rb b/lib/mysql2/client.rb index d80bb89..d2c380e 100644 --- a/lib/mysql2/client.rb +++ b/lib/mysql2/client.rb @@ -1,6 +1,15 @@ module Mysql2 class Client - def initialize opts = {} + attr_reader :query_options + @@default_query_options = { + :symbolize_keys => false, + :async => false, + :as => :hash + } + + def initialize(opts = {}) + @query_options = @@default_query_options.dup + init_connection [:reconnect, :connect_timeout].each do |key| @@ -23,6 +32,10 @@ module Mysql2 connect user, pass, host, port, database, socket end + def self.default_query_options + @@default_query_options + end + # NOTE: from ruby-mysql if defined? Encoding CHARSET_MAP = { diff --git a/spec/active_record/active_record_spec.rb b/spec/active_record/active_record_spec.rb index 922ac7e..2b6df1b 100644 --- a/spec/active_record/active_record_spec.rb +++ b/spec/active_record/active_record_spec.rb @@ -21,8 +21,8 @@ describe ActiveRecord::ConnectionAdapters::Mysql2Adapter do end it "should be able to execute a raw query" do - @connection.execute("SELECT 1 as one").first['one'].should eql(1) - @connection.execute("SELECT NOW() as n").first['n'].class.should eql(Time) + @connection.execute("SELECT 1 as one").first.first.should eql(1) + @connection.execute("SELECT NOW() as n").first.first.class.should eql(Time) end end @@ -72,7 +72,7 @@ describe ActiveRecord::ConnectionAdapters::Mysql2Adapter do end after(:all) do - Mysql2Test2.connection.execute("DELETE FROM mysql2_test WHERE id=#{@test_result['id']}") + Mysql2Test2.connection.execute("DELETE FROM mysql2_test WHERE id=#{@test_result.first}") end it "default value should be cast to the expected type of the field" do diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index 7898390..262dc6f 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -14,6 +14,10 @@ describe Mysql2::Client do end end + it "should have a global default_query_options hash" do + Mysql2::Client.should respond_to(:default_query_options) + end + it "should be able to connect via SSL options" do pending("DON'T WORRY, THIS TEST PASSES :) - but is machine-specific. You need to have MySQL running with SSL configured and enabled. Then update the paths in this test to your needs and remove the pending state.") ssl_client = nil @@ -49,6 +53,26 @@ describe Mysql2::Client do @client.should respond_to(:query) end + context "#query" do + it "should accept an options hash that inherits from Mysql2::Client.default_query_options" do + @client.query "SELECT 1", :something => :else + @client.query_options.should eql(@client.query_options.merge(:something => :else)) + end + + it "should return results as a hash by default" do + @client.query("SELECT 1").first.class.should eql(Hash) + end + + it "should be able to return results as an array" do + @client.query("SELECT 1", :as => :array).first.class.should eql(Array) + @client.query("SELECT 1").each(:as => :array) + end + + it "should be able to return results with symbolized keys" do + @client.query("SELECT 1", :symbolize_keys => true).first.keys[0].class.should eql(Symbol) + end + end + it "should respond to #escape" do @client.should respond_to(:escape) end diff --git a/spec/mysql2/result_spec.rb b/spec/mysql2/result_spec.rb index 900101a..3d553b2 100644 --- a/spec/mysql2/result_spec.rb +++ b/spec/mysql2/result_spec.rb @@ -2,7 +2,7 @@ require 'spec_helper' describe Mysql2::Result do - before(:all) do + before(:each) do @client = Mysql2::Client.new :host => "localhost", :username => "root" end @@ -37,18 +37,23 @@ describe Mysql2::Result do it "should yield rows as hash's with symbol keys if :symbolize_keys was set to true" do @result.each(:symbolize_keys => true) do |row| - row.class.should eql(Hash) row.keys.first.class.should eql(Symbol) end end + it "should be able to return results as an array" do + @result.each(:as => :array) do |row| + row.class.should eql(Array) + end + end + it "should cache previously yielded results" do @result.first.should eql(@result.first) end end context "#fields" do - before(:all) do + before(:each) do @client.query "USE test" @test_result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1") end @@ -64,7 +69,7 @@ describe Mysql2::Result do end context "row data type mapping" do - before(:all) do + before(:each) do @client.query "USE test" @client.query %[ CREATE TABLE IF NOT EXISTS mysql2_test (