BUG RUBY-15 don't check key names on create_index operations

This commit is contained in:
Mike Dirolf 2009-06-02 09:38:31 -04:00
parent 2743fd39b1
commit d87a7da617
5 changed files with 48 additions and 44 deletions

View File

@ -66,7 +66,7 @@ static int cmp_char(const void* a, const void* b) {
return *(char*)a - *(char*)b; return *(char*)a - *(char*)b;
} }
static void write_doc(bson_buffer* buffer, VALUE hash, VALUE no_dollar_sign); static void write_doc(bson_buffer* buffer, VALUE hash, VALUE check_keys);
static int write_element(VALUE key, VALUE value, VALUE extra); static int write_element(VALUE key, VALUE value, VALUE extra);
static VALUE elements_to_hash(const char* buffer, int max); static VALUE elements_to_hash(const char* buffer, int max);
@ -126,8 +126,8 @@ static void buffer_write_bytes(bson_buffer* buffer, const char* bytes, int size)
buffer->position += size; buffer->position += size;
} }
static VALUE pack_extra(bson_buffer* buffer, VALUE no_dollar_sign) { static VALUE pack_extra(bson_buffer* buffer, VALUE check_keys) {
return rb_ary_new3(2, INT2NUM((int)buffer), no_dollar_sign); return rb_ary_new3(2, INT2NUM((int)buffer), check_keys);
} }
static void write_name_and_type(bson_buffer* buffer, VALUE name, char type) { static void write_name_and_type(bson_buffer* buffer, VALUE name, char type) {
@ -138,7 +138,7 @@ static void write_name_and_type(bson_buffer* buffer, VALUE name, char type) {
static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow_id) { static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow_id) {
bson_buffer* buffer = (bson_buffer*)NUM2INT(rb_ary_entry(extra, 0)); bson_buffer* buffer = (bson_buffer*)NUM2INT(rb_ary_entry(extra, 0));
VALUE no_dollar_sign = rb_ary_entry(extra, 1); VALUE check_keys = rb_ary_entry(extra, 1);
if (TYPE(key) == T_SYMBOL) { if (TYPE(key) == T_SYMBOL) {
// TODO better way to do this... ? // TODO better way to do this... ?
@ -153,7 +153,8 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
return ST_CONTINUE; return ST_CONTINUE;
} }
if (no_dollar_sign == Qtrue && RSTRING_LEN(key) > 0 && RSTRING_PTR(key)[0] == '$') { if (check_keys == Qtrue) {
if (RSTRING_LEN(key) > 0 && RSTRING_PTR(key)[0] == '$') {
rb_raise(rb_eRuntimeError, "key must not start with '$'"); rb_raise(rb_eRuntimeError, "key must not start with '$'");
} }
int i; int i;
@ -162,6 +163,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
rb_raise(rb_eRuntimeError, "key must not contain '.'"); rb_raise(rb_eRuntimeError, "key must not contain '.'");
} }
} }
}
switch(TYPE(value)) { switch(TYPE(value)) {
case T_BIGNUM: case T_BIGNUM:
@ -210,7 +212,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
case T_HASH: case T_HASH:
{ {
write_name_and_type(buffer, key, 0x03); write_name_and_type(buffer, key, 0x03);
write_doc(buffer, value, no_dollar_sign); write_doc(buffer, value, check_keys);
break; break;
} }
case T_ARRAY: case T_ARRAY:
@ -228,7 +230,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
char* name; char* name;
asprintf(&name, "%d", i); asprintf(&name, "%d", i);
VALUE key = rb_str_new2(name); VALUE key = rb_str_new2(name);
write_element(key, values[i], pack_extra(buffer, no_dollar_sign)); write_element(key, values[i], pack_extra(buffer, check_keys));
free(name); free(name);
} }
@ -391,19 +393,19 @@ static int write_element(VALUE key, VALUE value, VALUE extra) {
return write_element_allow_id(key, value, extra, 0); return write_element_allow_id(key, value, extra, 0);
} }
static void write_doc(bson_buffer* buffer, VALUE hash, VALUE no_dollar_sign) { static void write_doc(bson_buffer* buffer, VALUE hash, VALUE check_keys) {
int start_position = buffer->position; int start_position = buffer->position;
int length_location = buffer_save_bytes(buffer, 4); int length_location = buffer_save_bytes(buffer, 4);
VALUE key = rb_str_new2("_id"); VALUE key = rb_str_new2("_id");
if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) { if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) {
VALUE id = rb_hash_aref(hash, key); VALUE id = rb_hash_aref(hash, key);
write_element_allow_id(key, id, pack_extra(buffer, no_dollar_sign), 1); write_element_allow_id(key, id, pack_extra(buffer, check_keys), 1);
} }
key = ID2SYM(rb_intern("_id")); key = ID2SYM(rb_intern("_id"));
if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) { if (rb_funcall(hash, rb_intern("has_key?"), 1, key) == Qtrue) {
VALUE id = rb_hash_aref(hash, key); VALUE id = rb_hash_aref(hash, key);
write_element_allow_id(key, id, pack_extra(buffer, no_dollar_sign), 1); write_element_allow_id(key, id, pack_extra(buffer, check_keys), 1);
} }
// we have to check for an OrderedHash and handle that specially // we have to check for an OrderedHash and handle that specially
@ -414,10 +416,10 @@ static void write_doc(bson_buffer* buffer, VALUE hash, VALUE no_dollar_sign) {
VALUE key = RARRAY_PTR(keys)[i]; VALUE key = RARRAY_PTR(keys)[i];
VALUE value = rb_hash_aref(hash, key); VALUE value = rb_hash_aref(hash, key);
write_element(key, value, pack_extra(buffer, no_dollar_sign)); write_element(key, value, pack_extra(buffer, check_keys));
} }
} else { } else {
rb_hash_foreach(hash, write_element, pack_extra(buffer, no_dollar_sign)); rb_hash_foreach(hash, write_element, pack_extra(buffer, check_keys));
} }
// write null byte and fill in length // write null byte and fill in length
@ -426,11 +428,11 @@ static void write_doc(bson_buffer* buffer, VALUE hash, VALUE no_dollar_sign) {
memcpy(buffer->buffer + length_location, &length, 4); memcpy(buffer->buffer + length_location, &length, 4);
} }
static VALUE method_serialize(VALUE self, VALUE doc, VALUE no_dollar_sign) { static VALUE method_serialize(VALUE self, VALUE doc, VALUE check_keys) {
bson_buffer* buffer = buffer_new(); bson_buffer* buffer = buffer_new();
assert(buffer); assert(buffer);
write_doc(buffer, doc, no_dollar_sign); write_doc(buffer, doc, check_keys);
VALUE result = rb_str_new(buffer->buffer, buffer->position); VALUE result = rb_str_new(buffer->buffer, buffer->position);
buffer_free(buffer); buffer_free(buffer);

View File

@ -473,7 +473,7 @@ module XGen
:unique => unique :unique => unique
} }
@semaphore.synchronize { @semaphore.synchronize {
send_to_db(InsertMessage.new(@name, SYSTEM_INDEX_COLLECTION, sel)) send_to_db(InsertMessage.new(@name, SYSTEM_INDEX_COLLECTION, false, sel))
} }
name name
end end
@ -485,7 +485,7 @@ module XGen
@semaphore.synchronize { @semaphore.synchronize {
objects.collect { |o| objects.collect { |o|
o = @pk_factory.create_pk(o) if @pk_factory o = @pk_factory.create_pk(o) if @pk_factory
send_to_db(InsertMessage.new(@name, collection_name, o)) send_to_db(InsertMessage.new(@name, collection_name, true, o))
o o
} }
} }

