@@ -482,6 +482,24 @@ def test_create():
482482 # Index wasn't created, only the default index on _id
483483 self .assertEqual (1 , len (db .test .index_information ()))
484484
485+ # Get the plan dynamically because the explain format will change.
486+ def get_plan_stage (self , root , stage ):
487+ if root .get ('stage' ) == stage :
488+ return root
489+ elif "inputStage" in root :
490+ return self .get_plan_stage (root ['inputStage' ], stage )
491+ elif "inputStages" in root :
492+ for i in root ['inputStages' ]:
493+ stage = self .get_plan_stage (i , stage )
494+ if stage :
495+ return stage
496+ elif "shards" in root :
497+ for i in root ['shards' ]:
498+ stage = self .get_plan_stage (i ['winningPlan' ], stage )
499+ if stage :
500+ return stage
501+ return {}
502+
485503 @client_context .require_version_min (3 , 1 , 9 , - 1 )
486504 def test_index_filter (self ):
487505 db = self .db
@@ -506,34 +524,40 @@ def test_index_filter(self):
506524 db .test .insert_one ({"x" : 6 , "a" : 1 })
507525
508526 # Operations that use the partial index.
509- explain = db .test .find (
510- {"x" : 6 , "a" : 1 }).explain ()['queryPlanner' ]['winningPlan' ]
511- self .assertEqual ("x_1" , explain .get ('inputStage' , {}).get ('indexName' ))
512- self .assertTrue (explain .get ('inputStage' , {}).get ('isPartial' ))
513- explain = db .test .find (
514- {"x" : {"$gt" : 1 }, "a" : 1 }).explain ()['queryPlanner' ]['winningPlan' ]
515- self .assertEqual ("x_1" , explain .get ('inputStage' , {}).get ('indexName' ))
516- self .assertTrue (explain .get ('inputStage' , {}).get ('isPartial' ))
517- explain = db .test .find (
518- {"x" : 6 ,
519- "a" : {"$lte" : 1 }}).explain ()['queryPlanner' ]['winningPlan' ]
520- self .assertEqual ("x_1" , explain .get ('inputStage' , {}).get ('indexName' ))
521- self .assertTrue (explain .get ('inputStage' , {}).get ('isPartial' ))
527+ explain = db .test .find ({"x" : 6 , "a" : 1 }).explain ()
528+ stage = self .get_plan_stage (explain ['queryPlanner' ]['winningPlan' ],
529+ 'IXSCAN' )
530+ self .assertEqual ("x_1" , stage .get ('indexName' ))
531+ self .assertTrue (stage .get ('isPartial' ))
532+
533+ explain = db .test .find ({"x" : {"$gt" : 1 }, "a" : 1 }).explain ()
534+ stage = self .get_plan_stage (explain ['queryPlanner' ]['winningPlan' ],
535+ 'IXSCAN' )
536+ self .assertEqual ("x_1" , stage .get ('indexName' ))
537+ self .assertTrue (stage .get ('isPartial' ))
538+
539+ explain = db .test .find ({"x" : 6 , "a" : {"$lte" : 1 }}).explain ()
540+ stage = self .get_plan_stage (explain ['queryPlanner' ]['winningPlan' ],
541+ 'IXSCAN' )
542+ self .assertEqual ("x_1" , stage .get ('indexName' ))
543+ self .assertTrue (stage .get ('isPartial' ))
522544
523545 # Operations that do not use the partial index.
524- explain = db .test .find (
525- {"x" : 6 ,
526- "a" : {"$lte" : 1.6 }}).explain ()['queryPlanner' ]['winningPlan' ]
527- self .assertEqual ("COLLSCAN" , explain .get ('stage' ))
528- explain = db .test .find (
529- {"x" : 6 }).explain ()['queryPlanner' ]['winningPlan' ]
530- self .assertEqual ("COLLSCAN" , explain .get ('stage' ))
546+ explain = db .test .find ({"x" : 6 , "a" : {"$lte" : 1.6 }}).explain ()
547+ stage = self .get_plan_stage (explain ['queryPlanner' ]['winningPlan' ],
548+ 'COLLSCAN' )
549+ self .assertNotEqual ({}, stage )
550+ explain = db .test .find ({"x" : 6 }).explain ()
551+ stage = self .get_plan_stage (explain ['queryPlanner' ]['winningPlan' ],
552+ 'COLLSCAN' )
553+ self .assertNotEqual ({}, stage )
531554
532555 # Test drop_indexes.
533556 db .test .drop_index ("x_1" )
534- explain = db .test .find (
535- {"x" : 6 , "a" : 1 }).explain ()['queryPlanner' ]['winningPlan' ]
536- self .assertEqual ("COLLSCAN" , explain .get ('stage' ))
557+ explain = db .test .find ({"x" : 6 , "a" : 1 }).explain ()
558+ stage = self .get_plan_stage (explain ['queryPlanner' ]['winningPlan' ],
559+ 'COLLSCAN' )
560+ self .assertNotEqual ({}, stage )
537561
538562 def test_field_selection (self ):
539563 db = self .db
0 commit comments