3030import pprint
3131import sys
3232import builtins
33+ import contextlib
3334from types import ModuleType , MethodType
3435from functools import wraps , partial
3536
@@ -1243,33 +1244,16 @@ def decorate_callable(self, func):
12431244 @wraps (func )
12441245 def patched (* args , ** keywargs ):
12451246 extra_args = []
1246- entered_patchers = []
1247-
1248- exc_info = tuple ()
1249- try :
1247+ with contextlib .ExitStack () as exit_stack :
12501248 for patching in patched .patchings :
1251- arg = patching .__enter__ ()
1252- entered_patchers .append (patching )
1249+ arg = exit_stack .enter_context (patching )
12531250 if patching .attribute_name is not None :
12541251 keywargs .update (arg )
12551252 elif patching .new is DEFAULT :
12561253 extra_args .append (arg )
12571254
12581255 args += tuple (extra_args )
12591256 return func (* args , ** keywargs )
1260- except :
1261- if (patching not in entered_patchers and
1262- _is_started (patching )):
1263- # the patcher may have been started, but an exception
1264- # raised whilst entering one of its additional_patchers
1265- entered_patchers .append (patching )
1266- # Pass the exception to __exit__
1267- exc_info = sys .exc_info ()
1268- # re-raise the exception
1269- raise
1270- finally :
1271- for patching in reversed (entered_patchers ):
1272- patching .__exit__ (* exc_info )
12731257
12741258 patched .patchings = [self ]
12751259 return patched
@@ -1411,19 +1395,23 @@ def __enter__(self):
14111395
14121396 self .temp_original = original
14131397 self .is_local = local
1414- setattr (self .target , self .attribute , new_attr )
1415- if self .attribute_name is not None :
1416- extra_args = {}
1417- if self .new is DEFAULT :
1418- extra_args [self .attribute_name ] = new
1419- for patching in self .additional_patchers :
1420- arg = patching .__enter__ ()
1421- if patching .new is DEFAULT :
1422- extra_args .update (arg )
1423- return extra_args
1424-
1425- return new
1426-
1398+ self ._exit_stack = contextlib .ExitStack ()
1399+ try :
1400+ setattr (self .target , self .attribute , new_attr )
1401+ if self .attribute_name is not None :
1402+ extra_args = {}
1403+ if self .new is DEFAULT :
1404+ extra_args [self .attribute_name ] = new
1405+ for patching in self .additional_patchers :
1406+ arg = self ._exit_stack .enter_context (patching )
1407+ if patching .new is DEFAULT :
1408+ extra_args .update (arg )
1409+ return extra_args
1410+
1411+ return new
1412+ except :
1413+ if not self .__exit__ (* sys .exc_info ()):
1414+ raise
14271415
14281416 def __exit__ (self , * exc_info ):
14291417 """Undo the patch."""
@@ -1444,9 +1432,9 @@ def __exit__(self, *exc_info):
14441432 del self .temp_original
14451433 del self .is_local
14461434 del self .target
1447- for patcher in reversed ( self .additional_patchers ):
1448- if _is_started ( patcher ):
1449- patcher .__exit__ (* exc_info )
1435+ exit_stack = self ._exit_stack
1436+ del self . _exit_stack
1437+ return exit_stack .__exit__ (* exc_info )
14501438
14511439
14521440 def start (self ):
@@ -1464,7 +1452,7 @@ def stop(self):
14641452 # If the patch hasn't been started this will fail
14651453 pass
14661454
1467- return self .__exit__ ()
1455+ return self .__exit__ (None , None , None )
14681456
14691457
14701458
0 commit comments