26
26
decode_all ,
27
27
decode_file_iter ,
28
28
decode_iter ,
29
+ RE_TYPE ,
30
+ _BUILT_IN_TYPES ,
29
31
_dict_to_bson ,
30
32
_bson_to_dict )
31
33
from bson .codec_options import (CodecOptions , TypeCodec , TypeDecoder ,
@@ -189,21 +191,24 @@ def run_test(base, attrs, fail):
189
191
else :
190
192
codec ()
191
193
192
- run_test (TypeEncoder , {'python_type' : int ,}, fail = True )
194
+ class MyType (object ):
195
+ pass
196
+
197
+ run_test (TypeEncoder , {'python_type' : MyType ,}, fail = True )
193
198
run_test (TypeEncoder , {'transform_python' : lambda s , x : x }, fail = True )
194
199
run_test (TypeEncoder , {'transform_python' : lambda s , x : x ,
195
- 'python_type' : int }, fail = False )
200
+ 'python_type' : MyType }, fail = False )
196
201
197
202
run_test (TypeDecoder , {'bson_type' : Decimal128 , }, fail = True )
198
203
run_test (TypeDecoder , {'transform_bson' : lambda s , x : x }, fail = True )
199
204
run_test (TypeDecoder , {'transform_bson' : lambda s , x : x ,
200
205
'bson_type' : Decimal128 }, fail = False )
201
206
202
207
run_test (TypeCodec , {'bson_type' : Decimal128 ,
203
- 'python_type' : int }, fail = True )
208
+ 'python_type' : MyType }, fail = True )
204
209
run_test (TypeCodec , {'transform_bson' : lambda s , x : x ,
205
210
'transform_python' : lambda s , x : x }, fail = True )
206
- run_test (TypeCodec , {'python_type' : int ,
211
+ run_test (TypeCodec , {'python_type' : MyType ,
207
212
'transform_python' : lambda s , x : x ,
208
213
'transform_bson' : lambda s , x : x ,
209
214
'bson_type' : Decimal128 }, fail = False )
@@ -215,6 +220,91 @@ def test_type_checks(self):
215
220
self .assertFalse (issubclass (TypeEncoder , TypeDecoder ))
216
221
217
222
223
+ class TestCustomTypeEncoderAndFallbackEncoderTandem (unittest .TestCase ):
224
+ @classmethod
225
+ def setUpClass (cls ):
226
+ class TypeA (object ):
227
+ def __init__ (self , x ):
228
+ self .value = x
229
+
230
+ class TypeB (object ):
231
+ def __init__ (self , x ):
232
+ self .value = x
233
+
234
+ # transforms A, and only A into B
235
+ def fallback_encoder_A2B (value ):
236
+ assert isinstance (value , TypeA )
237
+ return TypeB (value .value )
238
+
239
+ # transforms A, and only A into something encodable
240
+ def fallback_encoder_A2BSON (value ):
241
+ assert isinstance (value , TypeA )
242
+ return value .value
243
+
244
+ # transforms B into something encodable
245
+ class B2BSON (TypeEncoder ):
246
+ python_type = TypeB
247
+ def transform_python (self , value ):
248
+ return value .value
249
+
250
+ # transforms A into B
251
+ # technically, this isn't a proper type encoder as the output is not
252
+ # BSON-encodable.
253
+ class A2B (TypeEncoder ):
254
+ python_type = TypeA
255
+ def transform_python (self , value ):
256
+ return TypeB (value .value )
257
+
258
+ # transforms B into A
259
+ # technically, this isn't a proper type encoder as the output is not
260
+ # BSON-encodable.
261
+ class B2A (TypeEncoder ):
262
+ python_type = TypeB
263
+ def transform_python (self , value ):
264
+ return TypeA (value .value )
265
+
266
+ cls .TypeA = TypeA
267
+ cls .TypeB = TypeB
268
+ cls .fallback_encoder_A2B = staticmethod (fallback_encoder_A2B )
269
+ cls .fallback_encoder_A2BSON = staticmethod (fallback_encoder_A2BSON )
270
+ cls .B2BSON = B2BSON
271
+ cls .B2A = B2A
272
+ cls .A2B = A2B
273
+
274
+ def test_encode_fallback_then_custom (self ):
275
+ codecopts = CodecOptions (type_registry = TypeRegistry (
276
+ [self .B2BSON ()], fallback_encoder = self .fallback_encoder_A2B ))
277
+ testdoc = {'x' : self .TypeA (123 )}
278
+ expected_bytes = BSON .encode ({'x' : 123 })
279
+
280
+ self .assertEqual (BSON .encode (testdoc , codec_options = codecopts ),
281
+ expected_bytes )
282
+
283
+ def test_encode_custom_then_fallback (self ):
284
+ codecopts = CodecOptions (type_registry = TypeRegistry (
285
+ [self .B2A ()], fallback_encoder = self .fallback_encoder_A2BSON ))
286
+ testdoc = {'x' : self .TypeB (123 )}
287
+ expected_bytes = BSON .encode ({'x' : 123 })
288
+
289
+ self .assertEqual (BSON .encode (testdoc , codec_options = codecopts ),
290
+ expected_bytes )
291
+
292
+ def test_chaining_encoders_fails (self ):
293
+ codecopts = CodecOptions (type_registry = TypeRegistry (
294
+ [self .A2B (), self .B2BSON ()]))
295
+
296
+ with self .assertRaises (InvalidDocument ):
297
+ BSON .encode ({'x' : self .TypeA (123 )}, codec_options = codecopts )
298
+
299
+ def test_infinite_loop_exceeds_max_recursion_depth (self ):
300
+ codecopts = CodecOptions (type_registry = TypeRegistry (
301
+ [self .B2A ()], fallback_encoder = self .fallback_encoder_A2B ))
302
+
303
+ # Raises max recursion depth exceeded error
304
+ with self .assertRaises (RuntimeError ):
305
+ BSON .encode ({'x' : self .TypeA (100 )}, codec_options = codecopts )
306
+
307
+
218
308
class TestTypeRegistry (unittest .TestCase ):
219
309
@classmethod
220
310
def setUpClass (cls ):
@@ -347,6 +437,35 @@ def test_type_registry_eq(self):
347
437
self .assertNotEqual (
348
438
TypeRegistry (codec_instances ), TypeRegistry (codec_instances_2 ))
349
439
440
+ def test_builtin_types_override_fails (self ):
441
+ def run_test (base , attrs ):
442
+ msg = ("TypeEncoders cannot change how built-in types "
443
+ "are encoded \(encoder .* transforms type .*\)" )
444
+ for pytype in _BUILT_IN_TYPES :
445
+ attrs .update ({'python_type' : pytype ,
446
+ 'transform_python' : lambda x : x })
447
+ codec = type ('testcodec' , (base , ), attrs )
448
+ codec_instance = codec ()
449
+ with self .assertRaisesRegex (TypeError , msg ):
450
+ TypeRegistry ([codec_instance ,])
451
+
452
+ # Test only some subtypes as not all can be subclassed.
453
+ if pytype in [bool , type (None ), RE_TYPE ,]:
454
+ continue
455
+
456
+ class MyType (pytype ):
457
+ pass
458
+ attrs .update ({'python_type' : MyType ,
459
+ 'transform_python' : lambda x : x })
460
+ codec = type ('testcodec' , (base , ), attrs )
461
+ codec_instance = codec ()
462
+ with self .assertRaisesRegex (TypeError , msg ):
463
+ TypeRegistry ([codec_instance ,])
464
+
465
+ run_test (TypeEncoder , {})
466
+ run_test (TypeCodec , {'bson_type' : Decimal128 ,
467
+ 'transform_bson' : lambda x : x })
468
+
350
469
351
470
if __name__ == "__main__" :
352
471
unittest .main ()
0 commit comments