Skip to content
23 changes: 20 additions & 3 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import collections
import collections.abc
import concurrent.futures
import contextvars
import functools
import heapq
import itertools
Expand Down Expand Up @@ -789,7 +790,8 @@ def call_soon_threadsafe(self, callback, *args, context=None):
self._write_to_self()
return handle

def run_in_executor(self, executor, func, *args):
def run_in_executor(self, executor, func, *args, context=None,
retain_context=False):
self._check_closed()
if self._debug:
self._check_callback(func, 'run_in_executor')
Expand All @@ -800,8 +802,23 @@ def run_in_executor(self, executor, func, *args):
if executor is None:
executor = concurrent.futures.ThreadPoolExecutor()
self._default_executor = executor
return futures.wrap_future(
executor.submit(func, *args), loop=self)

if args:
runner = functools.partial(func, *args)
else:
runner = func

if retain_context:
if not isinstance(executor, concurrent.futures.ThreadPoolExecutor):
raise RuntimeError(
'retain_context=True supports only ThreadPoolExecutor')

if context is None:
context = contextvars.copy_context()

runner = functools.partial(context.run, runner)

return futures.wrap_future(executor.submit(runner), loop=self)

def set_default_executor(self, executor):
if not isinstance(executor, concurrent.futures.ThreadPoolExecutor):
Expand Down
56 changes: 56 additions & 0 deletions Lib/test/test_asyncio/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections.abc
import concurrent.futures
import contextvars
import functools
import io
import os
Expand Down Expand Up @@ -34,6 +35,9 @@
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST

foo_ctx = contextvars.ContextVar('foo')
foo_ctx.set('bar')


def tearDownModule():
asyncio.set_event_loop_policy(None)
Expand Down Expand Up @@ -367,6 +371,58 @@ def run():
time.sleep(0.4)
self.assertFalse(called)

def test_run_in_executor_hierarchy(self):
def run():
foo_ctx.set('foo')
res = foo_ctx.get()
self.assertEqual(res, 'foo')
return res

f = self.loop.run_in_executor(None, run, retain_context=True)
res = self.loop.run_until_complete(f)
self.assertEqual(res, 'foo')

res = foo_ctx.get()
self.assertEqual(res, 'bar')

def test_run_in_executor_no_context(self):
def run():
return foo_ctx.get()

f = self.loop.run_in_executor(None, run, retain_context=True)
res = self.loop.run_until_complete(f)
self.assertEqual(res, 'bar')

def test_run_in_executor_context(self):
def run():
return foo_ctx.get()

context = contextvars.copy_context()
f = self.loop.run_in_executor(None, run, context=context,
retain_context=True)
res = self.loop.run_until_complete(f)
self.assertEqual(res, 'bar')

def test_run_in_executor_context_args(self):
def run(arg):
return (arg, foo_ctx.get())

context = contextvars.copy_context()
f = self.loop.run_in_executor(None, run, 'yo', context=context,
retain_context=True)
res = self.loop.run_until_complete(f)
self.assertEqual(res, ('yo', 'bar'))

def test_run_in_executor_context_subprocess(self):
def run(arg):
pass

pool = concurrent.futures.ProcessPoolExecutor()
context = contextvars.copy_context()
with self.assertRaises(RuntimeError):
self.loop.run_in_executor(pool, run, retain_context=True)
pool.shutdown()

def test_reader_callback(self):
r, w = socket.socketpair()
r.setblocking(False)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support of contextvars for BaseEventLoop.run_in_executor