Skip to content

Commit 436c2b0

Browse files
tirkarthimiss-islington
authored andcommitted
bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)
Return a coroutine while patching async functions with a decorator. Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com> https://bugs.python.org/issue36996
1 parent 71dc7c5 commit 436c2b0

File tree

3 files changed

+74
-27
lines changed

3 files changed

+74
-27
lines changed

Lib/unittest/mock.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
__version__ = '1.0'
2727

2828
import asyncio
29+
import contextlib
2930
import io
3031
import inspect
3132
import pprint
@@ -1220,6 +1221,8 @@ def copy(self):
12201221
def __call__(self, func):
12211222
if isinstance(func, type):
12221223
return self.decorate_class(func)
1224+
if inspect.iscoroutinefunction(func):
1225+
return self.decorate_async_callable(func)
12231226
return self.decorate_callable(func)
12241227

12251228

@@ -1237,41 +1240,68 @@ def decorate_class(self, klass):
12371240
return klass
12381241

12391242

1243+
@contextlib.contextmanager
1244+
def decoration_helper(self, patched, args, keywargs):
1245+
extra_args = []
1246+
entered_patchers = []
1247+
patching = None
1248+
1249+
exc_info = tuple()
1250+
try:
1251+
for patching in patched.patchings:
1252+
arg = patching.__enter__()
1253+
entered_patchers.append(patching)
1254+
if patching.attribute_name is not None:
1255+
keywargs.update(arg)
1256+
elif patching.new is DEFAULT:
1257+
extra_args.append(arg)
1258+
1259+
args += tuple(extra_args)
1260+
yield (args, keywargs)
1261+
except:
1262+
if (patching not in entered_patchers and
1263+
_is_started(patching)):
1264+
# the patcher may have been started, but an exception
1265+
# raised whilst entering one of its additional_patchers
1266+
entered_patchers.append(patching)
1267+
# Pass the exception to __exit__
1268+
exc_info = sys.exc_info()
1269+
# re-raise the exception
1270+
raise
1271+
finally:
1272+
for patching in reversed(entered_patchers):
1273+
patching.__exit__(*exc_info)
1274+
1275+
12401276
def decorate_callable(self, func):
1277+
# NB. Keep the method in sync with decorate_async_callable()
12411278
if hasattr(func, 'patchings'):
12421279
func.patchings.append(self)
12431280
return func
12441281

12451282
@wraps(func)
12461283
def patched(*args, **keywargs):
1247-
extra_args = []
1248-
entered_patchers = []
1284+
with self.decoration_helper(patched,
1285+
args,
1286+
keywargs) as (newargs, newkeywargs):
1287+
return func(*newargs, **newkeywargs)
12491288

1250-
exc_info = tuple()
1251-
try:
1252-
for patching in patched.patchings:
1253-
arg = patching.__enter__()
1254-
entered_patchers.append(patching)
1255-
if patching.attribute_name is not None:
1256-
keywargs.update(arg)
1257-
elif patching.new is DEFAULT:
1258-
extra_args.append(arg)
1259-
1260-
args += tuple(extra_args)
1261-
return func(*args, **keywargs)
1262-
except:
1263-
if (patching not in entered_patchers and
1264-
_is_started(patching)):
1265-
# the patcher may have been started, but an exception
1266-
# raised whilst entering one of its additional_patchers
1267-
entered_patchers.append(patching)
1268-
# Pass the exception to __exit__
1269-
exc_info = sys.exc_info()
1270-
# re-raise the exception
1271-
raise
1272-
finally:
1273-
for patching in reversed(entered_patchers):
1274-
patching.__exit__(*exc_info)
1289+
patched.patchings = [self]
1290+
return patched
1291+
1292+
1293+
def decorate_async_callable(self, func):
1294+
# NB. Keep the method in sync with decorate_callable()
1295+
if hasattr(func, 'patchings'):
1296+
func.patchings.append(self)
1297+
return func
1298+
1299+
@wraps(func)
1300+
async def patched(*args, **keywargs):
1301+
with self.decoration_helper(patched,
1302+
args,
1303+
keywargs) as (newargs, newkeywargs):
1304+
return await func(*newargs, **newkeywargs)
12751305

12761306
patched.patchings = [self]
12771307
return patched

Lib/unittest/test/testmock/testasync.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ def test_async(mock_method):
6666

6767
test_async()
6868

69+
def test_async_def_patch(self):
70+
@patch(f"{__name__}.async_func", AsyncMock())
71+
async def test_async():
72+
self.assertIsInstance(async_func, AsyncMock)
73+
74+
asyncio.run(test_async())
75+
self.assertTrue(inspect.iscoroutinefunction(async_func))
76+
6977

7078
class AsyncPatchCMTest(unittest.TestCase):
7179
def test_is_async_function_cm(self):
@@ -91,6 +99,14 @@ def test_async():
9199

92100
test_async()
93101

102+
def test_async_def_cm(self):
103+
async def test_async():
104+
with patch(f"{__name__}.async_func", AsyncMock()):
105+
self.assertIsInstance(async_func, AsyncMock)
106+
self.assertTrue(inspect.iscoroutinefunction(async_func))
107+
108+
asyncio.run(test_async())
109+
94110

95111
class AsyncMockTest(unittest.TestCase):
96112
def test_iscoroutinefunction_default(self):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Handle :func:`unittest.mock.patch` used as a decorator on async functions.

0 commit comments

Comments
 (0)