2626__version__ = '1.0'
2727
2828import asyncio
29+ import contextlib
2930import io
3031import inspect
3132import pprint
@@ -1220,6 +1221,8 @@ def copy(self):
12201221 def __call__ (self , func ):
12211222 if isinstance (func , type ):
12221223 return self .decorate_class (func )
1224+ if inspect .iscoroutinefunction (func ):
1225+ return self .decorate_async_callable (func )
12231226 return self .decorate_callable (func )
12241227
12251228
@@ -1237,41 +1240,68 @@ def decorate_class(self, klass):
12371240 return klass
12381241
12391242
1243+ @contextlib .contextmanager
1244+ def decoration_helper (self , patched , args , keywargs ):
1245+ extra_args = []
1246+ entered_patchers = []
1247+ patching = None
1248+
1249+ exc_info = tuple ()
1250+ try :
1251+ for patching in patched .patchings :
1252+ arg = patching .__enter__ ()
1253+ entered_patchers .append (patching )
1254+ if patching .attribute_name is not None :
1255+ keywargs .update (arg )
1256+ elif patching .new is DEFAULT :
1257+ extra_args .append (arg )
1258+
1259+ args += tuple (extra_args )
1260+ yield (args , keywargs )
1261+ except :
1262+ if (patching not in entered_patchers and
1263+ _is_started (patching )):
1264+ # the patcher may have been started, but an exception
1265+ # raised whilst entering one of its additional_patchers
1266+ entered_patchers .append (patching )
1267+ # Pass the exception to __exit__
1268+ exc_info = sys .exc_info ()
1269+ # re-raise the exception
1270+ raise
1271+ finally :
1272+ for patching in reversed (entered_patchers ):
1273+ patching .__exit__ (* exc_info )
1274+
1275+
12401276 def decorate_callable (self , func ):
1277+ # NB. Keep the method in sync with decorate_async_callable()
12411278 if hasattr (func , 'patchings' ):
12421279 func .patchings .append (self )
12431280 return func
12441281
12451282 @wraps (func )
12461283 def patched (* args , ** keywargs ):
1247- extra_args = []
1248- entered_patchers = []
1284+ with self .decoration_helper (patched ,
1285+ args ,
1286+ keywargs ) as (newargs , newkeywargs ):
1287+ return func (* newargs , ** newkeywargs )
12491288
1250- exc_info = tuple ()
1251- try :
1252- for patching in patched .patchings :
1253- arg = patching .__enter__ ()
1254- entered_patchers .append (patching )
1255- if patching .attribute_name is not None :
1256- keywargs .update (arg )
1257- elif patching .new is DEFAULT :
1258- extra_args .append (arg )
1259-
1260- args += tuple (extra_args )
1261- return func (* args , ** keywargs )
1262- except :
1263- if (patching not in entered_patchers and
1264- _is_started (patching )):
1265- # the patcher may have been started, but an exception
1266- # raised whilst entering one of its additional_patchers
1267- entered_patchers .append (patching )
1268- # Pass the exception to __exit__
1269- exc_info = sys .exc_info ()
1270- # re-raise the exception
1271- raise
1272- finally :
1273- for patching in reversed (entered_patchers ):
1274- patching .__exit__ (* exc_info )
1289+ patched .patchings = [self ]
1290+ return patched
1291+
1292+
1293+ def decorate_async_callable (self , func ):
1294+ # NB. Keep the method in sync with decorate_callable()
1295+ if hasattr (func , 'patchings' ):
1296+ func .patchings .append (self )
1297+ return func
1298+
1299+ @wraps (func )
1300+ async def patched (* args , ** keywargs ):
1301+ with self .decoration_helper (patched ,
1302+ args ,
1303+ keywargs ) as (newargs , newkeywargs ):
1304+ return await func (* newargs , ** newkeywargs )
12751305
12761306 patched .patchings = [self ]
12771307 return patched
0 commit comments