1717import datetime
1818import sys
1919import tempfile
20+ from collections import OrderedDict
2021from decimal import Decimal
2122from random import random
2223
3637 TypeEncoder , TypeRegistry )
3738from bson .errors import InvalidDocument
3839from bson .int64 import Int64
40+ from bson .raw_bson import RawBSONDocument
3941from bson .py3compat import text_type
4042
4143from gridfs import GridIn , GridOut
4244
4345from pymongo .collection import ReturnDocument
4446from pymongo .errors import DuplicateKeyError
47+ from pymongo .message import _CursorAddress
4548
4649from test import client_context , unittest
4750from test .test_client import IntegrationTest
48- from test .utils import ignore_deprecations
51+ from test .utils import ignore_deprecations , rs_client
4952
5053
5154class DecimalEncoder (TypeEncoder ):
@@ -115,6 +118,14 @@ def transform_bson(self, value):
115118 [UppercaseTextDecoder (),]))
116119
117120
121+ def type_obfuscating_decoder_factory (rt_type ):
122+ class ResumeTokenToNanDecoder (TypeDecoder ):
123+ bson_type = rt_type
124+ def transform_bson (self , value ):
125+ return "NaN"
126+ return ResumeTokenToNanDecoder
127+
128+
118129class CustomBSONTypeTests (object ):
119130 def roundtrip (self , doc ):
120131 bsonbytes = BSON ().encode (doc , codec_options = self .codecopts )
@@ -549,7 +560,7 @@ def test_command_errors_w_custom_type_decoder(self):
549560 def test_find_w_custom_type_decoder (self ):
550561 db = self .db
551562 input_docs = [
552- {'x' : Int64 (k )} for k in [1.0 , 2.0 , 3.0 ]]
563+ {'x' : Int64 (k )} for k in [1 , 2 , 3 ]]
553564 for doc in input_docs :
554565 db .test .insert_one (doc )
555566
@@ -558,6 +569,24 @@ def test_find_w_custom_type_decoder(self):
558569 for doc in test .find ({}, batch_size = 1 ):
559570 self .assertIsInstance (doc ['x' ], UndecipherableInt64Type )
560571
572+ def test_find_w_custom_type_decoder_and_document_class (self ):
573+ def run_test (doc_cls ):
574+ db = self .db
575+ input_docs = [
576+ {'x' : Int64 (k )} for k in [1 , 2 , 3 ]]
577+ for doc in input_docs :
578+ db .test .insert_one (doc )
579+
580+ test = db .get_collection ('test' , codec_options = CodecOptions (
581+ type_registry = TypeRegistry ([UndecipherableIntDecoder ()]),
582+ document_class = doc_cls ))
583+ for doc in test .find ({}, batch_size = 1 ):
584+ self .assertIsInstance (doc , doc_cls )
585+ self .assertIsInstance (doc ['x' ], UndecipherableInt64Type )
586+
587+ for doc_cls in [RawBSONDocument , OrderedDict ]:
588+ run_test (doc_cls )
589+
561590 @client_context .require_version_max (4 , 1 , 0 , - 1 )
562591 def test_group_w_custom_type (self ):
563592 db = self .db
@@ -709,5 +738,155 @@ def test_grid_out_custom_opts(self):
709738 self .assertRaises (AttributeError , setattr , two , attr , 5 )
710739
711740
741+ class ChangeStreamsWCustomTypesTestMixin (object ):
742+ def change_stream (self , * args , ** kwargs ):
743+ return self .watched_target .watch (* args , ** kwargs )
744+
745+ def insert_and_check (self , change_stream , insert_doc ,
746+ expected_doc ):
747+ self .input_target .insert_one (insert_doc )
748+ change = next (change_stream )
749+ self .assertEqual (change ['fullDocument' ], expected_doc )
750+
751+ def kill_change_stream_cursor (self , change_stream ):
752+ # Cause a cursor not found error on the next getMore.
753+ cursor = change_stream ._cursor
754+ address = _CursorAddress (cursor .address , cursor ._CommandCursor__ns )
755+ client = self .input_target .database .client
756+ client ._close_cursor_now (cursor .cursor_id , address )
757+
758+ def test_simple (self ):
759+ codecopts = CodecOptions (type_registry = TypeRegistry ([
760+ UndecipherableIntEncoder (), UppercaseTextDecoder ()]))
761+ self .create_targets (codec_options = codecopts )
762+
763+ input_docs = [
764+ {'_id' : UndecipherableInt64Type (1 ), 'data' : 'hello' },
765+ {'_id' : 2 , 'data' : 'world' },
766+ {'_id' : UndecipherableInt64Type (3 ), 'data' : '!' },]
767+ expected_docs = [
768+ {'_id' : 1 , 'data' : 'HELLO' },
769+ {'_id' : 2 , 'data' : 'WORLD' },
770+ {'_id' : 3 , 'data' : '!' },]
771+
772+ change_stream = self .change_stream ()
773+
774+ self .insert_and_check (change_stream , input_docs [0 ], expected_docs [0 ])
775+ self .kill_change_stream_cursor (change_stream )
776+ self .insert_and_check (change_stream , input_docs [1 ], expected_docs [1 ])
777+ self .kill_change_stream_cursor (change_stream )
778+ self .insert_and_check (change_stream , input_docs [2 ], expected_docs [2 ])
779+
780+ def test_break_resume_token (self ):
781+ # Get one document from a change stream to determine resumeToken type.
782+ self .create_targets ()
783+ change_stream = self .change_stream ()
784+ self .input_target .insert_one ({"data" : "test" })
785+ change = next (change_stream )
786+ resume_token_decoder = type_obfuscating_decoder_factory (
787+ type (change ['_id' ]['_data' ]))
788+
789+ # Custom-decoding the resumeToken type breaks resume tokens.
790+ codecopts = CodecOptions (type_registry = TypeRegistry ([
791+ resume_token_decoder (), UndecipherableIntEncoder ()]))
792+
793+ # Re-create targets, change stream and proceed.
794+ self .create_targets (codec_options = codecopts )
795+
796+ docs = [{'_id' : 1 }, {'_id' : 2 }, {'_id' : 3 }]
797+
798+ change_stream = self .change_stream ()
799+ self .insert_and_check (change_stream , docs [0 ], docs [0 ])
800+ self .kill_change_stream_cursor (change_stream )
801+ self .insert_and_check (change_stream , docs [1 ], docs [1 ])
802+ self .kill_change_stream_cursor (change_stream )
803+ self .insert_and_check (change_stream , docs [2 ], docs [2 ])
804+
805+ def test_document_class (self ):
806+ def run_test (doc_cls ):
807+ codecopts = CodecOptions (type_registry = TypeRegistry ([
808+ UppercaseTextDecoder (), UndecipherableIntEncoder ()]),
809+ document_class = doc_cls )
810+
811+ self .create_targets (codec_options = codecopts )
812+ change_stream = self .change_stream ()
813+
814+ doc = {'a' : UndecipherableInt64Type (101 ), 'b' : 'xyz' }
815+ self .input_target .insert_one (doc )
816+ change = next (change_stream )
817+
818+ self .assertIsInstance (change , doc_cls )
819+ self .assertEqual (change ['fullDocument' ]['a' ], 101 )
820+ self .assertEqual (change ['fullDocument' ]['b' ], 'XYZ' )
821+
822+ for doc_cls in [OrderedDict , RawBSONDocument ]:
823+ run_test (doc_cls )
824+
825+
826+ class TestCollectionChangeStreamsWCustomTypes (
827+ IntegrationTest , ChangeStreamsWCustomTypesTestMixin ):
828+ @classmethod
829+ @client_context .require_version_min (3 , 6 , 0 )
830+ @client_context .require_no_mmap
831+ @client_context .require_no_standalone
832+ def setUpClass (cls ):
833+ super (TestCollectionChangeStreamsWCustomTypes , cls ).setUpClass ()
834+
835+ def tearDown (self ):
836+ self .input_target .drop ()
837+
838+ def create_targets (self , * args , ** kwargs ):
839+ self .watched_target = self .db .get_collection (
840+ 'test' , * args , ** kwargs )
841+ self .input_target = self .watched_target
842+ # Insert a record to ensure db, coll are created.
843+ self .input_target .insert_one ({'data' : 'dummy' })
844+
845+
846+ class TestDatabaseChangeStreamsWCustomTypes (
847+ IntegrationTest , ChangeStreamsWCustomTypesTestMixin ):
848+ @classmethod
849+ @client_context .require_version_min (4 , 0 , 0 )
850+ @client_context .require_no_mmap
851+ @client_context .require_no_standalone
852+ def setUpClass (cls ):
853+ super (TestDatabaseChangeStreamsWCustomTypes , cls ).setUpClass ()
854+
855+ def tearDown (self ):
856+ self .input_target .drop ()
857+ self .client .drop_database (self .watched_target )
858+
859+ def create_targets (self , * args , ** kwargs ):
860+ self .watched_target = self .client .get_database (
861+ self .db .name , * args , ** kwargs )
862+ self .input_target = self .watched_target .test
863+ # Insert a record to ensure db, coll are created.
864+ self .input_target .insert_one ({'data' : 'dummy' })
865+
866+
867+ class TestClusterChangeStreamsWCustomTypes (
868+ IntegrationTest , ChangeStreamsWCustomTypesTestMixin ):
869+ @classmethod
870+ @client_context .require_version_min (4 , 0 , 0 )
871+ @client_context .require_no_mmap
872+ @client_context .require_no_standalone
873+ def setUpClass (cls ):
874+ super (TestClusterChangeStreamsWCustomTypes , cls ).setUpClass ()
875+
876+ def tearDown (self ):
877+ self .input_target .drop ()
878+ self .client .drop_database (self .db )
879+
880+ def create_targets (self , * args , ** kwargs ):
881+ codec_options = kwargs .pop ('codec_options' , None )
882+ if codec_options :
883+ kwargs ['type_registry' ] = codec_options .type_registry
884+ kwargs ['document_class' ] = codec_options .document_class
885+ self .watched_target = rs_client (* args , ** kwargs )
886+ self .input_target = self .watched_target [self .db .name ].test
887+ # Insert a record to ensure db, coll are created.
888+ self .input_target .insert_one ({'data' : 'dummy' })
889+
890+
712891if __name__ == "__main__" :
713892 unittest .main ()
0 commit comments