diff --git a/ext/cbson/cbson.c b/ext/cbson/cbson.c index e7e9bb7..8a1de48 100644 --- a/ext/cbson/cbson.c +++ b/ext/cbson/cbson.c @@ -139,7 +139,7 @@ static int cmp_char(const void* a, const void* b) { return *(char*)a - *(char*)b; } -static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys); +static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys, VALUE move_id); static int write_element(VALUE key, VALUE value, VALUE extra); static VALUE elements_to_hash(const char* buffer, int max); @@ -147,6 +147,10 @@ static VALUE pack_extra(buffer_t buffer, VALUE check_keys) { return rb_ary_new3(2, LL2NUM((long long)buffer), check_keys); } +static VALUE pack_triple(buffer_t buffer, VALUE check_keys, int allow_id) { + return rb_ary_new3(3, LL2NUM((long long)buffer), check_keys, allow_id); +} + static void write_name_and_type(buffer_t buffer, VALUE name, char type) { SAFE_WRITE(buffer, &type, 1); name = TO_UTF8(name); @@ -236,7 +240,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow case T_HASH: { write_name_and_type(buffer, key, 0x03); - write_doc(buffer, value, check_keys); + write_doc(buffer, value, check_keys, Qfalse); break; } case T_ARRAY: @@ -289,7 +293,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow SAFE_WRITE(buffer, (char*)&length, 4); SAFE_WRITE(buffer, RSTRING_PTR(value), length - 1); SAFE_WRITE(buffer, &zero, 1); - write_doc(buffer, rb_funcall(value, rb_intern("scope"), 0), Qfalse); + write_doc(buffer, rb_funcall(value, rb_intern("scope"), 0), Qfalse, Qfalse); total_length = buffer_get_position(buffer) - start_position; SAFE_WRITE_AT_POS(buffer, length_location, (const char*)&total_length, 4); @@ -463,10 +467,11 @@ static int write_element(VALUE key, VALUE value, VALUE extra) { return write_element_allow_id(key, value, extra, 0); } -static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys) { +static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys, VALUE move_id) { buffer_position start_position = buffer_get_position(buffer); buffer_position length_location = buffer_save_space(buffer, 4); buffer_position length; + int allow_id; VALUE id_str = rb_str_new2("_id"); VALUE id_sym = ID2SYM(rb_intern("_id")); @@ -474,12 +479,23 @@ static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys) { rb_raise(rb_eNoMemError, "failed to allocate memory in buffer.c"); } - if (rb_funcall(hash, rb_intern("has_key?"), 1, id_str) == Qtrue) { - VALUE id = rb_hash_aref(hash, id_str); - write_element_allow_id(id_str, id, pack_extra(buffer, check_keys), 1); - } else if (rb_funcall(hash, rb_intern("has_key?"), 1, id_sym) == Qtrue) { - VALUE id = rb_hash_aref(hash, id_sym); - write_element_allow_id(id_sym, id, pack_extra(buffer, check_keys), 1); + // write '_id' first if move_id is true + if(move_id == Qtrue) { + allow_id = 0; + if (rb_funcall(hash, rb_intern("has_key?"), 1, id_str) == Qtrue) { + VALUE id = rb_hash_aref(hash, id_str); + write_element_allow_id(id_str, id, pack_extra(buffer, check_keys), 1); + } else if (rb_funcall(hash, rb_intern("has_key?"), 1, id_sym) == Qtrue) { + VALUE id = rb_hash_aref(hash, id_sym); + write_element_allow_id(id_sym, id, pack_extra(buffer, check_keys), 1); + } + } + else { + allow_id = 1; + if ((rb_funcall(hash, rb_intern("has_key?"), 1, id_str) == Qtrue) && + (rb_funcall(hash, rb_intern("has_key?"), 1, id_sym) == Qtrue)) { + VALUE obj = rb_hash_delete(hash, id_str); + } } // we have to check for an OrderedHash and handle that specially @@ -490,10 +506,10 @@ static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys) { VALUE key = RARRAY_PTR(keys)[i]; VALUE value = rb_hash_aref(hash, key); - write_element(key, value, pack_extra(buffer, check_keys)); + write_element_allow_id(key, value, pack_extra(buffer, check_keys), allow_id); } } else { - rb_hash_foreach(hash, write_element, pack_extra(buffer, check_keys)); + rb_hash_foreach(hash, write_element_allow_id, pack_triple(buffer, check_keys, allow_id)); } // write null byte and fill in length @@ -509,14 +525,14 @@ static void write_doc(buffer_t buffer, VALUE hash, VALUE check_keys) { SAFE_WRITE_AT_POS(buffer, length_location, (const char*)&length, 4); } -static VALUE method_serialize(VALUE self, VALUE doc, VALUE check_keys) { +static VALUE method_serialize(VALUE self, VALUE doc, VALUE check_keys, VALUE move_id) { VALUE result; buffer_t buffer = buffer_new(); if (buffer == NULL) { rb_raise(rb_eNoMemError, "failed to allocate memory in buffer.c"); } - write_doc(buffer, doc, check_keys); + write_doc(buffer, doc, check_keys, move_id); result = rb_str_new(buffer_get_buffer(buffer), buffer_get_position(buffer)); if (buffer_free(buffer) != 0) { @@ -872,7 +888,7 @@ void Init_cbson() { CBson = rb_define_module("CBson"); ext_version = rb_str_new2(VERSION); rb_define_const(CBson, "VERSION", ext_version); - rb_define_module_function(CBson, "serialize", method_serialize, 2); + rb_define_module_function(CBson, "serialize", method_serialize, 3); rb_define_module_function(CBson, "deserialize", method_deserialize, 1); rb_require("digest/md5"); diff --git a/lib/mongo/collection.rb b/lib/mongo/collection.rb index 124ed74..146c13e 100644 --- a/lib/mongo/collection.rb +++ b/lib/mongo/collection.rb @@ -126,6 +126,8 @@ module Mongo # # @raise [RuntimeError] # if given unknown options + # + # @see Official, Official Docs def find(selector={}, opts={}) fields = opts.delete(:fields) fields = ["_id"] if fields && fields.empty? @@ -198,6 +200,8 @@ module Mongo # If true, check that the save succeeded. OperationFailure # will be raised on an error. Note that a safe check requires an extra # round-trip to the database. + # + # @see Official, Official Docs def save(doc, options={}) if doc.has_key?(:_id) || doc.has_key?('_id') id = doc[:_id] || doc['_id'] @@ -294,7 +298,7 @@ module Mongo update_options += 2 if options[:multi] message.put_int(update_options) message.put_array(BSON.serialize(selector, false).to_a) - message.put_array(BSON.serialize(document, false).to_a) + message.put_array(BSON.serialize(document, false, true).to_a) if options[:safe] @connection.send_message_with_safe_check(Mongo::Constants::OP_UPDATE, message, @db.name, "db.#{@name}.update(#{selector.inspect}, #{document.inspect})") @@ -609,7 +613,7 @@ EOS # Initial byte is 0. message = ByteBuffer.new([0, 0, 0, 0]) BSON_RUBY.serialize_cstr(message, "#{@db.name}.#{collection_name}") - documents.each { |doc| message.put_array(BSON.serialize(doc, check_keys).to_a) } + documents.each { |doc| message.put_array(BSON.serialize(doc, check_keys, true).to_a) } if safe @connection.send_message_with_safe_check(Mongo::Constants::OP_INSERT, message, @db.name, "db.#{collection_name}.insert(#{documents.inspect})") diff --git a/lib/mongo/util/bson_c.rb b/lib/mongo/util/bson_c.rb index d6b349a..b220b35 100644 --- a/lib/mongo/util/bson_c.rb +++ b/lib/mongo/util/bson_c.rb @@ -1,8 +1,8 @@ # A thin wrapper for the CBson class class BSON_C - def self.serialize(obj, check_keys=false) - ByteBuffer.new(CBson.serialize(obj, check_keys)) + def self.serialize(obj, check_keys=false, move_id=false) + ByteBuffer.new(CBson.serialize(obj, check_keys, move_id)) end def self.deserialize(buf=nil) diff --git a/lib/mongo/util/bson_ruby.rb b/lib/mongo/util/bson_ruby.rb index f0f4a57..f6a85f5 100644 --- a/lib/mongo/util/bson_ruby.rb +++ b/lib/mongo/util/bson_ruby.rb @@ -87,15 +87,15 @@ class BSON_RUBY # Serializes an object. # Implemented to ensure an API compatible with BSON extension. - def self.serialize(obj, check_keys=false) - new.serialize(obj, check_keys) + def self.serialize(obj, check_keys=false, move_id=false) + new.serialize(obj, check_keys, move_id) end def self.deserialize(buf=nil) new.deserialize(buf) end - def serialize(obj, check_keys=false) + def serialize(obj, check_keys=false, move_id=false) raise "Document is null" unless obj @buf.rewind @@ -103,14 +103,20 @@ class BSON_RUBY @buf.put_int(0) # Write key/value pairs. Always write _id first if it exists. - if obj.has_key? '_id' - serialize_key_value('_id', obj['_id'], check_keys) - elsif obj.has_key? :_id - serialize_key_value('_id', obj[:_id], check_keys) + if move_id + if obj.has_key? '_id' + serialize_key_value('_id', obj['_id'], check_keys) + elsif obj.has_key? :_id + serialize_key_value('_id', obj[:_id], check_keys) + end + obj.each {|k, v| serialize_key_value(k, v, check_keys) unless k == '_id' || k == :_id } + else + if obj.has_key?('_id') && obj.has_key?(:_id) + obj.delete(:_id) + end + obj.each {|k, v| serialize_key_value(k, v, check_keys) } end - obj.each {|k, v| serialize_key_value(k, v, check_keys) unless k == '_id' || k == :_id } - serialize_eoo_element(@buf) if @buf.size > 4 * 1024 * 1024 raise InvalidDocument, "Document is too large (#{@buf.size}). BSON documents are limited to 4MB (#{4 * 1024 * 1024})." diff --git a/test/test_bson.rb b/test/test_bson.rb index 9f9d1f8..a0d5f4d 100644 --- a/test/test_bson.rb +++ b/test/test_bson.rb @@ -263,19 +263,19 @@ class BSONTest < Test::Unit::TestCase val = OrderedHash.new val['not_id'] = 1 val['_id'] = 2 - roundtrip = BSON.deserialize(BSON.serialize(val).to_a) + roundtrip = BSON.deserialize(BSON.serialize(val, false, true).to_a) assert_kind_of OrderedHash, roundtrip assert_equal '_id', roundtrip.keys.first val = {'a' => 'foo', 'b' => 'bar', :_id => 42, 'z' => 'hello'} - roundtrip = BSON.deserialize(BSON.serialize(val).to_a) + roundtrip = BSON.deserialize(BSON.serialize(val, false, true).to_a) assert_kind_of OrderedHash, roundtrip assert_equal '_id', roundtrip.keys.first end def test_nil_id doc = {"_id" => nil} - assert_equal doc, BSON.deserialize(bson = BSON.serialize(doc).to_a) + assert_equal doc, BSON.deserialize(bson = BSON.serialize(doc, false, true).to_a) end def test_timestamp @@ -394,4 +394,35 @@ class BSONTest < Test::Unit::TestCase end end + def test_move_id + a = OrderedHash.new + a['text'] = 'abc' + a['key'] = 'abc' + a['_id'] = 1 + + assert_equal ")\000\000\000\020_id\000\001\000\000\000\002text" + + "\000\004\000\000\000abc\000\002key\000\004\000\000\000abc\000\000", + BSON.serialize(a, false, true).to_s + assert_equal ")\000\000\000\002text\000\004\000\000\000abc\000\002key" + + "\000\004\000\000\000abc\000\020_id\000\001\000\000\000\000", + BSON.serialize(a, false, false).to_s + end + + def test_move_id_with_nested_doc + b = OrderedHash.new + b['text'] = 'abc' + b['_id'] = 2 + c = OrderedHash.new + c['text'] = 'abc' + c['hash'] = b + c['_id'] = 3 + assert_equal ">\000\000\000\020_id\000\003\000\000\000\002text" + + "\000\004\000\000\000abc\000\003hash\000\034\000\000" + + "\000\002text\000\004\000\000\000abc\000\020_id\000\002\000\000\000\000\000", + BSON.serialize(c, false, true).to_s + assert_equal ">\000\000\000\002text\000\004\000\000\000abc\000\003hash" + + "\000\034\000\000\000\002text\000\004\000\000\000abc\000\020_id" + + "\000\002\000\000\000\000\020_id\000\003\000\000\000\000", + BSON.serialize(c, false, false).to_s + end end