Skip to content

Commit 6a14d70

Browse files
committed
PYTHON-1044 - Fix up unknown BSON type handing
1 parent 46d9cf9 commit 6a14d70

File tree

3 files changed

+98
-36
lines changed

3 files changed

+98
-36
lines changed

bson/__init__.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,14 @@
9393
_UNPACK_TIMESTAMP = struct.Struct("<II").unpack
9494

9595

96-
def _get_int(data, position, dummy0, dummy1):
96+
def _raise_unknown_type(element_type, element_name):
97+
"""Unknown type helper."""
98+
raise InvalidBSON("Detected unknown BSON type %r for fieldname '%s'. Are "
99+
"you using the latest driver version?" % (
100+
element_type, element_name))
101+
102+
103+
def _get_int(data, position, dummy0, dummy1, dummy2):
97104
"""Decode a BSON int32 to python int."""
98105
end = position + 4
99106
return _UNPACK_INT(data[position:end])[0], end
@@ -106,13 +113,13 @@ def _get_c_string(data, position, opts):
106113
opts.unicode_decode_error_handler, True)[0], end + 1
107114

108115

109-
def _get_float(data, position, dummy0, dummy1):
116+
def _get_float(data, position, dummy0, dummy1, dummy2):
110117
"""Decode a BSON double to python float."""
111118
end = position + 8
112119
return _UNPACK_FLOAT(data[position:end])[0], end
113120

114121

115-
def _get_string(data, position, obj_end, opts):
122+
def _get_string(data, position, obj_end, opts, dummy):
116123
"""Decode a BSON string to python unicode string."""
117124
length = _UNPACK_INT(data[position:position + 4])[0]
118125
position += 4
@@ -125,7 +132,7 @@ def _get_string(data, position, obj_end, opts):
125132
opts.unicode_decode_error_handler, True)[0], end + 1
126133

127134

128-
def _get_object(data, position, obj_end, opts):
135+
def _get_object(data, position, obj_end, opts, dummy):
129136
"""Decode a BSON subdocument to opts.document_class or bson.dbref.DBRef."""
130137
obj_size = _UNPACK_INT(data[position:position + 4])[0]
131138
end = position + obj_size - 1
@@ -146,7 +153,7 @@ def _get_object(data, position, obj_end, opts):
146153
return obj, position
147154

148155

149-
def _get_array(data, position, obj_end, opts):
156+
def _get_array(data, position, obj_end, opts, element_name):
150157
"""Decode a BSON array to python list."""
151158
size = _UNPACK_INT(data[position:position + 4])[0]
152159
end = position + size - 1
@@ -166,12 +173,16 @@ def _get_array(data, position, obj_end, opts):
166173
element_type = data[position:position + 1]
167174
# Just skip the keys.
168175
position = index(b'\x00', position) + 1
169-
value, position = getter[element_type](data, position, obj_end, opts)
176+
try:
177+
value, position = getter[element_type](
178+
data, position, obj_end, opts, element_name)
179+
except KeyError:
180+
_raise_unknown_type(element_type, element_name)
170181
append(value)
171182
return result, position + 1
172183

173184

174-
def _get_binary(data, position, dummy, opts):
185+
def _get_binary(data, position, dummy0, opts, dummy1):
175186
"""Decode a BSON binary to bson.binary.Binary or python UUID."""
176187
length, subtype = _UNPACK_LENGTH_SUBTYPE(data[position:position + 5])
177188
position += 5
@@ -203,19 +214,19 @@ def _get_binary(data, position, dummy, opts):
203214
return value, end
204215

205216

206-
def _get_oid(data, position, dummy0, dummy1):
217+
def _get_oid(data, position, dummy0, dummy1, dummy2):
207218
"""Decode a BSON ObjectId to bson.objectid.ObjectId."""
208219
end = position + 12
209220
return ObjectId(data[position:end]), end
210221

211222

212-
def _get_boolean(data, position, dummy0, dummy1):
223+
def _get_boolean(data, position, dummy0, dummy1, dummy2):
213224
"""Decode a BSON true/false to python True/False."""
214225
end = position + 1
215226
return data[position:end] == b"\x01", end
216227

217228

218-
def _get_date(data, position, dummy, opts):
229+
def _get_date(data, position, dummy0, opts, dummy1):
219230
"""Decode a BSON datetime to python datetime.datetime."""
220231
end = position + 8
221232
millis = _UNPACK_LONG(data[position:end])[0]
@@ -233,42 +244,44 @@ def _get_date(data, position, dummy, opts):
233244
return dt, end
234245

235246

