diff --git a/lib/mongo/util/bson_ruby.rb b/lib/mongo/util/bson_ruby.rb index f15ab2a..ba50349 100644 --- a/lib/mongo/util/bson_ruby.rb +++ b/lib/mongo/util/bson_ruby.rb @@ -72,6 +72,11 @@ class BSON_RUBY buf.put_array(to_utf8(val.to_s).unpack("C*") << 0) end + def self.serialize_key(buf, key) + raise InvalidDocument, "Key names must not contain the NULL byte" if key.include? 0 + self.serialize_cstr(buf, key) + end + def to_a @buf.to_a end @@ -365,7 +370,7 @@ class BSON_RUBY def serialize_null_element(buf, key) buf.put(NULL) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) end def serialize_dbref_element(buf, key, val) @@ -377,7 +382,7 @@ class BSON_RUBY def serialize_binary_element(buf, key, val) buf.put(BINARY) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) bytes = val.to_a num_bytes = bytes.length @@ -396,13 +401,13 @@ class BSON_RUBY def serialize_boolean_element(buf, key, val) buf.put(BOOLEAN) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) buf.put(val ? 1 : 0) end def serialize_date_element(buf, key, val) buf.put(DATE) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) millisecs = (val.to_f * 1000).to_i buf.put_long(millisecs) end @@ -410,7 +415,7 @@ class BSON_RUBY def serialize_number_element(buf, key, val, type) if type == NUMBER buf.put(type) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) buf.put_double(val) else if val > 2**64 / 2 - 1 or val < -2**64 / 2 @@ -418,11 +423,11 @@ class BSON_RUBY end if val > 2**32 / 2 - 1 or val < -2**32 / 2 buf.put(NUMBER_LONG) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) buf.put_long(val) else buf.put(type) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) buf.put_int(val) end end @@ -430,7 +435,7 @@ class BSON_RUBY def serialize_object_element(buf, key, val, check_keys, opcode=OBJECT) buf.put(opcode) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) buf.put_array(BSON.new.serialize(val, check_keys).to_a) end @@ -444,7 +449,7 @@ class BSON_RUBY def serialize_regex_element(buf, key, val) buf.put(REGEX) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) str = val.to_s.sub(/.*?:/, '')[0..-2] # Turn "(?xxx:yyy)" into "yyy" self.class.serialize_cstr(buf, str) @@ -461,14 +466,14 @@ class BSON_RUBY def serialize_oid_element(buf, key, val) buf.put(OID) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) buf.put_array(val.to_a) end def serialize_string_element(buf, key, val, type) buf.put(type) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) # Make a hole for the length len_pos = buf.position @@ -488,7 +493,7 @@ class BSON_RUBY def serialize_code_w_scope(buf, key, val) buf.put(CODE_W_SCOPE) - self.class.serialize_cstr(buf, key) + self.class.serialize_key(buf, key) # Make a hole for the length len_pos = buf.position diff --git a/test/test_bson.rb b/test/test_bson.rb index 6e52ce5..2a88230 100644 --- a/test/test_bson.rb +++ b/test/test_bson.rb @@ -321,4 +321,14 @@ class BSONTest < Test::Unit::TestCase assert_equal BSON.serialize(one).to_a, BSON.serialize(dup).to_a end + def test_null_character + doc = {"a" => "\x00"} + + assert_equal doc, BSON.deserialize(BSON.serialize(doc).to_a) + + assert_raise InvalidDocument do + BSON.serialize({"\x00" => "a"}) + end + end + end