From 2d9e10c1922d621899c09003c85c9c750f0e9add Mon Sep 17 00:00:00 2001 From: Kyle Banker Date: Sat, 27 Feb 2010 12:22:34 -0500 Subject: [PATCH] CBson HashWithIndifferentAccess error --- ext/cbson/cbson.c | 54 ++++++++++++++++++++++--------------- ext/cbson/version.h | 2 +- lib/mongo.rb | 2 +- lib/mongo/collection.rb | 4 +-- lib/mongo/util/bson_ruby.rb | 6 ++--- test/bson_test.rb | 27 +++++++++++++++++++ 6 files changed, 67 insertions(+), 28 deletions(-) diff --git a/ext/cbson/cbson.c b/ext/cbson/cbson.c index 05048eb..46a2298 100644 --- a/ext/cbson/cbson.c +++ b/ext/cbson/cbson.c @@ -141,17 +141,14 @@ static int cmp_char(const void* a, const void* b) { } static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys, VALUE move_id); -static int write_element(VALUE key, VALUE value, VALUE extra); +static int write_element_with_id(VALUE key, VALUE value, VALUE extra); +static int write_element_without_id(VALUE key, VALUE value, VALUE extra); static VALUE elements_to_hash(const char* buffer, int max); static VALUE pack_extra(buffer_t buffer, VALUE check_keys) { return rb_ary_new3(2, LL2NUM((long long)buffer), check_keys); } -static VALUE pack_triple(buffer_t buffer, VALUE check_keys, int allow_id) { - return rb_ary_new3(3, LL2NUM((long long)buffer), check_keys, allow_id); -} - static void write_name_and_type(buffer_t buffer, VALUE name, char type) { SAFE_WRITE(buffer, &type, 1); name = TO_UTF8(name); @@ -159,7 +156,7 @@ static void write_name_and_type(buffer_t buffer, VALUE name, char type) { SAFE_WRITE(buffer, &zero, 1); } -static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow_id) { +static int write_element(VALUE key, VALUE value, VALUE extra, int allow_id) { buffer_t buffer = (buffer_t)NUM2LL(rb_ary_entry(extra, 0)); VALUE check_keys = rb_ary_entry(extra, 1); @@ -173,7 +170,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow rb_raise(rb_eTypeError, "keys must be strings or symbols"); } - if (!allow_id && strcmp("_id", RSTRING_PTR(key)) == 0) { + if (allow_id == 0 && strcmp("_id", RSTRING_PTR(key)) == 0) { return ST_CONTINUE; } @@ -266,7 +263,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow VALUE key; INT2STRING(&name, i); key = rb_str_new2(name); - write_element(key, values[i], pack_extra(buffer, check_keys)); + write_element_with_id(key, values[i], pack_extra(buffer, check_keys)); free(name); } @@ -366,9 +363,9 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow } ns = rb_funcall(value, rb_intern("namespace"), 0); - write_element(rb_str_new2("$ref"), ns, pack_extra(buffer, Qfalse)); + write_element_with_id(rb_str_new2("$ref"), ns, pack_extra(buffer, Qfalse)); oid = rb_funcall(value, rb_intern("object_id"), 0); - write_element(rb_str_new2("$id"), oid, pack_extra(buffer, Qfalse)); + write_element_with_id(rb_str_new2("$id"), oid, pack_extra(buffer, Qfalse)); // write null byte and fill in length SAFE_WRITE(buffer, &zero, 1); @@ -464,8 +461,12 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow return ST_CONTINUE; } -static int write_element(VALUE key, VALUE value, VALUE extra) { - return write_element_allow_id(key, value, extra, 0); +static int write_element_without_id(VALUE key, VALUE value, VALUE extra) { + return write_element(key, value, extra, 0); +} + +static int write_element_with_id(VALUE key, VALUE value, VALUE extra) { + return write_element(key, value, extra, 1); } static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys, VALUE move_id) { @@ -473,6 +474,7 @@ static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys, VALUE move_ buffer_position length_location = buffer_save_space(buffer, 4); buffer_position length; int allow_id; + int (*write_function)(VALUE, VALUE, VALUE) = NULL; VALUE id_str = rb_str_new2("_id"); VALUE id_sym = ID2SYM(rb_intern("_id")); @@ -480,37 +482,47 @@ static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys, VALUE move_ rb_raise(rb_eNoMemError, "failed to allocate memory in buffer.c"); } - // write '_id' first if move_id is true + // write '_id' first if move_id is true. then don't allow an id to be written. if(move_id == Qtrue) { allow_id = 0; if (rb_funcall(hash, rb_intern("has_key?"), 1, id_str) == Qtrue) { VALUE id = rb_hash_aref(hash, id_str); - write_element_allow_id(id_str, id, pack_extra(buffer, check_keys), 1); + write_element_with_id(id_str, id, pack_extra(buffer, check_keys)); } else if (rb_funcall(hash, rb_intern("has_key?"), 1, id_sym) == Qtrue) { VALUE id = rb_hash_aref(hash, id_sym); - write_element_allow_id(id_sym, id, pack_extra(buffer, check_keys), 1); + write_element_with_id(id_sym, id, pack_extra(buffer, check_keys)); } } else { allow_id = 1; - if ((rb_funcall(hash, rb_intern("has_key?"), 1, id_str) == Qtrue) && - (rb_funcall(hash, rb_intern("has_key?"), 1, id_sym) == Qtrue)) { - VALUE obj = rb_hash_delete(hash, id_str); + if (strcmp(rb_class2name(RBASIC(hash)->klass), "HashWithIndifferentAccess") != 0) { + if ((rb_funcall(hash, rb_intern("has_key?"), 1, id_str) == Qtrue) && + (rb_funcall(hash, rb_intern("has_key?"), 1, id_sym) == Qtrue)) { + VALUE oid_sym = rb_hash_delete(hash, id_sym); + rb_funcall(hash, rb_intern("[]="), 2, id_str, oid_sym); + } } } + if(allow_id == 1) { + write_function = write_element_with_id; + } + else { + write_function = write_element_without_id; + } + // we have to check for an OrderedHash and handle that specially if (strcmp(rb_class2name(RBASIC(hash)->klass), "OrderedHash") == 0) { VALUE keys = rb_funcall(hash, rb_intern("keys"), 0); int i; - for(i = 0; i < RARRAY_LEN(keys); i++) { + for(i = 0; i < RARRAY_LEN(keys); i++) { VALUE key = RARRAY_PTR(keys)[i]; VALUE value = rb_hash_aref(hash, key); - write_element_allow_id(key, value, pack_extra(buffer, check_keys), allow_id); + write_function(key, value, pack_extra(buffer, check_keys)); } } else { - rb_hash_foreach(hash, write_element_allow_id, pack_triple(buffer, check_keys, allow_id)); + rb_hash_foreach(hash, write_function, pack_extra(buffer, check_keys)); } // write null byte and fill in length diff --git a/ext/cbson/version.h b/ext/cbson/version.h index e5a4d77..04566e6 100644 --- a/ext/cbson/version.h +++ b/ext/cbson/version.h @@ -14,4 +14,4 @@ * limitations under the License. */ -#define VERSION "0.18.3p" +#define VERSION "0.19" diff --git a/lib/mongo.rb b/lib/mongo.rb index 951a5b4..79cf136 100644 --- a/lib/mongo.rb +++ b/lib/mongo.rb @@ -1,7 +1,7 @@ $:.unshift(File.join(File.dirname(__FILE__), '..', 'lib')) module Mongo - VERSION = "0.18.3p" + VERSION = "0.19" end begin diff --git a/lib/mongo/collection.rb b/lib/mongo/collection.rb index 1a0a321..de0c106 100644 --- a/lib/mongo/collection.rb +++ b/lib/mongo/collection.rb @@ -261,7 +261,7 @@ module Mongo message = ByteBuffer.new([0, 0, 0, 0]) BSON_RUBY.serialize_cstr(message, "#{@db.name}.#{@name}") message.put_int(0) - message.put_array(BSON.serialize(selector, false).to_a) + message.put_array(BSON.serialize(selector, false, true).to_a) if opts[:safe] @connection.send_message_with_safe_check(Mongo::Constants::OP_DELETE, message, @db.name, @@ -303,7 +303,7 @@ module Mongo update_options += 1 if options[:upsert] update_options += 2 if options[:multi] message.put_int(update_options) - message.put_array(BSON.serialize(selector, false).to_a) + message.put_array(BSON.serialize(selector, false, true).to_a) message.put_array(BSON.serialize(document, false, true).to_a) if options[:safe] @connection.send_message_with_safe_check(Mongo::Constants::OP_UPDATE, message, @db.name, diff --git a/lib/mongo/util/bson_ruby.rb b/lib/mongo/util/bson_ruby.rb index 61392ca..3a2dbf7 100644 --- a/lib/mongo/util/bson_ruby.rb +++ b/lib/mongo/util/bson_ruby.rb @@ -105,14 +105,14 @@ class BSON_RUBY # Write key/value pairs. Always write _id first if it exists. if move_id if obj.has_key? '_id' - serialize_key_value('_id', obj['_id'], check_keys) + serialize_key_value('_id', obj['_id'], false) elsif obj.has_key? :_id - serialize_key_value('_id', obj[:_id], check_keys) + serialize_key_value('_id', obj[:_id], false) end obj.each {|k, v| serialize_key_value(k, v, check_keys) unless k == '_id' || k == :_id } else if obj.has_key?('_id') && obj.has_key?(:_id) - obj.delete(:_id) + obj['_id'] = obj.delete(:_id) end obj.each {|k, v| serialize_key_value(k, v, check_keys) } end diff --git a/test/bson_test.rb b/test/bson_test.rb index 9413565..f422b64 100644 --- a/test/bson_test.rb +++ b/test/bson_test.rb @@ -356,6 +356,13 @@ class BSONTest < Test::Unit::TestCase assert_equal BSON.serialize(one).to_a, BSON.serialize(dup).to_a end + def test_no_duplicate_id_when_moving_id + dup = {"_id" => "foo", :_id => "foo"} + one = {:_id => "foo"} + + assert_equal BSON.serialize(one, false, true).to_s, BSON.serialize(dup, false, true).to_s + end + def test_null_character doc = {"a" => "\x00"} @@ -399,6 +406,7 @@ class BSONTest < Test::Unit::TestCase a['key'] = 'abc' a['_id'] = 1 + assert_equal ")\000\000\000\020_id\000\001\000\000\000\002text" + "\000\004\000\000\000abc\000\002key\000\004\000\000\000abc\000\000", BSON.serialize(a, false, true).to_s @@ -424,4 +432,23 @@ class BSONTest < Test::Unit::TestCase "\000\002\000\000\000\000\020_id\000\003\000\000\000\000", BSON.serialize(c, false, false).to_s end + + begin + require 'active_support' + rescue LoadError + warn 'Could not test BSON with HashWithIndifferentAccess.' + end + + if defined?(HashWithIndifferentAccess) + def test_keep_id_with_hash_with_indifferent_access + doc = HashWithIndifferentAccess.new + doc[:_id] = ObjectID.new + BSON.serialize(doc, false, false).to_a + assert doc.has_key?("_id") + + doc['_id'] = ObjectID.new + BSON.serialize(doc, false, false).to_a + assert doc.has_key?("_id") + end + end end