diff --git a/ext/mysql2/mysql2_ext.c b/ext/mysql2/mysql2_ext.c index 07d4a57..2b0dbf9 100644 --- a/ext/mysql2/mysql2_ext.c +++ b/ext/mysql2/mysql2_ext.c @@ -1,4 +1,8 @@ -#include "mysql2_ext.h" +#include + +VALUE mMysql2; +VALUE cMysql2Error; +ID sym_id, sym_version, sym_async; #define REQUIRE_OPEN_DB(_ctxt) \ if(!_ctxt->net.vio) { \ @@ -6,6 +10,10 @@ return Qnil; \ } +#ifdef HAVE_RUBY_ENCODING_H +rb_encoding *utf8Encoding; +#endif + /* * non-blocking mysql_*() functions that we won't be wrapping since * they do not appear to hit the network nor issue any interruptible @@ -28,6 +36,14 @@ * - mysql_ssl_set() */ +static VALUE rb_raise_mysql2_error(MYSQL *client) { + VALUE e = rb_exc_new2(cMysql2Error, mysql_error(client)); + rb_funcall(e, rb_intern("error_number="), 1, INT2NUM(mysql_errno(client))); + rb_funcall(e, rb_intern("sql_state="), 1, rb_tainted_str_new2(mysql_sqlstate(client))); + rb_exc_raise(e); + return Qnil; +} + static VALUE nogvl_init(void *ptr) { MYSQL * client = (MYSQL *)ptr; @@ -49,6 +65,37 @@ static VALUE nogvl_connect(void *ptr) { return client ? Qtrue : Qfalse; } +static void rb_mysql_client_free(void * ptr) { + MYSQL * client = (MYSQL *)ptr; + + /* + * we'll send a QUIT message to the server, but that message is more of a + * formality than a hard requirement since the socket is getting shutdown + * anyways, so ensure the socket write does not block our interpreter + */ + int fd = client->net.fd; + int flags; + + if (fd >= 0) { + /* + * if the socket is dead we have no chance of blocking, + * so ignore any potential fcntl errors since they don't matter + */ + flags = fcntl(fd, F_GETFL); + if (flags > 0 && !(flags & O_NONBLOCK)) + fcntl(fd, F_SETFL, flags | O_NONBLOCK); + } + + /* It's safe to call mysql_close() on an already closed connection. */ + mysql_close(client); + xfree(ptr); +} + +static VALUE nogvl_close(void * ptr) { + mysql_close((MYSQL *)ptr); + return Qnil; +} + static VALUE allocate(VALUE klass) { MYSQL * client; @@ -87,37 +134,6 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po return self; } -static void rb_mysql_client_free(void * ptr) { - MYSQL * client = (MYSQL *)ptr; - - /* - * we'll send a QUIT message to the server, but that message is more of a - * formality than a hard requirement since the socket is getting shutdown - * anyways, so ensure the socket write does not block our interpreter - */ - int fd = client->net.fd; - int flags; - - if (fd >= 0) { - /* - * if the socket is dead we have no chance of blocking, - * so ignore any potential fcntl errors since they don't matter - */ - flags = fcntl(fd, F_GETFL); - if (flags > 0 && !(flags & O_NONBLOCK)) - fcntl(fd, F_SETFL, flags | O_NONBLOCK); - } - - /* It's safe to call mysql_close() on an already closed connection. */ - mysql_close(client); - xfree(ptr); -} - -static VALUE nogvl_close(void * ptr) { - mysql_close((MYSQL *)ptr); - return Qnil; -} - /* * Immediately disconnect from the server, normally the garbage collector * will disconnect automatically when a connection is no longer needed. @@ -151,6 +167,46 @@ static VALUE nogvl_send_query(void *ptr) { return rv == 0 ? Qtrue : Qfalse; } +/* + * even though we did rb_thread_select before calling this, a large + * response can overflow the socket buffers and cause us to eventually + * block while calling mysql_read_query_result + */ +static VALUE nogvl_read_query_result(void *ptr) { + MYSQL * client = ptr; + my_bool res = mysql_read_query_result(client); + + return res == 0 ? Qtrue : Qfalse; +} + +/* mysql_store_result may (unlikely) read rows off the socket */ +static VALUE nogvl_store_result(void *ptr) { + MYSQL * client = ptr; + return (VALUE)mysql_store_result(client); +} + +static VALUE rb_mysql_client_async_result(VALUE self) { + MYSQL * client; + MYSQL_RES * result; + + Data_Get_Struct(self, MYSQL, client); + + REQUIRE_OPEN_DB(client); + if (rb_thread_blocking_region(nogvl_read_query_result, client, RUBY_UBF_IO, 0) == Qfalse) { + return rb_raise_mysql2_error(client); + } + + result = (MYSQL_RES *)rb_thread_blocking_region(nogvl_store_result, client, RUBY_UBF_IO, 0); + if (result == NULL) { + if (mysql_field_count(client) != 0) { + rb_raise_mysql2_error(client); + } + return Qnil; + } + + return rb_mysql_result_to_obj(result); +} + static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { struct nogvl_send_query_args args; fd_set fdset; @@ -281,46 +337,6 @@ static VALUE rb_mysql_client_socket(VALUE self) { return INT2NUM(client->net.fd); } -/* - * even though we did rb_thread_select before calling this, a large - * response can overflow the socket buffers and cause us to eventually - * block while calling mysql_read_query_result - */ -static VALUE nogvl_read_query_result(void *ptr) { - MYSQL * client = ptr; - my_bool res = mysql_read_query_result(client); - - return res == 0 ? Qtrue : Qfalse; -} - -/* mysql_store_result may (unlikely) read rows off the socket */ -static VALUE nogvl_store_result(void *ptr) { - MYSQL * client = ptr; - return (VALUE)mysql_store_result(client); -} - -static VALUE rb_mysql_client_async_result(VALUE self) { - MYSQL * client; - MYSQL_RES * result; - - Data_Get_Struct(self, MYSQL, client); - - REQUIRE_OPEN_DB(client); - if (rb_thread_blocking_region(nogvl_read_query_result, client, RUBY_UBF_IO, 0) == Qfalse) { - return rb_raise_mysql2_error(client); - } - - result = (MYSQL_RES *)rb_thread_blocking_region(nogvl_store_result, client, RUBY_UBF_IO, 0); - if (result == NULL) { - if (mysql_field_count(client) != 0) { - rb_raise_mysql2_error(client); - } - return Qnil; - } - - return rb_mysql_result_to_obj(result); -} - static VALUE rb_mysql_client_last_id(VALUE self) { MYSQL * client; Data_Get_Struct(self, MYSQL, client); @@ -335,307 +351,6 @@ static VALUE rb_mysql_client_affected_rows(VALUE self) { return ULL2NUM(mysql_affected_rows(client)); } -/* Mysql2::Result */ -static VALUE rb_mysql_result_to_obj(MYSQL_RES * r) { - VALUE obj; - mysql2_result_wrapper * wrapper; - obj = Data_Make_Struct(cMysql2Result, mysql2_result_wrapper, rb_mysql_result_mark, rb_mysql_result_free, wrapper); - wrapper->numberOfFields = 0; - wrapper->numberOfRows = 0; - wrapper->lastRowProcessed = 0; - wrapper->resultFreed = 0; - wrapper->result = r; - wrapper->fields = Qnil; - wrapper->rows = Qnil; - rb_obj_call_init(obj, 0, NULL); - return obj; -} - -/* this may be called manually or during GC */ -static void rb_mysql_result_free_result(mysql2_result_wrapper * wrapper) { - if (wrapper && wrapper->resultFreed != 1) { - mysql_free_result(wrapper->result); - wrapper->resultFreed = 1; - } -} - -/* this is called during GC */ -static void rb_mysql_result_free(void * wrapper) { - mysql2_result_wrapper * w = wrapper; - /* FIXME: this may call flush_use_result, which can hit the socket */ - rb_mysql_result_free_result(w); - xfree(wrapper); -} - -static void rb_mysql_result_mark(void * wrapper) { - mysql2_result_wrapper * w = wrapper; - if (w) { - rb_gc_mark(w->fields); - rb_gc_mark(w->rows); - } -} - -static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsigned int idx, short int symbolize_keys) { - if (wrapper->fields == Qnil) { - wrapper->numberOfFields = mysql_num_fields(wrapper->result); - wrapper->fields = rb_ary_new2(wrapper->numberOfFields); - } - - VALUE rb_field = rb_ary_entry(wrapper->fields, idx); - if (rb_field == Qnil) { - MYSQL_FIELD *field = NULL; - #ifdef HAVE_RUBY_ENCODING_H - rb_encoding *default_internal_enc = rb_default_internal_encoding(); - #endif - - field = mysql_fetch_field_direct(wrapper->result, idx); - if (symbolize_keys) { - char buf[field->name_length+1]; - memcpy(buf, field->name, field->name_length); - buf[field->name_length] = 0; - rb_field = ID2SYM(rb_intern(buf)); - } else { - rb_field = rb_str_new(field->name, field->name_length); - #ifdef HAVE_RUBY_ENCODING_H - rb_enc_associate(rb_field, utf8Encoding); - if (default_internal_enc) { - rb_field = rb_str_export_to_enc(rb_field, default_internal_enc); - } - #endif - } - rb_ary_store(wrapper->fields, idx, rb_field); - } - - return rb_field; -} - -static VALUE rb_mysql_result_fetch_fields(VALUE self) { - mysql2_result_wrapper * wrapper; - unsigned int i = 0; - - GetMysql2Result(self, wrapper); - - if (wrapper->fields == Qnil) { - wrapper->numberOfFields = mysql_num_fields(wrapper->result); - wrapper->fields = rb_ary_new2(wrapper->numberOfFields); - } - - if (RARRAY_LEN(wrapper->fields) != wrapper->numberOfFields) { - for (i=0; inumberOfFields; i++) { - rb_mysql_result_fetch_field(wrapper, i, 0); - } - } - - return wrapper->fields; -} - -/* - * for small results, this won't hit the network, but there's no - * reliable way for us to tell this so we'll always release the GVL - * to be safe - */ -static VALUE nogvl_fetch_row(void *ptr) { - MYSQL_RES *result = ptr; - - return (VALUE)mysql_fetch_row(result); -} - -static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { - VALUE rowHash, opts, block; - mysql2_result_wrapper * wrapper; - MYSQL_ROW row; - MYSQL_FIELD * fields = NULL; - unsigned int i = 0, symbolizeKeys = 0; - unsigned long * fieldLengths; - void * ptr; -#ifdef HAVE_RUBY_ENCODING_H - rb_encoding *default_internal_enc = rb_default_internal_encoding(); -#endif - - GetMysql2Result(self, wrapper); - - if (rb_scan_args(argc, argv, "01&", &opts, &block) == 1) { - Check_Type(opts, T_HASH); - if (rb_hash_aref(opts, sym_symbolize_keys) == Qtrue) { - symbolizeKeys = 1; - } - } - - ptr = wrapper->result; - row = (MYSQL_ROW)rb_thread_blocking_region(nogvl_fetch_row, ptr, RUBY_UBF_IO, 0); - if (row == NULL) { - return Qnil; - } - - rowHash = rb_hash_new(); - fields = mysql_fetch_fields(wrapper->result); - fieldLengths = mysql_fetch_lengths(wrapper->result); - if (wrapper->fields == Qnil) { - wrapper->numberOfFields = mysql_num_fields(wrapper->result); - wrapper->fields = rb_ary_new2(wrapper->numberOfFields); - } - - for (i = 0; i < wrapper->numberOfFields; i++) { - VALUE field = rb_mysql_result_fetch_field(wrapper, i, symbolizeKeys); - if (row[i]) { - VALUE val; - switch(fields[i].type) { - case MYSQL_TYPE_NULL: // NULL-type field - val = Qnil; - break; - case MYSQL_TYPE_BIT: // BIT field (MySQL 5.0.3 and up) - val = rb_str_new(row[i], fieldLengths[i]); - break; - case MYSQL_TYPE_TINY: // TINYINT field - case MYSQL_TYPE_SHORT: // SMALLINT field - case MYSQL_TYPE_LONG: // INTEGER field - case MYSQL_TYPE_INT24: // MEDIUMINT field - case MYSQL_TYPE_LONGLONG: // BIGINT field - case MYSQL_TYPE_YEAR: // YEAR field - val = rb_cstr2inum(row[i], 10); - break; - case MYSQL_TYPE_DECIMAL: // DECIMAL or NUMERIC field - case MYSQL_TYPE_NEWDECIMAL: // Precision math DECIMAL or NUMERIC field (MySQL 5.0.3 and up) - val = rb_funcall(cBigDecimal, intern_new, 1, rb_str_new(row[i], fieldLengths[i])); - break; - case MYSQL_TYPE_FLOAT: // FLOAT field - case MYSQL_TYPE_DOUBLE: // DOUBLE or REAL field - val = rb_float_new(strtod(row[i], NULL)); - break; - case MYSQL_TYPE_TIME: { // TIME field - int hour, min, sec, tokens; - tokens = sscanf(row[i], "%2d:%2d:%2d", &hour, &min, &sec); - val = rb_funcall(rb_cTime, intern_utc, 6, INT2NUM(0), INT2NUM(1), INT2NUM(1), INT2NUM(hour), INT2NUM(min), INT2NUM(sec)); - break; - } - case MYSQL_TYPE_TIMESTAMP: // TIMESTAMP field - case MYSQL_TYPE_DATETIME: { // DATETIME field - int year, month, day, hour, min, sec, tokens; - tokens = sscanf(row[i], "%4d-%2d-%2d %2d:%2d:%2d", &year, &month, &day, &hour, &min, &sec); - if (year+month+day+hour+min+sec == 0) { - val = Qnil; - } else { - if (month < 1 || day < 1) { - rb_raise(cMysql2Error, "Invalid date: %s", row[i]); - val = Qnil; - } else { - val = rb_funcall(rb_cTime, intern_utc, 6, INT2NUM(year), INT2NUM(month), INT2NUM(day), INT2NUM(hour), INT2NUM(min), INT2NUM(sec)); - } - } - break; - } - case MYSQL_TYPE_DATE: // DATE field - case MYSQL_TYPE_NEWDATE: { // Newer const used > 5.0 - int year, month, day, tokens; - tokens = sscanf(row[i], "%4d-%2d-%2d", &year, &month, &day); - if (year+month+day == 0) { - val = Qnil; - } else { - if (month < 1 || day < 1) { - rb_raise(cMysql2Error, "Invalid date: %s", row[i]); - val = Qnil; - } else { - val = rb_funcall(cDate, intern_new, 3, INT2NUM(year), INT2NUM(month), INT2NUM(day)); - } - } - break; - } - case MYSQL_TYPE_TINY_BLOB: - case MYSQL_TYPE_MEDIUM_BLOB: - case MYSQL_TYPE_LONG_BLOB: - case MYSQL_TYPE_BLOB: - case MYSQL_TYPE_VAR_STRING: - case MYSQL_TYPE_VARCHAR: - case MYSQL_TYPE_STRING: // CHAR or BINARY field - case MYSQL_TYPE_SET: // SET field - case MYSQL_TYPE_ENUM: // ENUM field - case MYSQL_TYPE_GEOMETRY: // Spatial fielda - default: - val = rb_str_new(row[i], fieldLengths[i]); -#ifdef HAVE_RUBY_ENCODING_H - // rudimentary check for binary content - if ((fields[i].flags & BINARY_FLAG) || fields[i].charsetnr == 63) { - rb_enc_associate(val, binaryEncoding); - } else { - rb_enc_associate(val, utf8Encoding); - if (default_internal_enc) { - val = rb_str_export_to_enc(val, default_internal_enc); - } - } -#endif - break; - } - rb_hash_aset(rowHash, field, val); - } else { - rb_hash_aset(rowHash, field, Qnil); - } - } - return rowHash; -} - -static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { - VALUE opts, block; - mysql2_result_wrapper * wrapper; - unsigned long i; - - GetMysql2Result(self, wrapper); - - rb_scan_args(argc, argv, "01&", &opts, &block); - - if (wrapper->lastRowProcessed == 0) { - wrapper->numberOfRows = mysql_num_rows(wrapper->result); - if (wrapper->numberOfRows == 0) { - return Qnil; - } - wrapper->rows = rb_ary_new2(wrapper->numberOfRows); - } - - if (wrapper->lastRowProcessed == wrapper->numberOfRows) { - // we've already read the entire dataset from the C result into our - // internal array. Lets hand that over to the user since it's ready to go - for (i = 0; i < wrapper->numberOfRows; i++) { - rb_yield(rb_ary_entry(wrapper->rows, i)); - } - } else { - unsigned long rowsProcessed = 0; - rowsProcessed = RARRAY_LEN(wrapper->rows); - for (i = 0; i < wrapper->numberOfRows; i++) { - VALUE row; - if (i < rowsProcessed) { - row = rb_ary_entry(wrapper->rows, i); - } else { - row = rb_mysql_result_fetch_row(argc, argv, self); - rb_ary_store(wrapper->rows, i, row); - wrapper->lastRowProcessed++; - } - - if (row == Qnil) { - // we don't need the mysql C dataset around anymore, peace it - rb_mysql_result_free_result(wrapper); - return Qnil; - } - - if (block != Qnil) { - rb_yield(row); - } - } - if (wrapper->lastRowProcessed == wrapper->numberOfRows) { - // we don't need the mysql C dataset around anymore, peace it - rb_mysql_result_free_result(wrapper); - } - } - - return wrapper->rows; -} - -static VALUE rb_raise_mysql2_error(MYSQL *client) { - VALUE e = rb_exc_new2(cMysql2Error, mysql_error(client)); - rb_funcall(e, rb_intern("error_number="), 1, INT2NUM(mysql_errno(client))); - rb_funcall(e, rb_intern("sql_state="), 1, rb_tainted_str_new2(mysql_sqlstate(client))); - rb_exc_raise(e); - return Qnil; -} - static VALUE set_reconnect(VALUE self, VALUE value) { my_bool reconnect; @@ -724,11 +439,7 @@ static VALUE init_connection(VALUE self) /* Ruby Extension initializer */ void Init_mysql2() { - cBigDecimal = rb_const_get(rb_cObject, rb_intern("BigDecimal")); - cDate = rb_const_get(rb_cObject, rb_intern("Date")); - cDateTime = rb_const_get(rb_cObject, rb_intern("DateTime")); - - VALUE mMysql2 = rb_define_module("Mysql2"); + mMysql2 = rb_define_module("Mysql2"); VALUE cMysql2Client = rb_define_class_under(mMysql2, "Client", rb_cObject); rb_define_alloc_func(cMysql2Client, allocate); @@ -752,20 +463,14 @@ void Init_mysql2() { cMysql2Error = rb_const_get(mMysql2, rb_intern("Error")); - cMysql2Result = rb_define_class_under(mMysql2, "Result", rb_cObject); - rb_define_method(cMysql2Result, "each", rb_mysql_result_each, -1); - rb_define_method(cMysql2Result, "fields", rb_mysql_result_fetch_fields, 0); - - intern_new = rb_intern("new"); - intern_utc = rb_intern("utc"); - - sym_symbolize_keys = ID2SYM(rb_intern("symbolize_keys")); - sym_id = ID2SYM(rb_intern("id")); - sym_version = ID2SYM(rb_intern("version")); - sym_async = ID2SYM(rb_intern("async")); #ifdef HAVE_RUBY_ENCODING_H utf8Encoding = rb_utf8_encoding(); - binaryEncoding = rb_enc_find("binary"); #endif + + init_mysql2_result(); + + sym_id = ID2SYM(rb_intern("id")); + sym_version = ID2SYM(rb_intern("version")); + sym_async = ID2SYM(rb_intern("async")); } diff --git a/ext/mysql2/mysql2_ext.h b/ext/mysql2/mysql2_ext.h index 3e1b4ad..602ecd9 100644 --- a/ext/mysql2/mysql2_ext.h +++ b/ext/mysql2/mysql2_ext.h @@ -1,3 +1,6 @@ +#ifndef MYSQL2_EXT +#define MYSQL2_EXT + #include #include @@ -15,7 +18,10 @@ #ifdef HAVE_RUBY_ENCODING_H #include -static rb_encoding *utf8Encoding, *binaryEncoding; +#endif + +#ifdef HAVE_RUBY_ENCODING_H +extern rb_encoding *utf8Encoding; #endif #if defined(__GNUC__) && (__GNUC__ >= 3) @@ -24,22 +30,12 @@ static rb_encoding *utf8Encoding, *binaryEncoding; #define RB_MYSQL_UNUSED #endif -static VALUE cBigDecimal, cDate, cDateTime; -static ID intern_new, intern_utc; +#include + +extern VALUE mMysql2; /* Mysql2::Error */ -static VALUE cMysql2Error; - -static ID sym_id, sym_version, sym_symbolize_keys, sym_async; -static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self); -static VALUE rb_mysql_client_escape(VALUE self, VALUE str); -static VALUE rb_mysql_client_info(VALUE self); -static VALUE rb_mysql_client_server_info(VALUE self); -static VALUE rb_mysql_client_socket(VALUE self); -static VALUE rb_mysql_client_async_result(VALUE self); -static VALUE rb_mysql_client_last_id(VALUE self); -static VALUE rb_mysql_client_affected_rows(VALUE self); -static void rb_mysql_client_free(void * client); +extern VALUE cMysql2Error; /* Mysql2::Result */ typedef struct { @@ -52,16 +48,6 @@ typedef struct { MYSQL_RES *result; } mysql2_result_wrapper; #define GetMysql2Result(obj, sval) (sval = (mysql2_result_wrapper*)DATA_PTR(obj)); -static VALUE cMysql2Result; -static VALUE rb_mysql_result_to_obj(MYSQL_RES * res); -static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self); -static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self); -static void rb_mysql_result_free(void * wrapper); -static void rb_mysql_result_mark(void * wrapper); -static void rb_mysql_result_free_result(mysql2_result_wrapper * wrapper); - -/* Mysql2::Error */ -static VALUE rb_raise_mysql2_error(MYSQL *client); /* * used to pass all arguments to mysql_real_connect while inside @@ -112,3 +98,5 @@ rb_thread_blocking_region( return rv; } #endif /* ! HAVE_RB_THREAD_BLOCKING_REGION */ + +#endif diff --git a/ext/mysql2/result.c b/ext/mysql2/result.c new file mode 100644 index 0000000..5bf0baf --- /dev/null +++ b/ext/mysql2/result.c @@ -0,0 +1,324 @@ +#include + +#ifdef HAVE_RUBY_ENCODING_H +rb_encoding *binaryEncoding; +#endif + +ID sym_symbolize_keys; +ID intern_new, intern_utc; + +VALUE cBigDecimal, cDate, cDateTime; +VALUE cMysql2Result; + +static void rb_mysql_result_mark(void * wrapper) { + mysql2_result_wrapper * w = wrapper; + if (w) { + rb_gc_mark(w->fields); + rb_gc_mark(w->rows); + } +} + +/* this may be called manually or during GC */ +static void rb_mysql_result_free_result(mysql2_result_wrapper * wrapper) { + if (wrapper && wrapper->resultFreed != 1) { + mysql_free_result(wrapper->result); + wrapper->resultFreed = 1; + } +} + +/* this is called during GC */ +static void rb_mysql_result_free(void * wrapper) { + mysql2_result_wrapper * w = wrapper; + /* FIXME: this may call flush_use_result, which can hit the socket */ + rb_mysql_result_free_result(w); + xfree(wrapper); +} + +/* + * for small results, this won't hit the network, but there's no + * reliable way for us to tell this so we'll always release the GVL + * to be safe + */ +static VALUE nogvl_fetch_row(void *ptr) { + MYSQL_RES *result = ptr; + + return (VALUE)mysql_fetch_row(result); +} + +static VALUE rb_mysql_result_fetch_field(mysql2_result_wrapper * wrapper, unsigned int idx, short int symbolize_keys) { + if (wrapper->fields == Qnil) { + wrapper->numberOfFields = mysql_num_fields(wrapper->result); + wrapper->fields = rb_ary_new2(wrapper->numberOfFields); + } + + VALUE rb_field = rb_ary_entry(wrapper->fields, idx); + if (rb_field == Qnil) { + MYSQL_FIELD *field = NULL; + #ifdef HAVE_RUBY_ENCODING_H + rb_encoding *default_internal_enc = rb_default_internal_encoding(); + #endif + + field = mysql_fetch_field_direct(wrapper->result, idx); + if (symbolize_keys) { + char buf[field->name_length+1]; + memcpy(buf, field->name, field->name_length); + buf[field->name_length] = 0; + rb_field = ID2SYM(rb_intern(buf)); + } else { + rb_field = rb_str_new(field->name, field->name_length); + #ifdef HAVE_RUBY_ENCODING_H + rb_enc_associate(rb_field, utf8Encoding); + if (default_internal_enc) { + rb_field = rb_str_export_to_enc(rb_field, default_internal_enc); + } + #endif + } + rb_ary_store(wrapper->fields, idx, rb_field); + } + + return rb_field; +} + +static VALUE rb_mysql_result_fetch_row(int argc, VALUE * argv, VALUE self) { + VALUE rowHash, opts, block; + mysql2_result_wrapper * wrapper; + MYSQL_ROW row; + MYSQL_FIELD * fields = NULL; + unsigned int i = 0, symbolizeKeys = 0; + unsigned long * fieldLengths; + void * ptr; +#ifdef HAVE_RUBY_ENCODING_H + rb_encoding *default_internal_enc = rb_default_internal_encoding(); +#endif + + GetMysql2Result(self, wrapper); + + if (rb_scan_args(argc, argv, "01&", &opts, &block) == 1) { + Check_Type(opts, T_HASH); + if (rb_hash_aref(opts, sym_symbolize_keys) == Qtrue) { + symbolizeKeys = 1; + } + } + + ptr = wrapper->result; + row = (MYSQL_ROW)rb_thread_blocking_region(nogvl_fetch_row, ptr, RUBY_UBF_IO, 0); + if (row == NULL) { + return Qnil; + } + + rowHash = rb_hash_new(); + fields = mysql_fetch_fields(wrapper->result); + fieldLengths = mysql_fetch_lengths(wrapper->result); + if (wrapper->fields == Qnil) { + wrapper->numberOfFields = mysql_num_fields(wrapper->result); + wrapper->fields = rb_ary_new2(wrapper->numberOfFields); + } + + for (i = 0; i < wrapper->numberOfFields; i++) { + VALUE field = rb_mysql_result_fetch_field(wrapper, i, symbolizeKeys); + if (row[i]) { + VALUE val; + switch(fields[i].type) { + case MYSQL_TYPE_NULL: // NULL-type field + val = Qnil; + break; + case MYSQL_TYPE_BIT: // BIT field (MySQL 5.0.3 and up) + val = rb_str_new(row[i], fieldLengths[i]); + break; + case MYSQL_TYPE_TINY: // TINYINT field + case MYSQL_TYPE_SHORT: // SMALLINT field + case MYSQL_TYPE_LONG: // INTEGER field + case MYSQL_TYPE_INT24: // MEDIUMINT field + case MYSQL_TYPE_LONGLONG: // BIGINT field + case MYSQL_TYPE_YEAR: // YEAR field + val = rb_cstr2inum(row[i], 10); + break; + case MYSQL_TYPE_DECIMAL: // DECIMAL or NUMERIC field + case MYSQL_TYPE_NEWDECIMAL: // Precision math DECIMAL or NUMERIC field (MySQL 5.0.3 and up) + val = rb_funcall(cBigDecimal, intern_new, 1, rb_str_new(row[i], fieldLengths[i])); + break; + case MYSQL_TYPE_FLOAT: // FLOAT field + case MYSQL_TYPE_DOUBLE: // DOUBLE or REAL field + val = rb_float_new(strtod(row[i], NULL)); + break; + case MYSQL_TYPE_TIME: { // TIME field + int hour, min, sec, tokens; + tokens = sscanf(row[i], "%2d:%2d:%2d", &hour, &min, &sec); + val = rb_funcall(rb_cTime, intern_utc, 6, INT2NUM(0), INT2NUM(1), INT2NUM(1), INT2NUM(hour), INT2NUM(min), INT2NUM(sec)); + break; + } + case MYSQL_TYPE_TIMESTAMP: // TIMESTAMP field + case MYSQL_TYPE_DATETIME: { // DATETIME field + int year, month, day, hour, min, sec, tokens; + tokens = sscanf(row[i], "%4d-%2d-%2d %2d:%2d:%2d", &year, &month, &day, &hour, &min, &sec); + if (year+month+day+hour+min+sec == 0) { + val = Qnil; + } else { + if (month < 1 || day < 1) { + rb_raise(cMysql2Error, "Invalid date: %s", row[i]); + val = Qnil; + } else { + val = rb_funcall(rb_cTime, intern_utc, 6, INT2NUM(year), INT2NUM(month), INT2NUM(day), INT2NUM(hour), INT2NUM(min), INT2NUM(sec)); + } + } + break; + } + case MYSQL_TYPE_DATE: // DATE field + case MYSQL_TYPE_NEWDATE: { // Newer const used > 5.0 + int year, month, day, tokens; + tokens = sscanf(row[i], "%4d-%2d-%2d", &year, &month, &day); + if (year+month+day == 0) { + val = Qnil; + } else { + if (month < 1 || day < 1) { + rb_raise(cMysql2Error, "Invalid date: %s", row[i]); + val = Qnil; + } else { + val = rb_funcall(cDate, intern_new, 3, INT2NUM(year), INT2NUM(month), INT2NUM(day)); + } + } + break; + } + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_STRING: // CHAR or BINARY field + case MYSQL_TYPE_SET: // SET field + case MYSQL_TYPE_ENUM: // ENUM field + case MYSQL_TYPE_GEOMETRY: // Spatial fielda + default: + val = rb_str_new(row[i], fieldLengths[i]); +#ifdef HAVE_RUBY_ENCODING_H + // rudimentary check for binary content + if ((fields[i].flags & BINARY_FLAG) || fields[i].charsetnr == 63) { + rb_enc_associate(val, binaryEncoding); + } else { + rb_enc_associate(val, utf8Encoding); + if (default_internal_enc) { + val = rb_str_export_to_enc(val, default_internal_enc); + } + } +#endif + break; + } + rb_hash_aset(rowHash, field, val); + } else { + rb_hash_aset(rowHash, field, Qnil); + } + } + return rowHash; +} + +static VALUE rb_mysql_result_fetch_fields(VALUE self) { + mysql2_result_wrapper * wrapper; + unsigned int i = 0; + + GetMysql2Result(self, wrapper); + + if (wrapper->fields == Qnil) { + wrapper->numberOfFields = mysql_num_fields(wrapper->result); + wrapper->fields = rb_ary_new2(wrapper->numberOfFields); + } + + if (RARRAY_LEN(wrapper->fields) != wrapper->numberOfFields) { + for (i=0; inumberOfFields; i++) { + rb_mysql_result_fetch_field(wrapper, i, 0); + } + } + + return wrapper->fields; +} + +static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { + VALUE opts, block; + mysql2_result_wrapper * wrapper; + unsigned long i; + + GetMysql2Result(self, wrapper); + + rb_scan_args(argc, argv, "01&", &opts, &block); + + if (wrapper->lastRowProcessed == 0) { + wrapper->numberOfRows = mysql_num_rows(wrapper->result); + if (wrapper->numberOfRows == 0) { + return Qnil; + } + wrapper->rows = rb_ary_new2(wrapper->numberOfRows); + } + + if (wrapper->lastRowProcessed == wrapper->numberOfRows) { + // we've already read the entire dataset from the C result into our + // internal array. Lets hand that over to the user since it's ready to go + for (i = 0; i < wrapper->numberOfRows; i++) { + rb_yield(rb_ary_entry(wrapper->rows, i)); + } + } else { + unsigned long rowsProcessed = 0; + rowsProcessed = RARRAY_LEN(wrapper->rows); + for (i = 0; i < wrapper->numberOfRows; i++) { + VALUE row; + if (i < rowsProcessed) { + row = rb_ary_entry(wrapper->rows, i); + } else { + row = rb_mysql_result_fetch_row(argc, argv, self); + rb_ary_store(wrapper->rows, i, row); + wrapper->lastRowProcessed++; + } + + if (row == Qnil) { + // we don't need the mysql C dataset around anymore, peace it + rb_mysql_result_free_result(wrapper); + return Qnil; + } + + if (block != Qnil) { + rb_yield(row); + } + } + if (wrapper->lastRowProcessed == wrapper->numberOfRows) { + // we don't need the mysql C dataset around anymore, peace it + rb_mysql_result_free_result(wrapper); + } + } + + return wrapper->rows; +} + +/* Mysql2::Result */ +VALUE rb_mysql_result_to_obj(MYSQL_RES * r) { + VALUE obj; + mysql2_result_wrapper * wrapper; + obj = Data_Make_Struct(cMysql2Result, mysql2_result_wrapper, rb_mysql_result_mark, rb_mysql_result_free, wrapper); + wrapper->numberOfFields = 0; + wrapper->numberOfRows = 0; + wrapper->lastRowProcessed = 0; + wrapper->resultFreed = 0; + wrapper->result = r; + wrapper->fields = Qnil; + wrapper->rows = Qnil; + rb_obj_call_init(obj, 0, NULL); + return obj; +} + +void init_mysql2_result() +{ + cBigDecimal = rb_const_get(rb_cObject, rb_intern("BigDecimal")); + cDate = rb_const_get(rb_cObject, rb_intern("Date")); + cDateTime = rb_const_get(rb_cObject, rb_intern("DateTime")); + + cMysql2Result = rb_define_class_under(mMysql2, "Result", rb_cObject); + rb_define_method(cMysql2Result, "each", rb_mysql_result_each, -1); + rb_define_method(cMysql2Result, "fields", rb_mysql_result_fetch_fields, 0); + + sym_symbolize_keys = ID2SYM(rb_intern("symbolize_keys")); + intern_new = rb_intern("new"); + intern_utc = rb_intern("utc"); + +#ifdef HAVE_RUBY_ENCODING_H + utf8Encoding = rb_utf8_encoding(); + binaryEncoding = rb_enc_find("binary"); +#endif +} diff --git a/ext/mysql2/result.h b/ext/mysql2/result.h new file mode 100644 index 0000000..3b9eeab --- /dev/null +++ b/ext/mysql2/result.h @@ -0,0 +1,7 @@ +#ifndef MYSQL2_RESULT_H +#define MYSQL2_RESULT_H + +void init_mysql2_result(); +VALUE rb_mysql_result_to_obj(MYSQL_RES * r); + +#endif diff --git a/mysql2.gemspec b/mysql2.gemspec index af3a3fc..6b6d12c 100644 --- a/mysql2.gemspec +++ b/mysql2.gemspec @@ -32,6 +32,8 @@ Gem::Specification.new do |s| "ext/mysql2/extconf.rb", "ext/mysql2/mysql2_ext.c", "ext/mysql2/mysql2_ext.h", + "ext/mysql2/result.c", + "ext/mysql2/result.h", "lib/active_record/connection_adapters/mysql2_adapter.rb", "lib/arel/engines/sql/compilers/mysql2_compiler.rb", "lib/mysql2.rb",