don't allow invalid key names on inserts
This commit is contained in:
parent
6e9a5194f9
commit
71d7ff726b
|
@ -66,7 +66,7 @@ static int cmp_char(const void* a, const void* b) {
|
||||||
return *(char*)a - *(char*)b;
|
return *(char*)a - *(char*)b;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void write_doc(bson_buffer* buffer, VALUE hash);
|
static void write_doc(bson_buffer* buffer, VALUE hash, VALUE no_dollar_sign);
|
||||||
static int write_element(VALUE key, VALUE value, VALUE extra);
|
static int write_element(VALUE key, VALUE value, VALUE extra);
|
||||||
static VALUE elements_to_hash(const char* buffer, int max);
|
static VALUE elements_to_hash(const char* buffer, int max);
|
||||||
|
|
||||||
|
@ -126,6 +126,10 @@ static void buffer_write_bytes(bson_buffer* buffer, const char* bytes, int size)
|
||||||
buffer->position += size;
|
buffer->position += size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static VALUE pack_extra(bson_buffer* buffer, VALUE no_dollar_sign) {
|
||||||
|
return rb_ary_new3(2, INT2NUM((int)buffer), no_dollar_sign);
|
||||||
|
}
|
||||||
|
|
||||||
static void write_name_and_type(bson_buffer* buffer, VALUE name, char type) {
|
static void write_name_and_type(bson_buffer* buffer, VALUE name, char type) {
|
||||||
buffer_write_bytes(buffer, &type, 1);
|
buffer_write_bytes(buffer, &type, 1);
|
||||||
buffer_write_bytes(buffer, RSTRING_PTR(name), RSTRING_LEN(name));
|
buffer_write_bytes(buffer, RSTRING_PTR(name), RSTRING_LEN(name));
|
||||||
|
@ -133,7 +137,8 @@ static void write_name_and_type(bson_buffer* buffer, VALUE name, char type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow_id) {
|
static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow_id) {
|
||||||
bson_buffer* buffer = (bson_buffer*)extra;
|
bson_buffer* buffer = (bson_buffer*)NUM2INT(rb_ary_entry(extra, 0));
|
||||||
|
VALUE no_dollar_sign = rb_ary_entry(extra, 1);
|
||||||
|
|
||||||
if (TYPE(key) == T_SYMBOL) {
|
if (TYPE(key) == T_SYMBOL) {
|
||||||
// TODO better way to do this... ?
|
// TODO better way to do this... ?
|
||||||
|
@ -148,6 +153,16 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
|
||||||
return ST_CONTINUE;
|
return ST_CONTINUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (no_dollar_sign == Qtrue && RSTRING_LEN(key) > 0 && RSTRING_PTR(key)[0] == '$') {
|
||||||
|
rb_raise(rb_eRuntimeError, "key must not start with '$'");
|
||||||
|
}
|
||||||
|
int i;
|
||||||
|
for (i = 0; i < RSTRING_LEN(key); i++) {
|
||||||
|
if (RSTRING_PTR(key)[i] == '.') {
|
||||||
|
rb_raise(rb_eRuntimeError, "key must not contain '.'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch(TYPE(value)) {
|
switch(TYPE(value)) {
|
||||||
case T_BIGNUM:
|
case T_BIGNUM:
|
||||||
{
|
{
|
||||||
|
@ -195,7 +210,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
|
||||||
case T_HASH:
|
case T_HASH:
|
||||||
{
|
{
|
||||||
write_name_and_type(buffer, key, 0x03);
|
write_name_and_type(buffer, key, 0x03);
|
||||||
write_doc(buffer, value);
|
write_doc(buffer, value, no_dollar_sign);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case T_ARRAY:
|
case T_ARRAY:
|
||||||
|
@ -213,7 +228,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
|
||||||
char* name;
|
char* name;
|
||||||
asprintf(&name, "%d", i);
|
asprintf(&name, "%d", i);
|
||||||
VALUE key = rb_str_new2(name);
|
VALUE key = rb_str_new2(name);
|
||||||
write_element(key, values[i], (VALUE)buffer);
|
write_element(key, values[i], pack_extra(buffer, no_dollar_sign));
|
||||||
free(name);
|
free(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,7 +251,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
|
||||||
buffer_write_bytes(buffer, (char*)&length, 4);
|
buffer_write_bytes(buffer, (char*)&length, 4);
|
||||||
buffer_write_bytes(buffer, RSTRING_PTR(value), length - 1);
|
buffer_write_bytes(buffer, RSTRING_PTR(value), length - 1);
|
||||||
buffer_write_bytes(buffer, &zero, 1);
|
buffer_write_bytes(buffer, &zero, 1);
|
||||||
write_doc(buffer, rb_funcall(value, rb_intern("scope"), 0));
|
write_doc(buffer, rb_funcall(value, rb_intern("scope"), 0), Qfalse);
|
||||||
|
|
||||||
int total_length = buffer->position - start_position;
|
int total_length = buffer->position - start_position;
|
||||||
memcpy(buffer->buffer + length_location, &total_length, 4);
|
memcpy(buffer->buffer + length_location, &total_length, 4);
|
||||||
|
@ -302,9 +317,9 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
|
||||||
int length_location = buffer_save_bytes(buffer, 4);
|
int length_location = buffer_save_bytes(buffer, 4);
|
||||||
|
|
||||||
VALUE ns = rb_funcall(value, rb_intern("namespace"), 0);
|
VALUE ns = rb_funcall(value, rb_intern("namespace"), 0);
|
||||||
write_element(rb_str_new2("$ref"), ns, (VALUE)buffer);
|
write_element(rb_str_new2("$ref"), ns, pack_extra(buffer, Qfalse));
|
||||||
VALUE oid = rb_funcall(value, rb_intern("object_id"), 0);
|
VALUE oid = rb_funcall(value, rb_intern("object_id"), 0);
|
||||||
write_element(rb_str_new2("$id"), oid, (VALUE)buffer);
|
write_element(rb_str_new2("$id"), oid, pack_extra(buffer, Qfalse));
|
||||||
|
|
||||||
// write null byte and fill in length
|
// write null byte and fill in length
|
||||||
buffer_write_bytes(buffer, &zero, 1);
|
buffer_write_bytes(buffer, &zero, 1);
|
||||||
|
@ -376,19 +391,19 @@ static int write_element(VALUE key, VALUE value, VALUE extra) {
|
||||||
return write_element_allow_id(key, value, extra, 0);
|
return write_element_allow_id(key, value, extra, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void write_doc(bson_buffer* buffer, VALUE hash) {
|
static void write_doc(bson_buffer* buffer, VALUE hash, VALUE no_dollar_sign) {
|
||||||
int start_position = buffer->position;
|
int start_position = buffer->position;
|
||||||
int length_location = buffer_save_bytes(buffer, 4);
|
int length_location = buffer_save_bytes(buffer, 4);
|
||||||
|
|
||||||
VALUE key = rb_str_new2("_id");
|
VALUE key = rb_str_new2("_id");
|
||||||
if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) {
|
if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) {
|
||||||
VALUE id = rb_hash_aref(hash, key);
|
VALUE id = rb_hash_aref(hash, key);
|
||||||
write_element_allow_id(key, id, (VALUE)buffer, 1);
|
write_element_allow_id(key, id, pack_extra(buffer, no_dollar_sign), 1);
|
||||||
}
|
}
|
||||||
key = ID2SYM(rb_intern("_id"));
|
key = ID2SYM(rb_intern("_id"));
|
||||||
if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) {
|
if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) {
|
||||||
VALUE id = rb_hash_aref(hash, key);
|
VALUE id = rb_hash_aref(hash, key);
|
||||||
write_element_allow_id(key, id, (VALUE)buffer, 1);
|
write_element_allow_id(key, id, pack_extra(buffer, no_dollar_sign), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// we have to check for an OrderedHash and handle that specially
|
// we have to check for an OrderedHash and handle that specially
|
||||||
|
@ -398,10 +413,11 @@ static void write_doc(bson_buffer* buffer, VALUE hash) {
|
||||||
for(i = 0; i < RARRAY_LEN(keys); i++) {
|
for(i = 0; i < RARRAY_LEN(keys); i++) {
|
||||||
VALUE key = RARRAY_PTR(keys)[i];
|
VALUE key = RARRAY_PTR(keys)[i];
|
||||||
VALUE value = rb_hash_aref(hash, key);
|
VALUE value = rb_hash_aref(hash, key);
|
||||||
write_element(key, value, (VALUE)buffer);
|
|
||||||
|
write_element(key, value, pack_extra(buffer, no_dollar_sign));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
rb_hash_foreach(hash, write_element, (VALUE)buffer);
|
rb_hash_foreach(hash, write_element, pack_extra(buffer, no_dollar_sign));
|
||||||
}
|
}
|
||||||
|
|
||||||
// write null byte and fill in length
|
// write null byte and fill in length
|
||||||
|
@ -410,11 +426,11 @@ static void write_doc(bson_buffer* buffer, VALUE hash) {
|
||||||
memcpy(buffer->buffer + length_location, &length, 4);
|
memcpy(buffer->buffer + length_location, &length, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
static VALUE method_serialize(VALUE self, VALUE doc) {
|
static VALUE method_serialize(VALUE self, VALUE doc, VALUE no_dollar_sign) {
|
||||||
bson_buffer* buffer = buffer_new();
|
bson_buffer* buffer = buffer_new();
|
||||||
assert(buffer);
|
assert(buffer);
|
||||||
|
|
||||||
write_doc(buffer, doc);
|
write_doc(buffer, doc, no_dollar_sign);
|
||||||
|
|
||||||
VALUE result = rb_str_new(buffer->buffer, buffer->position);
|
VALUE result = rb_str_new(buffer->buffer, buffer->position);
|
||||||
buffer_free(buffer);
|
buffer_free(buffer);
|
||||||
|
@ -681,6 +697,6 @@ void Init_cbson() {
|
||||||
OrderedHash = rb_const_get(rb_cObject, rb_intern("OrderedHash"));
|
OrderedHash = rb_const_get(rb_cObject, rb_intern("OrderedHash"));
|
||||||
|
|
||||||
VALUE CBson = rb_define_module("CBson");
|
VALUE CBson = rb_define_module("CBson");
|
||||||
rb_define_module_function(CBson, "serialize", method_serialize, 1);
|
rb_define_module_function(CBson, "serialize", method_serialize, 2);
|
||||||
rb_define_module_function(CBson, "deserialize", method_deserialize, 1);
|
rb_define_module_function(CBson, "deserialize", method_deserialize, 1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ module XGen
|
||||||
super(OP_INSERT)
|
super(OP_INSERT)
|
||||||
write_int(0)
|
write_int(0)
|
||||||
write_string("#{db_name}.#{collection_name}")
|
write_string("#{db_name}.#{collection_name}")
|
||||||
objs.each { |o| write_doc(o) }
|
objs.each { |o| write_doc(o, true) }
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -58,8 +58,8 @@ module XGen
|
||||||
update_message_length
|
update_message_length
|
||||||
end
|
end
|
||||||
|
|
||||||
def write_doc(hash)
|
def write_doc(hash, no_dollar_sign=false)
|
||||||
@buf.put_array(BSON.new.serialize(hash).to_a)
|
@buf.put_array(BSON.new.serialize(hash, no_dollar_sign).to_a)
|
||||||
update_message_length
|
update_message_length
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -73,11 +73,11 @@ class BSON
|
||||||
|
|
||||||
begin
|
begin
|
||||||
require 'mongo_ext/cbson'
|
require 'mongo_ext/cbson'
|
||||||
def serialize(obj)
|
def serialize(obj, no_dollar_sign=false)
|
||||||
@buf = ByteBuffer.new(CBson.serialize(obj))
|
@buf = ByteBuffer.new(CBson.serialize(obj, no_dollar_sign))
|
||||||
end
|
end
|
||||||
rescue LoadError
|
rescue LoadError
|
||||||
def serialize(obj)
|
def serialize(obj, no_dollar_sign=false)
|
||||||
raise "Document is null" unless obj
|
raise "Document is null" unless obj
|
||||||
|
|
||||||
@buf.rewind
|
@buf.rewind
|
||||||
|
@ -86,12 +86,12 @@ class BSON
|
||||||
|
|
||||||
# Write key/value pairs. Always write _id first if it exists.
|
# Write key/value pairs. Always write _id first if it exists.
|
||||||
if obj.has_key? '_id'
|
if obj.has_key? '_id'
|
||||||
serialize_key_value('_id', obj['_id'])
|
serialize_key_value('_id', obj['_id'], no_dollar_sign)
|
||||||
elsif obj.has_key? :_id
|
elsif obj.has_key? :_id
|
||||||
serialize_key_value('_id', obj[:_id])
|
serialize_key_value('_id', obj[:_id], no_dollar_sign)
|
||||||
end
|
end
|
||||||
|
|
||||||
obj.each {|k, v| serialize_key_value(k, v) unless k == '_id' || k == :_id }
|
obj.each {|k, v| serialize_key_value(k, v, no_dollar_sign) unless k == '_id' || k == :_id }
|
||||||
|
|
||||||
serialize_eoo_element(@buf)
|
serialize_eoo_element(@buf)
|
||||||
@buf.put_int(@buf.size, 0)
|
@buf.put_int(@buf.size, 0)
|
||||||
|
@ -99,7 +99,14 @@ class BSON
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def serialize_key_value(k, v)
|
def serialize_key_value(k, v, no_dollar_sign)
|
||||||
|
k = k.to_s
|
||||||
|
if no_dollar_sign and k[0] == ?$
|
||||||
|
raise RuntimeError.new("key #{k} must not start with '$'")
|
||||||
|
end
|
||||||
|
if k.include? ?.
|
||||||
|
raise RuntimeError.new("key #{k} must not contain '.'")
|
||||||
|
end
|
||||||
type = bson_type(v)
|
type = bson_type(v)
|
||||||
case type
|
case type
|
||||||
when STRING, SYMBOL
|
when STRING, SYMBOL
|
||||||
|
@ -107,11 +114,11 @@ class BSON
|
||||||
when NUMBER, NUMBER_INT
|
when NUMBER, NUMBER_INT
|
||||||
serialize_number_element(@buf, k, v, type)
|
serialize_number_element(@buf, k, v, type)
|
||||||
when OBJECT
|
when OBJECT
|
||||||
serialize_object_element(@buf, k, v)
|
serialize_object_element(@buf, k, v, no_dollar_sign)
|
||||||
when OID
|
when OID
|
||||||
serialize_oid_element(@buf, k, v)
|
serialize_oid_element(@buf, k, v)
|
||||||
when ARRAY
|
when ARRAY
|
||||||
serialize_array_element(@buf, k, v)
|
serialize_array_element(@buf, k, v, no_dollar_sign)
|
||||||
when REGEX
|
when REGEX
|
||||||
serialize_regex_element(@buf, k, v)
|
serialize_regex_element(@buf, k, v)
|
||||||
when BOOLEAN
|
when BOOLEAN
|
||||||
|
@ -344,7 +351,7 @@ class BSON
|
||||||
oh = OrderedHash.new
|
oh = OrderedHash.new
|
||||||
oh['$ref'] = val.namespace
|
oh['$ref'] = val.namespace
|
||||||
oh['$id'] = val.object_id
|
oh['$id'] = val.object_id
|
||||||
serialize_object_element(buf, key, oh)
|
serialize_object_element(buf, key, oh, false)
|
||||||
end
|
end
|
||||||
|
|
||||||
def serialize_binary_element(buf, key, val)
|
def serialize_binary_element(buf, key, val)
|
||||||
|
@ -397,18 +404,18 @@ class BSON
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def serialize_object_element(buf, key, val, opcode=OBJECT)
|
def serialize_object_element(buf, key, val, no_dollar_sign, opcode=OBJECT)
|
||||||
buf.put(opcode)
|
buf.put(opcode)
|
||||||
self.class.serialize_cstr(buf, key)
|
self.class.serialize_cstr(buf, key)
|
||||||
buf.put_array(BSON.new.serialize(val).to_a)
|
buf.put_array(BSON.new.serialize(val, no_dollar_sign).to_a)
|
||||||
end
|
end
|
||||||
|
|
||||||
def serialize_array_element(buf, key, val)
|
def serialize_array_element(buf, key, val, no_dollar_sign)
|
||||||
# Turn array into hash with integer indices as keys
|
# Turn array into hash with integer indices as keys
|
||||||
h = OrderedHash.new
|
h = OrderedHash.new
|
||||||
i = 0
|
i = 0
|
||||||
val.each { |v| h[i] = v; i += 1 }
|
val.each { |v| h[i] = v; i += 1 }
|
||||||
serialize_object_element(buf, key, h, ARRAY)
|
serialize_object_element(buf, key, h, no_dollar_sign, ARRAY)
|
||||||
end
|
end
|
||||||
|
|
||||||
def serialize_regex_element(buf, key, val)
|
def serialize_regex_element(buf, key, val)
|
||||||
|
|
|
@ -585,6 +585,44 @@ class DBAPITest < Test::Unit::TestCase
|
||||||
assert_equal 2, @@coll.count
|
assert_equal 2, @@coll.count
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def test_invalid_key_names
|
||||||
|
@@coll.clear
|
||||||
|
|
||||||
|
@@coll.insert({"hello" => "world"})
|
||||||
|
@@coll.insert({"hello" => {"hello" => "world"}})
|
||||||
|
|
||||||
|
assert_raise RuntimeError do
|
||||||
|
@@coll.insert({"$hello" => "world"})
|
||||||
|
end
|
||||||
|
assert_raise RuntimeError do
|
||||||
|
@@coll.insert({"hello" => {"$hello" => "world"}})
|
||||||
|
end
|
||||||
|
|
||||||
|
@@coll.insert({"he$llo" => "world"})
|
||||||
|
@@coll.insert({"hello" => {"hell$o" => "world"}})
|
||||||
|
|
||||||
|
assert_raise RuntimeError do
|
||||||
|
@@coll.insert({".hello" => "world"})
|
||||||
|
end
|
||||||
|
assert_raise RuntimeError do
|
||||||
|
@@coll.insert({"hello" => {".hello" => "world"}})
|
||||||
|
end
|
||||||
|
assert_raise RuntimeError do
|
||||||
|
@@coll.insert({"hello." => "world"})
|
||||||
|
end
|
||||||
|
assert_raise RuntimeError do
|
||||||
|
@@coll.insert({"hello" => {"hello." => "world"}})
|
||||||
|
end
|
||||||
|
assert_raise RuntimeError do
|
||||||
|
@@coll.insert({"hel.lo" => "world"})
|
||||||
|
end
|
||||||
|
assert_raise RuntimeError do
|
||||||
|
@@coll.insert({"hello" => {"hel.lo" => "world"}})
|
||||||
|
end
|
||||||
|
|
||||||
|
@@coll.modify({"hello" => "world"}, {"$inc" => "hello"})
|
||||||
|
end
|
||||||
|
|
||||||
# TODO this test fails with error message "Undefed Before end of object"
|
# TODO this test fails with error message "Undefed Before end of object"
|
||||||
# That is a database error. The undefined type may go away.
|
# That is a database error. The undefined type may go away.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue