null checking for keys and regex patterns, allow nulls for regular strings

This commit is contained in:
Mike Dirolf 2009-12-17 12:17:19 -05:00
parent 9bbaafe03d
commit 65c36ca943
4 changed files with 62 additions and 32 deletions

View File

@ -72,15 +72,26 @@ static VALUE DigestMD5;
#define STR_NEW(p,n) rb_enc_str_new((p), (n), rb_utf8_encoding())
/* MUST call TO_UTF8 before calling write_utf8. */
#define TO_UTF8(string) rb_str_export_to_enc((string), rb_utf8_encoding())
static void write_utf8(buffer_t buffer, VALUE string) {
static void write_utf8(buffer_t buffer, VALUE string, char check_null) {
result_t status = check_string(RSTRING_PTR(string), RSTRING_LEN(string),
0, check_null);
if (status == HAS_NULL) {
buffer_free(buffer);
rb_raise(InvalidDocument, "Key names / regex patterns must not contain the NULL byte");
}
SAFE_WRITE(buffer, RSTRING_PTR(string), RSTRING_LEN(string));
}
#else
#define STR_NEW(p,n) rb_str_new((p), (n))
/* MUST call TO_UTF8 before calling write_utf8. */
#define TO_UTF8(string) (string)
static void write_utf8(buffer_t buffer, VALUE string) {
if (!is_legal_utf8_string(RSTRING_PTR(string), RSTRING_LEN(string))) {
static void write_utf8(buffer_t buffer, VALUE string, char check_null) {
result_t status = check_string(RSTRING_PTR(string), RSTRING_LEN(string),
1, check_null);
if (status == HAS_NULL) {
buffer_free(buffer);
rb_raise(InvalidDocument, "Key names / regex patterns must not contain the NULL byte");
} else if (status == NOT_UTF_8) {
buffer_free(buffer);
rb_raise(InvalidStringEncoding, "String not valid UTF-8");
}
@ -113,9 +124,8 @@ static void write_utf8(buffer_t buffer, VALUE string) {
#endif
// this sucks too.
#ifndef RREGEXP_SRC_PTR
#define RREGEXP_SRC_PTR(r) RREGEXP(r)->str
#define RREGEXP_SRC_LEN(r) RREGEXP(r)->len
#ifndef RREGEXP_SRC
#define RREGEXP_SRC(r) rb_str_new(RREGEXP((r))->str, RREGEXP((r))->len)
#endif
static char zero = 0;
@ -136,7 +146,7 @@ static VALUE pack_extra(buffer_t buffer, VALUE check_keys) {
static void write_name_and_type(buffer_t buffer, VALUE name, char type) {
SAFE_WRITE(buffer, &type, 1);
name = TO_UTF8(name);
write_utf8(buffer, name);
write_utf8(buffer, name, 1);
SAFE_WRITE(buffer, &zero, 1);
}
@ -286,7 +296,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
value = TO_UTF8(value);
length = RSTRING_LEN(value) + 1;
SAFE_WRITE(buffer, (char*)&length, 4);
write_utf8(buffer, value);
write_utf8(buffer, value, 0);
SAFE_WRITE(buffer, &zero, 1);
break;
}
@ -372,14 +382,14 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
}
case T_REGEXP:
{
int length = RREGEXP_SRC_LEN(value);
char* pattern = (char*)RREGEXP_SRC_PTR(value);
VALUE pattern = RREGEXP_SRC(value);
long flags = RREGEXP(value)->ptr->options;
VALUE has_extra;
write_name_and_type(buffer, key, 0x0B);
SAFE_WRITE(buffer, pattern, length);
pattern = TO_UTF8(pattern);
write_utf8(buffer, pattern, 1);
SAFE_WRITE(buffer, &zero, 1);
if (flags & IGNORECASE) {
@ -497,8 +507,8 @@ static VALUE get_value(const char* buffer, int* position, int type) {
case 13:
{
int value_length;
value_length = *(int*)(buffer + *position) - 1;
*position += 4;
value_length = strlen(buffer + *position);
value = STR_NEW(buffer + *position, value_length);
*position += value_length + 1;
break;
@ -508,10 +518,11 @@ static VALUE get_value(const char* buffer, int* position, int type) {
int size;
memcpy(&size, buffer + *position, 4);
if (strcmp(buffer + *position + 5, "$ref") == 0) { // DBRef
int offset = *position + 14;
int offset = *position + 10;
VALUE argv[2];
int collection_length = strlen(buffer + offset);
int collection_length = *(int*)(buffer + offset) - 1;
char id_type;
offset += 4;
argv[0] = STR_NEW(buffer + offset, collection_length);
offset += collection_length + 1;
@ -637,8 +648,8 @@ static VALUE get_value(const char* buffer, int* position, int type) {
{
int collection_length;
VALUE collection, str, oid, id, argv[2];
collection_length = *(int*)(buffer + *position) - 1;
*position += 4;
collection_length = strlen(buffer + *position);
collection = STR_NEW(buffer + *position, collection_length);
*position += collection_length + 1;
@ -664,8 +675,9 @@ static VALUE get_value(const char* buffer, int* position, int type) {
{
int code_length, scope_size;
VALUE code, scope, argv[2];
*position += 8;
code_length = strlen(buffer + *position);
*position += 4;
code_length = *(int*)(buffer + *position) - 1;
*position += 4;
code = STR_NEW(buffer + *position, code_length);
*position += code_length + 1;

View File

@ -14,8 +14,10 @@
* limitations under the License.
*/
#include "encoding_helpers.h"
/*
* Copyright 2001 Unicode, Inc.
* Portions Copyright 2001 Unicode, Inc.
*
* Disclaimer
*
@ -85,23 +87,32 @@ static unsigned char isLegalUTF8(const unsigned char* source, int length) {
return 1;
}
/* --------------------------------------------------------------------- */
/*
* Return whether a string containing UTF-8 is legal.
*/
unsigned char is_legal_utf8_string(const unsigned char* string, const int length) {
result_t check_string(const unsigned char* string, const int length,
const char check_utf8, const char check_null) {
int position = 0;
/* By default we go character by character. Will be different for checking
* UTF-8 */
int sequence_length = 1;
if (!check_utf8 && !check_null) {
return VALID;
}
while (position < length) {
int sequence_length = trailingBytesForUTF8[*(string + position)] + 1;
if (check_null && *(string + position) == 0) {
return HAS_NULL;
}
if (check_utf8) {
sequence_length = trailingBytesForUTF8[*(string + position)] + 1;
if ((position + sequence_length) > length) {
return 0;
return NOT_UTF_8;
}
if (!isLegalUTF8(string + position, sequence_length)) {
return 0;
return NOT_UTF_8;
}
}
position += sequence_length;
}
return 1;
return VALID;
}

View File

@ -17,6 +17,13 @@
#ifndef ENCODING_HELPERS_H
#define ENCODING_HELPERS_H
unsigned char is_legal_utf8_string(const unsigned char* string, const int length);
typedef enum {
VALID,
NOT_UTF_8,
HAS_NULL
} result_t;
result_t check_string(const unsigned char* string, const int length,
const char check_utf8, const char check_null);
#endif

View File

@ -73,7 +73,7 @@ class BSON_RUBY
end
def self.serialize_key(buf, key)
raise InvalidDocument, "Key names must not contain the NULL byte" if key.include? 0
raise InvalidDocument, "Key names / regex patterns must not contain the NULL byte" if key.include? 0
self.serialize_cstr(buf, key)
end