236-
def _get_code(data, position, obj_end, opts):
247+
def _get_code(data, position, obj_end, opts, element_name):
237248
"""Decode a BSON code to bson.code.Code."""
238-
code, position = _get_string(data, position, obj_end, opts)
249+
code, position = _get_string(data, position, obj_end, opts, element_name)
239250
return Code(code), position
240251

241252

242-
def _get_code_w_scope(data, position, obj_end, opts):
253+
def _get_code_w_scope(data, position, obj_end, opts, element_name):
243254
"""Decode a BSON code_w_scope to bson.code.Code."""
244-
code, position = _get_string(data, position + 4, obj_end, opts)
245-
scope, position = _get_object(data, position, obj_end, opts)
255+
code, position = _get_string(
256+
data, position + 4, obj_end, opts, element_name)
257+
scope, position = _get_object(data, position, obj_end, opts, element_name)
246258
return Code(code, scope), position
247259

248260

249-
def _get_regex(data, position, dummy0, opts):
261+
def _get_regex(data, position, dummy0, opts, dummy1):
250262
"""Decode a BSON regex to bson.regex.Regex or a python pattern object."""
251263
pattern, position = _get_c_string(data, position, opts)
252264
bson_flags, position = _get_c_string(data, position, opts)
253265
bson_re = Regex(pattern, bson_flags)
254266
return bson_re, position
255267

256268

257-
def _get_ref(data, position, obj_end, opts):
269+
def _get_ref(data, position, obj_end, opts, element_name):
258270
"""Decode (deprecated) BSON DBPointer to bson.dbref.DBRef."""
259-
collection, position = _get_string(data, position, obj_end, opts)
260-
oid, position = _get_oid(data, position, obj_end, opts)
271+
collection, position = _get_string(
272+
data, position, obj_end, opts, element_name)
273+
oid, position = _get_oid(data, position, obj_end, opts, element_name)
261274
return DBRef(collection, oid), position
262275

263276

264-
def _get_timestamp(data, position, dummy0, dummy1):
277+
def _get_timestamp(data, position, dummy0, dummy1, dummy2):
265278
"""Decode a BSON timestamp to bson.timestamp.Timestamp."""
266279
end = position + 8
267280
inc, timestamp = _UNPACK_TIMESTAMP(data[position:end])
268281
return Timestamp(timestamp, inc), end
269282

270283

271-
def _get_int64(data, position, dummy0, dummy1):
284+
def _get_int64(data, position, dummy0, dummy1, dummy2):
272285
"""Decode a BSON int64 to bson.int64.Int64."""
273286
end = position + 8
274287
return Int64(_UNPACK_LONG(data[position:end])[0]), end
@@ -285,11 +298,11 @@ def _get_int64(data, position, dummy0, dummy1):
285298
BSONOBJ: _get_object,
286299
BSONARR: _get_array,
287300
BSONBIN: _get_binary,
288-
BSONUND: lambda w, x, y, z: (None, x), # Deprecated undefined
301+
BSONUND: lambda v, w, x, y, z: (None, w), # Deprecated undefined
289302
BSONOID: _get_oid,
290303
BSONBOO: _get_boolean,
291304
BSONDAT: _get_date,
292-
BSONNUL: lambda w, x, y, z: (None, x),
305+
BSONNUL: lambda v, w, x, y, z: (None, w),
293306
BSONRGX: _get_regex,
294307
BSONREF: _get_ref, # Deprecated DBPointer
295308
BSONCOD: _get_code,
@@ -298,17 +311,21 @@ def _get_int64(data, position, dummy0, dummy1):
298311
BSONINT: _get_int,
299312
BSONTIM: _get_timestamp,
300313
BSONLON: _get_int64,
301-
BSONMIN: lambda w, x, y, z: (MinKey(), x),
302-
BSONMAX: lambda w, x, y, z: (MaxKey(), x)}
314+
BSONMIN: lambda v, w, x, y, z: (MinKey(), w),
315+
BSONMAX: lambda v, w, x, y, z: (MaxKey(), w)}
303316

304317

305318
def _element_to_dict(data, position, obj_end, opts):
306319
"""Decode a single key, value pair."""
307320
element_type = data[position:position + 1]
308321
position += 1
309322
element_name, position = _get_c_string(data, position, opts)
310-
value, position = _ELEMENT_GETTER[element_type](data,
311-
position, obj_end, opts)
323+
try:
324+
value, position = _ELEMENT_GETTER[element_type](data, position,
325+
obj_end, opts,
326+
element_name)
327+
except KeyError:
328+
_raise_unknown_type(element_type, element_name)
312329
return element_name, value, position
313330
if _USE_C:
314331
_element_to_dict = _cbson._element_to_dict