View File

@ -23,11 +23,11 @@ module XGen
class InsertMessage < Message class InsertMessage < Message
def initialize(db_name, collection_name, *objs) def initialize(db_name, collection_name, check_keys=true, *objs)
super(OP_INSERT) super(OP_INSERT)
write_int(0) write_int(0)
write_string("#{db_name}.#{collection_name}") write_string("#{db_name}.#{collection_name}")
objs.each { |o| write_doc(o, true) } objs.each { |o| write_doc(o, check_keys) }
end end
end end
end end

View File

@ -58,8 +58,8 @@ module XGen
update_message_length update_message_length
end end
def write_doc(hash, no_dollar_sign=false) def write_doc(hash, check_keys=false)
@buf.put_array(BSON.new.serialize(hash, no_dollar_sign).to_a) @buf.put_array(BSON.new.serialize(hash, check_keys).to_a)
update_message_length update_message_length
end end

View File

@ -73,11 +73,11 @@ class BSON
begin begin
require 'mongo_ext/cbson' require 'mongo_ext/cbson'
def serialize(obj, no_dollar_sign=false) def serialize(obj, check_keys=false)
@buf = ByteBuffer.new(CBson.serialize(obj, no_dollar_sign)) @buf = ByteBuffer.new(CBson.serialize(obj, check_keys))
end end
rescue LoadError rescue LoadError
def serialize(obj, no_dollar_sign=false) def serialize(obj, check_keys=false)
raise "Document is null" unless obj raise "Document is null" unless obj
@buf.rewind @buf.rewind
@ -86,12 +86,12 @@ class BSON
# Write key/value pairs. Always write _id first if it exists. # Write key/value pairs. Always write _id first if it exists.
if obj.has_key? '_id' if obj.has_key? '_id'
serialize_key_value('_id', obj['_id'], no_dollar_sign) serialize_key_value('_id', obj['_id'], check_keys)
elsif obj.has_key? :_id elsif obj.has_key? :_id
serialize_key_value('_id', obj[:_id], no_dollar_sign) serialize_key_value('_id', obj[:_id], check_keys)
end end
obj.each {|k, v| serialize_key_value(k, v, no_dollar_sign) unless k == '_id' || k == :_id } obj.each {|k, v| serialize_key_value(k, v, check_keys) unless k == '_id' || k == :_id }
serialize_eoo_element(@buf) serialize_eoo_element(@buf)
@buf.put_int(@buf.size, 0) @buf.put_int(@buf.size, 0)
@ -99,14 +99,16 @@ class BSON
end end
end end
def serialize_key_value(k, v, no_dollar_sign) def serialize_key_value(k, v, check_keys)
k = k.to_s k = k.to_s
if no_dollar_sign and k[0] == ?$ if check_keys
if k[0] == ?$
raise RuntimeError.new("key #{k} must not start with '$'") raise RuntimeError.new("key #{k} must not start with '$'")
end end
if k.include? ?. if k.include? ?.
raise RuntimeError.new("key #{k} must not contain '.'") raise RuntimeError.new("key #{k} must not contain '.'")
end end
end
type = bson_type(v) type = bson_type(v)
case type case type
when STRING, SYMBOL when STRING, SYMBOL
@ -114,11 +116,11 @@ class BSON
when NUMBER, NUMBER_INT when NUMBER, NUMBER_INT
serialize_number_element(@buf, k, v, type) serialize_number_element(@buf, k, v, type)
when OBJECT when OBJECT
serialize_object_element(@buf, k, v, no_dollar_sign) serialize_object_element(@buf, k, v, check_keys)
when OID when OID
serialize_oid_element(@buf, k, v) serialize_oid_element(@buf, k, v)
when ARRAY when ARRAY
serialize_array_element(@buf, k, v, no_dollar_sign) serialize_array_element(@buf, k, v, check_keys)
when REGEX when REGEX
serialize_regex_element(@buf, k, v) serialize_regex_element(@buf, k, v)
when BOOLEAN when BOOLEAN
@ -404,18 +406,18 @@ class BSON
end end
end end
def serialize_object_element(buf, key, val, no_dollar_sign, opcode=OBJECT) def serialize_object_element(buf, key, val, check_keys, opcode=OBJECT)
buf.put(opcode) buf.put(opcode)
self.class.serialize_cstr(buf, key) self.class.serialize_cstr(buf, key)
buf.put_array(BSON.new.serialize(val, no_dollar_sign).to_a) buf.put_array(BSON.new.serialize(val, check_keys).to_a)
end end
def serialize_array_element(buf, key, val, no_dollar_sign) def serialize_array_element(buf, key, val, check_keys)
# Turn array into hash with integer indices as keys # Turn array into hash with integer indices as keys
h = OrderedHash.new h = OrderedHash.new
i = 0 i = 0
val.each { |v| h[i] = v; i += 1 } val.each { |v| h[i] = v; i += 1 }
serialize_object_element(buf, key, h, no_dollar_sign, ARRAY) serialize_object_element(buf, key, h, check_keys, ARRAY)
end end
def serialize_regex_element(buf, key, val) def serialize_regex_element(buf, key, val)