null checking for keys and regex patterns, allow nulls for regular strings
This commit is contained in:
parent
9bbaafe03d
commit
65c36ca943
|
@ -72,15 +72,26 @@ static VALUE DigestMD5;
|
||||||
#define STR_NEW(p,n) rb_enc_str_new((p), (n), rb_utf8_encoding())
|
#define STR_NEW(p,n) rb_enc_str_new((p), (n), rb_utf8_encoding())
|
||||||
/* MUST call TO_UTF8 before calling write_utf8. */
|
/* MUST call TO_UTF8 before calling write_utf8. */
|
||||||
#define TO_UTF8(string) rb_str_export_to_enc((string), rb_utf8_encoding())
|
#define TO_UTF8(string) rb_str_export_to_enc((string), rb_utf8_encoding())
|
||||||
static void write_utf8(buffer_t buffer, VALUE string) {
|
static void write_utf8(buffer_t buffer, VALUE string, char check_null) {
|
||||||
|
result_t status = check_string(RSTRING_PTR(string), RSTRING_LEN(string),
|
||||||
|
0, check_null);
|
||||||
|
if (status == HAS_NULL) {
|
||||||
|
buffer_free(buffer);
|
||||||
|
rb_raise(InvalidDocument, "Key names / regex patterns must not contain the NULL byte");
|
||||||
|
}
|
||||||
SAFE_WRITE(buffer, RSTRING_PTR(string), RSTRING_LEN(string));
|
SAFE_WRITE(buffer, RSTRING_PTR(string), RSTRING_LEN(string));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
#define STR_NEW(p,n) rb_str_new((p), (n))
|
#define STR_NEW(p,n) rb_str_new((p), (n))
|
||||||
/* MUST call TO_UTF8 before calling write_utf8. */
|
/* MUST call TO_UTF8 before calling write_utf8. */
|
||||||
#define TO_UTF8(string) (string)
|
#define TO_UTF8(string) (string)
|
||||||
static void write_utf8(buffer_t buffer, VALUE string) {
|
static void write_utf8(buffer_t buffer, VALUE string, char check_null) {
|
||||||
if (!is_legal_utf8_string(RSTRING_PTR(string), RSTRING_LEN(string))) {
|
result_t status = check_string(RSTRING_PTR(string), RSTRING_LEN(string),
|
||||||
|
1, check_null);
|
||||||
|
if (status == HAS_NULL) {
|
||||||
|
buffer_free(buffer);
|
||||||
|
rb_raise(InvalidDocument, "Key names / regex patterns must not contain the NULL byte");
|
||||||
|
} else if (status == NOT_UTF_8) {
|
||||||
buffer_free(buffer);
|
buffer_free(buffer);
|
||||||
rb_raise(InvalidStringEncoding, "String not valid UTF-8");
|
rb_raise(InvalidStringEncoding, "String not valid UTF-8");
|
||||||
}
|
}
|
||||||
|
@ -113,9 +124,8 @@ static void write_utf8(buffer_t buffer, VALUE string) {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// this sucks too.
|
// this sucks too.
|
||||||
#ifndef RREGEXP_SRC_PTR
|
#ifndef RREGEXP_SRC
|
||||||
#define RREGEXP_SRC_PTR(r) RREGEXP(r)->str
|
#define RREGEXP_SRC(r) rb_str_new(RREGEXP((r))->str, RREGEXP((r))->len)
|
||||||
#define RREGEXP_SRC_LEN(r) RREGEXP(r)->len
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static char zero = 0;
|
static char zero = 0;
|
||||||
|
@ -136,7 +146,7 @@ static VALUE pack_extra(buffer_t buffer, VALUE check_keys) {
|
||||||
static void write_name_and_type(buffer_t buffer, VALUE name, char type) {
|
static void write_name_and_type(buffer_t buffer, VALUE name, char type) {
|
||||||
SAFE_WRITE(buffer, &type, 1);
|
SAFE_WRITE(buffer, &type, 1);
|
||||||
name = TO_UTF8(name);
|
name = TO_UTF8(name);
|
||||||
write_utf8(buffer, name);
|
write_utf8(buffer, name, 1);
|
||||||
SAFE_WRITE(buffer, &zero, 1);
|
SAFE_WRITE(buffer, &zero, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,7 +296,7 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
|
||||||
value = TO_UTF8(value);
|
value = TO_UTF8(value);
|
||||||
length = RSTRING_LEN(value) + 1;
|
length = RSTRING_LEN(value) + 1;
|
||||||
SAFE_WRITE(buffer, (char*)&length, 4);
|
SAFE_WRITE(buffer, (char*)&length, 4);
|
||||||
write_utf8(buffer, value);
|
write_utf8(buffer, value, 0);
|
||||||
SAFE_WRITE(buffer, &zero, 1);
|
SAFE_WRITE(buffer, &zero, 1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -372,14 +382,14 @@ static int write_element_allow_id(VALUE key, VALUE value, VALUE extra, int allow
|
||||||
}
|
}
|
||||||
case T_REGEXP:
|
case T_REGEXP:
|
||||||
{
|
{
|
||||||
int length = RREGEXP_SRC_LEN(value);
|
VALUE pattern = RREGEXP_SRC(value);
|
||||||
char* pattern = (char*)RREGEXP_SRC_PTR(value);
|
|
||||||
long flags = RREGEXP(value)->ptr->options;
|
long flags = RREGEXP(value)->ptr->options;
|
||||||
VALUE has_extra;
|
VALUE has_extra;
|
||||||
|
|
||||||
write_name_and_type(buffer, key, 0x0B);
|
write_name_and_type(buffer, key, 0x0B);
|
||||||
|
|
||||||
SAFE_WRITE(buffer, pattern, length);
|
pattern = TO_UTF8(pattern);
|
||||||
|
write_utf8(buffer, pattern, 1);
|
||||||
SAFE_WRITE(buffer, &zero, 1);
|
SAFE_WRITE(buffer, &zero, 1);
|
||||||
|
|
||||||
if (flags & IGNORECASE) {
|
if (flags & IGNORECASE) {
|
||||||
|
@ -497,8 +507,8 @@ static VALUE get_value(const char* buffer, int* position, int type) {
|
||||||
case 13:
|
case 13:
|
||||||
{
|
{
|
||||||
int value_length;
|
int value_length;
|
||||||
|
value_length = *(int*)(buffer + *position) - 1;
|
||||||
*position += 4;
|
*position += 4;
|
||||||
value_length = strlen(buffer + *position);
|
|
||||||
value = STR_NEW(buffer + *position, value_length);
|
value = STR_NEW(buffer + *position, value_length);
|
||||||
*position += value_length + 1;
|
*position += value_length + 1;
|
||||||
break;
|
break;
|
||||||
|
@ -508,10 +518,11 @@ static VALUE get_value(const char* buffer, int* position, int type) {
|
||||||
int size;
|
int size;
|
||||||
memcpy(&size, buffer + *position, 4);
|
memcpy(&size, buffer + *position, 4);
|
||||||
if (strcmp(buffer + *position + 5, "$ref") == 0) { // DBRef
|
if (strcmp(buffer + *position + 5, "$ref") == 0) { // DBRef
|
||||||
int offset = *position + 14;
|
int offset = *position + 10;
|
||||||
VALUE argv[2];
|
VALUE argv[2];
|
||||||
int collection_length = strlen(buffer + offset);
|
int collection_length = *(int*)(buffer + offset) - 1;
|
||||||
char id_type;
|
char id_type;
|
||||||
|
offset += 4;
|
||||||
|
|
||||||
argv[0] = STR_NEW(buffer + offset, collection_length);
|
argv[0] = STR_NEW(buffer + offset, collection_length);
|
||||||
offset += collection_length + 1;
|
offset += collection_length + 1;
|
||||||
|
@ -637,8 +648,8 @@ static VALUE get_value(const char* buffer, int* position, int type) {
|
||||||
{
|
{
|
||||||
int collection_length;
|
int collection_length;
|
||||||
VALUE collection, str, oid, id, argv[2];
|
VALUE collection, str, oid, id, argv[2];
|
||||||
|
collection_length = *(int*)(buffer + *position) - 1;
|
||||||
*position += 4;
|
*position += 4;
|
||||||
collection_length = strlen(buffer + *position);
|
|
||||||
collection = STR_NEW(buffer + *position, collection_length);
|
collection = STR_NEW(buffer + *position, collection_length);
|
||||||
*position += collection_length + 1;
|
*position += collection_length + 1;
|
||||||
|
|
||||||
|
@ -664,8 +675,9 @@ static VALUE get_value(const char* buffer, int* position, int type) {
|
||||||
{
|
{
|
||||||
int code_length, scope_size;
|
int code_length, scope_size;
|
||||||
VALUE code, scope, argv[2];
|
VALUE code, scope, argv[2];
|
||||||
*position += 8;
|
*position += 4;
|
||||||
code_length = strlen(buffer + *position);
|
code_length = *(int*)(buffer + *position) - 1;
|
||||||
|
*position += 4;
|
||||||
code = STR_NEW(buffer + *position, code_length);
|
code = STR_NEW(buffer + *position, code_length);
|
||||||
*position += code_length + 1;
|
*position += code_length + 1;
|
||||||
|
|
||||||
|
|
|
@ -14,8 +14,10 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "encoding_helpers.h"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Copyright 2001 Unicode, Inc.
|
* Portions Copyright 2001 Unicode, Inc.
|
||||||
*
|
*
|
||||||
* Disclaimer
|
* Disclaimer
|
||||||
*
|
*
|
||||||
|
@ -85,23 +87,32 @@ static unsigned char isLegalUTF8(const unsigned char* source, int length) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* --------------------------------------------------------------------- */
|
result_t check_string(const unsigned char* string, const int length,
|
||||||
|
const char check_utf8, const char check_null) {
|
||||||
/*
|
|
||||||
* Return whether a string containing UTF-8 is legal.
|
|
||||||
*/
|
|
||||||
unsigned char is_legal_utf8_string(const unsigned char* string, const int length) {
|
|
||||||
int position = 0;
|
int position = 0;
|
||||||
|
/* By default we go character by character. Will be different for checking
|
||||||
|
* UTF-8 */
|
||||||
|
int sequence_length = 1;
|
||||||
|
|
||||||
|
if (!check_utf8 && !check_null) {
|
||||||
|
return VALID;
|
||||||
|
}
|
||||||
|
|
||||||
while (position < length) {
|
while (position < length) {
|
||||||
int sequence_length = trailingBytesForUTF8[*(string + position)] + 1;
|
if (check_null && *(string + position) == 0) {
|
||||||
if ((position + sequence_length) > length) {
|
return HAS_NULL;
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
if (!isLegalUTF8(string + position, sequence_length)) {
|
if (check_utf8) {
|
||||||
return 0;
|
sequence_length = trailingBytesForUTF8[*(string + position)] + 1;
|
||||||
|
if ((position + sequence_length) > length) {
|
||||||
|
return NOT_UTF_8;
|
||||||
|
}
|
||||||
|
if (!isLegalUTF8(string + position, sequence_length)) {
|
||||||
|
return NOT_UTF_8;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
position += sequence_length;
|
position += sequence_length;
|
||||||
}
|
}
|
||||||
return 1;
|
|
||||||
|
return VALID;
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,13 @@
|
||||||
#ifndef ENCODING_HELPERS_H
|
#ifndef ENCODING_HELPERS_H
|
||||||
#define ENCODING_HELPERS_H
|
#define ENCODING_HELPERS_H
|
||||||
|
|
||||||
unsigned char is_legal_utf8_string(const unsigned char* string, const int length);
|
typedef enum {
|
||||||
|
VALID,
|
||||||
|
NOT_UTF_8,
|
||||||
|
HAS_NULL
|
||||||
|
} result_t;
|
||||||
|
|
||||||
|
result_t check_string(const unsigned char* string, const int length,
|
||||||
|
const char check_utf8, const char check_null);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -73,7 +73,7 @@ class BSON_RUBY
|
||||||
end
|
end
|
||||||
|
|
||||||
def self.serialize_key(buf, key)
|
def self.serialize_key(buf, key)
|
||||||
raise InvalidDocument, "Key names must not contain the NULL byte" if key.include? 0
|
raise InvalidDocument, "Key names / regex patterns must not contain the NULL byte" if key.include? 0
|
||||||
self.serialize_cstr(buf, key)
|
self.serialize_cstr(buf, key)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue