null checking for keys and regex patterns, allow nulls for regular strings

This commit is contained in:
Mike Dirolf 2009-12-17 12:17:19 -05:00
parent 9bbaafe03d
commit 65c36ca943
4 changed files with 62 additions and 32 deletions

View File

@ -72,15 +72,26 @@ static VALUE DigestMD5;
#define STR_NEW(p,n) rb_enc_str_new((p), (n), rb_utf8_encoding()) #define STR_NEW(p,n) rb_enc_str_new((p), (n), rb_utf8_encoding())
/* MUST call TO_UTF8 before calling write_utf8. */ /* MUST call TO_UTF8 before calling write_utf8. */
#define TO_UTF8(string) rb_str_export_to_enc((string), rb_utf8_encoding()) #define TO_UTF8(string) rb_str_export_to_enc((string), rb_utf8_encoding())
static void write_utf8(buffer_t buffer, VALUE string) { static void write_utf8(buffer_t buffer, VALUE string, char check_null) {
result_t status = check_string(RSTRING_PTR(string), RSTRING_LEN(string),
0, check_null);
if (status == HAS_NULL) {
buffer_free(buffer);
rb_raise(InvalidDocument, "Key names / regex patterns must not contain the NULL byte");
}
SAFE_WRITE(buffer, RSTRING_PTR(string), RSTRING_LEN(string)); SAFE_WRITE(buffer, RSTRING_PTR(string), RSTRING_LEN(string));
} }
#else #else
#define STR_NEW(p,n) rb_str_new((p), (n)) #define STR_NEW(p,n) rb_str_new((p), (n))
/* MUST call TO_UTF8 before calling write_utf8. */ /* MUST call TO_UTF8 before calling write_utf8. */
#define TO_UTF8(string) (string) #define TO_UTF8(string) (string)
static void write_utf8(buffer_t buffer, VALUE string) { static void write_utf8(buffer_t buffer, VALUE string, char check_null) {
if (!is_legal_utf8_string(RSTRING_PTR(string), RSTRING_LEN(string))) { result_t status = check_string(RSTRING_PTR(string), RSTRING_LEN(string),
1, check_null);
if (status == HAS_NULL) {
buffer_free(buffer);
rb_raise(InvalidDocument, "Key names / regex patterns must not contain the NULL byte");
} else if (status == NOT_UTF_8) {
buffer_free(buffer); buffer_free(buffer);
rb_raise(InvalidStringEncoding, "String not valid UTF-8"); rb_raise(InvalidStringEncoding, "String not valid UTF-8");
} }
@ -113,9 +124,8 @@ static void write_utf8(buffer_t buffer, VALUE string) {
#endif #endif
// this sucks too. // this sucks too.
#ifndef RREGEXP_SRC_PTR #ifndef RREGEXP_SRC
#define RREGEXP_SRC_PTR(r) RREGEXP(r)->str #define RREGEXP_SRC(r) rb_str_new(RREGEXP((r))->str, RREGEXP((r))->len)
#define RREGEXP_SRC_LEN(r) RREGEXP(r)->len
#endif #endif
static char zero = 0; static char zero = 0;
@ -136,7 +146,7 @@ static VALUE pack_extra(buffer_t buffer, VALUE check_keys) {
static void write_name_and_type(buffer_t buffer, VALUE name, char type) { static void write_name_and_type(buffer_t buffer, VALUE name, char type) {
SAFE_WRITE(buffer, &type, 1); SAFE_WRITE(buffer, &type, 1);
name = TO_UTF8(name); name = TO_UTF8(name);
write_utf8(buffer, name); write_utf8(buffer, name, 1);
SAFE_WRITE(buffer, &zero, 1); SAFE_WRITE(buffer, &zero, 1);
} }
@ -286,7 +296,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
value = TO_UTF8(value); value = TO_UTF8(value);
length = RSTRING_LEN(value) + 1; length = RSTRING_LEN(value) + 1;
SAFE_WRITE(buffer, (char*)&length, 4); SAFE_WRITE(buffer, (char*)&length, 4);
write_utf8(buffer, value); write_utf8(buffer, value, 0);
SAFE_WRITE(buffer, &zero, 1); SAFE_WRITE(buffer, &zero, 1);
break; break;
} }
@ -372,14 +382,14 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
} }
case T_REGEXP: case T_REGEXP:
{ {
int length = RREGEXP_SRC_LEN(value); VALUE pattern = RREGEXP_SRC(value);
char* pattern = (char*)RREGEXP_SRC_PTR(value);
long flags = RREGEXP(value)->ptr->options; long flags = RREGEXP(value)->ptr->options;
VALUE has_extra; VALUE has_extra;
write_name_and_type(buffer, key, 0x0B); write_name_and_type(buffer, key, 0x0B);
SAFE_WRITE(buffer, pattern, length); pattern = TO_UTF8(pattern);
write_utf8(buffer, pattern, 1);
SAFE_WRITE(buffer, &zero, 1); SAFE_WRITE(buffer, &zero, 1);
if (flags & IGNORECASE) { if (flags & IGNORECASE) {
@ -497,8 +507,8 @@ static VALUE get_value(const char* buffer, int* position, int type) {
case 13: case 13:
{ {
int value_length; int value_length;
value_length = *(int*)(buffer + *position) - 1;
*position += 4; *position += 4;
value_length = strlen(buffer + *position);
value = STR_NEW(buffer + *position, value_length); value = STR_NEW(buffer + *position, value_length);
*position += value_length + 1; *position += value_length + 1;
break; break;
@ -508,10 +518,11 @@ static VALUE get_value(const char* buffer, int* position, int type) {
int size; int size;
memcpy(&size, buffer + *position, 4); memcpy(&size, buffer + *position, 4);
if (strcmp(buffer + *position + 5, "$ref") == 0) { // DBRef if (strcmp(buffer + *position + 5, "$ref") == 0) { // DBRef
int offset = *position + 14; int offset = *position + 10;
VALUE argv[2]; VALUE argv[2];
int collection_length = strlen(buffer + offset); int collection_length = *(int*)(buffer + offset) - 1;
char id_type; char id_type;
offset += 4;
argv[0] = STR_NEW(buffer + offset, collection_length); argv[0] = STR_NEW(buffer + offset, collection_length);
offset += collection_length + 1; offset += collection_length + 1;
@ -637,8 +648,8 @@ static VALUE get_value(const char* buffer, int* position, int type) {
{ {
int collection_length; int collection_length;
VALUE collection, str, oid, id, argv[2]; VALUE collection, str, oid, id, argv[2];
collection_length = *(int*)(buffer + *position) - 1;
*position += 4; *position += 4;
collection_length = strlen(buffer + *position);
collection = STR_NEW(buffer + *position, collection_length); collection = STR_NEW(buffer + *position, collection_length);
*position += collection_length + 1; *position += collection_length + 1;
@ -664,8 +675,9 @@ static VALUE get_value(const char* buffer, int* position, int type) {
{ {
int code_length, scope_size; int code_length, scope_size;
VALUE code, scope, argv[2]; VALUE code, scope, argv[2];
*position += 8; *position += 4;
code_length = strlen(buffer + *position); code_length = *(int*)(buffer + *position) - 1;
*position += 4;
code = STR_NEW(buffer + *position, code_length); code = STR_NEW(buffer + *position, code_length);
*position += code_length + 1; *position += code_length + 1;

View File

@ -14,8 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#include "encoding_helpers.h"
/* /*
* Copyright 2001 Unicode, Inc. * Portions Copyright 2001 Unicode, Inc.
* *
* Disclaimer * Disclaimer
* *
@ -85,23 +87,32 @@ static unsigned char isLegalUTF8(const unsigned char* source, int length) {
return 1; return 1;
} }
/* --------------------------------------------------------------------- */ result_t check_string(const unsigned char* string, const int length,
const char check_utf8, const char check_null) {
/*
* Return whether a string containing UTF-8 is legal.
*/
unsigned char is_legal_utf8_string(const unsigned char* string, const int length) {
int position = 0; int position = 0;
/* By default we go character by character. Will be different for checking
* UTF-8 */
int sequence_length = 1;
if (!check_utf8 && !check_null) {
return VALID;
}
while (position < length) { while (position < length) {
int sequence_length = trailingBytesForUTF8[*(string + position)] + 1; if (check_null && *(string + position) == 0) {
return HAS_NULL;
}
if (check_utf8) {
sequence_length = trailingBytesForUTF8[*(string + position)] + 1;
if ((position + sequence_length) > length) { if ((position + sequence_length) > length) {
return 0; return NOT_UTF_8;
} }
if (!isLegalUTF8(string + position, sequence_length)) { if (!isLegalUTF8(string + position, sequence_length)) {
return 0; return NOT_UTF_8;
}
} }
position += sequence_length; position += sequence_length;
} }
return 1;
return VALID;
} }

View File

@ -17,6 +17,13 @@
#ifndef ENCODING_HELPERS_H #ifndef ENCODING_HELPERS_H
#define ENCODING_HELPERS_H #define ENCODING_HELPERS_H
unsigned char is_legal_utf8_string(const unsigned char* string, const int length); typedef enum {
VALID,
NOT_UTF_8,
HAS_NULL
} result_t;
result_t check_string(const unsigned char* string, const int length,
const char check_utf8, const char check_null);
#endif #endif

View File

@ -73,7 +73,7 @@ class BSON_RUBY
end end
def self.serialize_key(buf, key) def self.serialize_key(buf, key)
raise InvalidDocument, "Key names must not contain the NULL byte" if key.include? 0 raise InvalidDocument, "Key names / regex patterns must not contain the NULL byte" if key.include? 0
self.serialize_cstr(buf, key) self.serialize_cstr(buf, key)
end end