diff --git a/ext/cbson/cbson.c b/ext/cbson/cbson.c index b29c013..17e5785 100644 --- a/ext/cbson/cbson.c +++ b/ext/cbson/cbson.c @@ -66,7 +66,7 @@ static int cmp_char(const void* a, const void* b) { return *(char*)a - *(char*)b; } -static void write_doc(bson_buffer* buffer, VALUE hash); +static void write_doc(bson_buffer* buffer, VALUE hash, VALUE no_dollar_sign); static int write_element(VALUE key, VALUE value, VALUE extra); static VALUE elements_to_hash(const char* buffer, int max); @@ -126,6 +126,10 @@ static void buffer_write_bytes(bson_buffer* buffer, const char* bytes, int size) buffer->position += size; } +static VALUE pack_extra(bson_buffer* buffer, VALUE no_dollar_sign) { + return rb_ary_new3(2, INT2NUM((int)buffer), no_dollar_sign); +} + static void write_name_and_type(bson_buffer* buffer, VALUE name, char type) { buffer_write_bytes(buffer, &type, 1); buffer_write_bytes(buffer, RSTRING_PTR(name), RSTRING_LEN(name)); @@ -133,7 +137,8 @@ static void write_name_and_type(bson_buffer* buffer, VALUE name, char type) { } static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow_id) { - bson_buffer* buffer = (bson_buffer*)extra; + bson_buffer* buffer = (bson_buffer*)NUM2INT(rb_ary_entry(extra, 0)); + VALUE no_dollar_sign = rb_ary_entry(extra, 1); if (TYPE(key) == T_SYMBOL) { // TODO better way to do this... ? @@ -148,6 +153,16 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow return ST_CONTINUE; } + if (no_dollar_sign == Qtrue && RSTRING_LEN(key) > 0 && RSTRING_PTR(key)[0] == '$') { + rb_raise(rb_eRuntimeError, "key must not start with '$'"); + } + int i; + for (i = 0; i < RSTRING_LEN(key); i++) { + if (RSTRING_PTR(key)[i] == '.') { + rb_raise(rb_eRuntimeError, "key must not contain '.'"); + } + } + switch(TYPE(value)) { case T_BIGNUM: { @@ -195,7 +210,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow case T_HASH: { write_name_and_type(buffer, key, 0x03); - write_doc(buffer, value); + write_doc(buffer, value, no_dollar_sign); break; } case T_ARRAY: @@ -213,7 +228,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow char* name; asprintf(&name, "%d", i); VALUE key = rb_str_new2(name); - write_element(key, values[i], (VALUE)buffer); + write_element(key, values[i], pack_extra(buffer, no_dollar_sign)); free(name); } @@ -236,7 +251,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow buffer_write_bytes(buffer, (char*)&length, 4); buffer_write_bytes(buffer, RSTRING_PTR(value), length - 1); buffer_write_bytes(buffer, &zero, 1); - write_doc(buffer, rb_funcall(value, rb_intern("scope"), 0)); + write_doc(buffer, rb_funcall(value, rb_intern("scope"), 0), Qfalse); int total_length = buffer->position - start_position; memcpy(buffer->buffer + length_location, &total_length, 4); @@ -302,9 +317,9 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow int length_location = buffer_save_bytes(buffer, 4); VALUE ns = rb_funcall(value, rb_intern("namespace"), 0); - write_element(rb_str_new2("$ref"), ns, (VALUE)buffer); + write_element(rb_str_new2("$ref"), ns, pack_extra(buffer, Qfalse)); VALUE oid = rb_funcall(value, rb_intern("object_id"), 0); - write_element(rb_str_new2("$id"), oid, (VALUE)buffer); + write_element(rb_str_new2("$id"), oid, pack_extra(buffer, Qfalse)); // write null byte and fill in length buffer_write_bytes(buffer, &zero, 1); @@ -376,19 +391,19 @@ static int write_element(VALUE key, VALUE value, VALUE extra) { return write_element_allow_id(key, value, extra, 0); } -static void write_doc(bson_buffer* buffer, VALUE hash) { +static void write_doc(bson_buffer* buffer, VALUE hash, VALUE no_dollar_sign) { int start_position = buffer->position; int length_location = buffer_save_bytes(buffer, 4); VALUE key = rb_str_new2("_id"); if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) { VALUE id = rb_hash_aref(hash, key); - write_element_allow_id(key, id, (VALUE)buffer, 1); + write_element_allow_id(key, id, pack_extra(buffer, no_dollar_sign), 1); } key = ID2SYM(rb_intern("_id")); if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) { VALUE id = rb_hash_aref(hash, key); - write_element_allow_id(key, id, (VALUE)buffer, 1); + write_element_allow_id(key, id, pack_extra(buffer, no_dollar_sign), 1); } // we have to check for an OrderedHash and handle that specially @@ -398,10 +413,11 @@ static void write_doc(bson_buffer* buffer, VALUE hash) { for(i = 0; i < RARRAY_LEN(keys); i++) { VALUE key = RARRAY_PTR(keys)[i]; VALUE value = rb_hash_aref(hash, key); - write_element(key, value, (VALUE)buffer); + + write_element(key, value, pack_extra(buffer, no_dollar_sign)); } } else { - rb_hash_foreach(hash, write_element, (VALUE)buffer); + rb_hash_foreach(hash, write_element, pack_extra(buffer, no_dollar_sign)); } // write null byte and fill in length @@ -410,11 +426,11 @@ static void write_doc(bson_buffer* buffer, VALUE hash) { memcpy(buffer->buffer + length_location, &length, 4); } -static VALUE method_serialize(VALUE self, VALUE doc) { +static VALUE method_serialize(VALUE self, VALUE doc, VALUE no_dollar_sign) { bson_buffer* buffer = buffer_new(); assert(buffer); - write_doc(buffer, doc); + write_doc(buffer, doc, no_dollar_sign); VALUE result = rb_str_new(buffer->buffer, buffer->position); buffer_free(buffer); @@ -681,6 +697,6 @@ void Init_cbson() { OrderedHash = rb_const_get(rb_cObject, rb_intern("OrderedHash")); VALUE CBson = rb_define_module("CBson"); - rb_define_module_function(CBson, "serialize", method_serialize, 1); + rb_define_module_function(CBson, "serialize", method_serialize, 2); rb_define_module_function(CBson, "deserialize", method_deserialize, 1); } diff --git a/lib/mongo/message/insert_message.rb b/lib/mongo/message/insert_message.rb index b202445..6bda26a 100644 --- a/lib/mongo/message/insert_message.rb +++ b/lib/mongo/message/insert_message.rb @@ -27,7 +27,7 @@ module XGen super(OP_INSERT) write_int(0) write_string("#{db_name}.#{collection_name}") - objs.each { |o| write_doc(o) } + objs.each { |o| write_doc(o, true) } end end end diff --git a/lib/mongo/message/message.rb b/lib/mongo/message/message.rb index 6d55e19..2ab72dd 100644 --- a/lib/mongo/message/message.rb +++ b/lib/mongo/message/message.rb @@ -36,7 +36,7 @@ module XGen @request_id = (@@class_req_id += 1) @response_id = 0 @buf = ByteBuffer.new - + @buf.put_int(16) # holder for length @buf.put_int(@request_id) @buf.put_int(0) # response_to @@ -58,8 +58,8 @@ module XGen update_message_length end - def write_doc(hash) - @buf.put_array(BSON.new.serialize(hash).to_a) + def write_doc(hash, no_dollar_sign=false) + @buf.put_array(BSON.new.serialize(hash, no_dollar_sign).to_a) update_message_length end diff --git a/lib/mongo/util/bson.rb b/lib/mongo/util/bson.rb index 23267d0..f11c31c 100644 --- a/lib/mongo/util/bson.rb +++ b/lib/mongo/util/bson.rb @@ -73,11 +73,11 @@ class BSON begin require 'mongo_ext/cbson' - def serialize(obj) - @buf = ByteBuffer.new(CBson.serialize(obj)) + def serialize(obj, no_dollar_sign=false) + @buf = ByteBuffer.new(CBson.serialize(obj, no_dollar_sign)) end rescue LoadError - def serialize(obj) + def serialize(obj, no_dollar_sign=false) raise "Document is null" unless obj @buf.rewind @@ -86,12 +86,12 @@ class BSON # Write key/value pairs. Always write _id first if it exists. if obj.has_key? '_id' - serialize_key_value('_id', obj['_id']) + serialize_key_value('_id', obj['_id'], no_dollar_sign) elsif obj.has_key? :_id - serialize_key_value('_id', obj[:_id]) + serialize_key_value('_id', obj[:_id], no_dollar_sign) end - obj.each {|k, v| serialize_key_value(k, v) unless k == '_id' || k == :_id } + obj.each {|k, v| serialize_key_value(k, v, no_dollar_sign) unless k == '_id' || k == :_id } serialize_eoo_element(@buf) @buf.put_int(@buf.size, 0) @@ -99,38 +99,45 @@ class BSON end end - def serialize_key_value(k, v) - type = bson_type(v) - case type - when STRING, SYMBOL - serialize_string_element(@buf, k, v, type) - when NUMBER, NUMBER_INT - serialize_number_element(@buf, k, v, type) - when OBJECT - serialize_object_element(@buf, k, v) - when OID - serialize_oid_element(@buf, k, v) - when ARRAY - serialize_array_element(@buf, k, v) - when REGEX - serialize_regex_element(@buf, k, v) - when BOOLEAN - serialize_boolean_element(@buf, k, v) - when DATE - serialize_date_element(@buf, k, v) - when NULL - serialize_null_element(@buf, k) - when REF - serialize_dbref_element(@buf, k, v) - when BINARY - serialize_binary_element(@buf, k, v) - when UNDEFINED - serialize_undefined_element(@buf, k) - when CODE_W_SCOPE - serialize_code_w_scope(@buf, k, v) - else - raise "unhandled type #{type}" - end + def serialize_key_value(k, v, no_dollar_sign) + k = k.to_s + if no_dollar_sign and k[0] == ?$ + raise RuntimeError.new("key #{k} must not start with '$'") + end + if k.include? ?. + raise RuntimeError.new("key #{k} must not contain '.'") + end + type = bson_type(v) + case type + when STRING, SYMBOL + serialize_string_element(@buf, k, v, type) + when NUMBER, NUMBER_INT + serialize_number_element(@buf, k, v, type) + when OBJECT + serialize_object_element(@buf, k, v, no_dollar_sign) + when OID + serialize_oid_element(@buf, k, v) + when ARRAY + serialize_array_element(@buf, k, v, no_dollar_sign) + when REGEX + serialize_regex_element(@buf, k, v) + when BOOLEAN + serialize_boolean_element(@buf, k, v) + when DATE + serialize_date_element(@buf, k, v) + when NULL + serialize_null_element(@buf, k) + when REF + serialize_dbref_element(@buf, k, v) + when BINARY + serialize_binary_element(@buf, k, v) + when UNDEFINED + serialize_undefined_element(@buf, k) + when CODE_W_SCOPE + serialize_code_w_scope(@buf, k, v) + else + raise "unhandled type #{type}" + end end begin @@ -344,7 +351,7 @@ class BSON oh = OrderedHash.new oh['$ref'] = val.namespace oh['$id'] = val.object_id - serialize_object_element(buf, key, oh) + serialize_object_element(buf, key, oh, false) end def serialize_binary_element(buf, key, val) @@ -397,18 +404,18 @@ class BSON end end - def serialize_object_element(buf, key, val, opcode=OBJECT) + def serialize_object_element(buf, key, val, no_dollar_sign, opcode=OBJECT) buf.put(opcode) self.class.serialize_cstr(buf, key) - buf.put_array(BSON.new.serialize(val).to_a) + buf.put_array(BSON.new.serialize(val, no_dollar_sign).to_a) end - def serialize_array_element(buf, key, val) + def serialize_array_element(buf, key, val, no_dollar_sign) # Turn array into hash with integer indices as keys h = OrderedHash.new i = 0 val.each { |v| h[i] = v; i += 1 } - serialize_object_element(buf, key, h, ARRAY) + serialize_object_element(buf, key, h, no_dollar_sign, ARRAY) end def serialize_regex_element(buf, key, val) diff --git a/tests/test_db_api.rb b/tests/test_db_api.rb index 3507401..b0d9d15 100644 --- a/tests/test_db_api.rb +++ b/tests/test_db_api.rb @@ -585,6 +585,44 @@ class DBAPITest < Test::Unit::TestCase assert_equal 2, @@coll.count end + def test_invalid_key_names + @@coll.clear + + @@coll.insert({"hello" => "world"}) + @@coll.insert({"hello" => {"hello" => "world"}}) + + assert_raise RuntimeError do + @@coll.insert({"$hello" => "world"}) + end + assert_raise RuntimeError do + @@coll.insert({"hello" => {"$hello" => "world"}}) + end + + @@coll.insert({"he$llo" => "world"}) + @@coll.insert({"hello" => {"hell$o" => "world"}}) + + assert_raise RuntimeError do + @@coll.insert({".hello" => "world"}) + end + assert_raise RuntimeError do + @@coll.insert({"hello" => {".hello" => "world"}}) + end + assert_raise RuntimeError do + @@coll.insert({"hello." => "world"}) + end + assert_raise RuntimeError do + @@coll.insert({"hello" => {"hello." => "world"}}) + end + assert_raise RuntimeError do + @@coll.insert({"hel.lo" => "world"}) + end + assert_raise RuntimeError do + @@coll.insert({"hello" => {"hel.lo" => "world"}}) + end + + @@coll.modify({"hello" => "world"}, {"$inc" => "hello"}) + end + # TODO this test fails with error message "Undefed Before end of object" # That is a database error. The undefined type may go away.