bson/_cbsonmodule.c

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,7 +1645,7 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
16451645
return result;
16461646
}
16471647

1648-
static PyObject* get_value(PyObject* self, const char* buffer,
1648+
static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
16491649
unsigned* position, unsigned char type,
16501650
unsigned max, const codec_options_t* options) {
16511651
struct module_state *state = GETSTATE(self);
@@ -1815,7 +1815,7 @@ static PyObject* get_value(PyObject* self, const char* buffer,
18151815
Py_DECREF(value);
18161816
goto invalid;
18171817
}
1818-
to_append = get_value(self, buffer, position, bson_type,
1818+
to_append = get_value(self, name, buffer, position, bson_type,
18191819
max - (unsigned)key_size, options);
18201820
Py_LeaveRecursiveCall();
18211821
if (!to_append) {
@@ -2343,11 +2343,38 @@ static PyObject* get_value(PyObject* self, const char* buffer,
23432343
}
23442344
default:
23452345
{
2346-
PyObject* InvalidDocument = _error("InvalidDocument");
2347-
if (InvalidDocument) {
2348-
PyErr_SetString(InvalidDocument,
2349-
"no c decoder for this type yet");
2350-
Py_DECREF(InvalidDocument);
2346+
PyObject* InvalidBSON = _error("InvalidBSON");
2347+
if (InvalidBSON) {
2348+
PyObject* bobj = PyBytes_FromFormat("%c", type);
2349+
if (bobj) {
2350+
PyObject* repr = PyObject_Repr(bobj);
2351+
Py_DECREF(bobj);
2352+
/*
2353+
* See http://bugs.python.org/issue22023 for why we can't
2354+
* just use PyUnicode_FromFormat with %S or %R to do this
2355+
* work.
2356+
*/
2357+
if (repr) {
2358+
PyObject* left = PyUnicode_FromString(
2359+
"Detected unknown BSON type ");
2360+
if (left) {
2361+
PyObject* lmsg = PyUnicode_Concat(left, repr);
2362+
Py_DECREF(left);
2363+
if (lmsg) {
2364+
PyObject* errmsg = PyUnicode_FromFormat(
2365+
"%U for fieldname '%U'. Are you using the "
2366+
"latest driver version?", lmsg, name);
2367+
if (errmsg) {
2368+
PyErr_SetObject(InvalidBSON, errmsg);
2369+
Py_DECREF(errmsg);
2370+
}
2371+
Py_DECREF(lmsg);
2372+
}
2373+
}
2374+
Py_DECREF(repr);
2375+
}
2376+
}
2377+
Py_DECREF(InvalidBSON);
23512378
}
23522379
goto invalid;
23532380
}
@@ -2457,10 +2484,10 @@ static int _element_to_dict(PyObject* self, const char* string,
24572484
return -1;
24582485
}
24592486
position += (unsigned)name_length + 1;
2460-
*value = get_value(self, string, &position, type,
2487+
*value = get_value(self, *name, string, &position, type,
24612488
max - position, options);
24622489
if (!*value) {
2463-
Py_DECREF(name);
2490+
Py_DECREF(*name);
24642491
return -1;
24652492
}
24662493
return position;

test/test_bson.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,24 @@ def test_basic_encode(self):
401401
b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00"
402402
b"\x00")
403403

404+
def test_unknown_type(self):
405+
# Repr value differs with major python version
406+
part = "type %r for fieldname 'foo'" % (b'\x13',)
407+
docs = [
408+
b'\x0e\x00\x00\x00\x13foo\x00\x01\x00\x00\x00\x00',
409+
(b'\x16\x00\x00\x00\x04foo\x00\x0c\x00\x00\x00\x130'
410+
b'\x00\x01\x00\x00\x00\x00\x00'),
411+
(b' \x00\x00\x00\x04bar\x00\x16\x00\x00\x00\x030\x00\x0e\x00\x00'
412+
b'\x00\x13foo\x00\x01\x00\x00\x00\x00\x00\x00')]
413+
for bs in docs:
414+
try:
415+
bson.BSON(bs).decode()
416+
except Exception as exc:
417+
self.assertTrue(isinstance(exc, InvalidBSON))
418+
self.assertTrue(part in str(exc))
419+
else:
420+
self.fail("Failed to raise an exception.")
421+
404422
def test_dbpointer(self):
405423
# *Note* - DBPointer and DBRef are *not* the same thing. DBPointer
406424
# is a deprecated BSON type. DBRef is a convention that does not

0 commit comments

Comments
 (0)