Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

import collections
import contextvars
import collections.abc
import concurrent.futures
import errno
Expand Down Expand Up @@ -289,6 +290,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None
self._context = contextvars.copy_context()

def __repr__(self):
return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
Expand Down Expand Up @@ -318,7 +320,7 @@ def _start_serving(self):
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
self, self._backlog, self._ssl_handshake_timeout,
self._ssl_shutdown_timeout)
self._ssl_shutdown_timeout, context=self._context)

def get_loop(self):
return self._loop
Expand Down Expand Up @@ -1211,9 +1213,10 @@ async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_shutdown_timeout=None, context=None):

sock.setblocking(False)
context = context if context is not None else contextvars.copy_context()

protocol = protocol_factory()
waiter = self.create_future()
Expand All @@ -1225,7 +1228,7 @@ async def _create_connection_transport(
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)
transport = self._make_socket_transport(sock, protocol, waiter, context=context)

try:
await waiter
Expand Down
50 changes: 25 additions & 25 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def __init__(self, selector=None):
self._transports = weakref.WeakValueDictionary()

def _make_socket_transport(self, sock, protocol, waiter=None, *,
extra=None, server=None):
extra=None, server=None, context=None):
self._ensure_fd_no_transport(sock)
return _SelectorSocketTransport(self, sock, protocol, waiter,
extra, server)
extra, server, context=context)

def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
Expand Down Expand Up @@ -159,16 +159,16 @@ def _write_to_self(self):
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog,
ssl_handshake_timeout, ssl_shutdown_timeout)
ssl_handshake_timeout, ssl_shutdown_timeout, context)

def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
Expand Down Expand Up @@ -204,21 +204,21 @@ def _accept_connection(
self._start_serving,
protocol_factory, sock, sslcontext, server,
backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
ssl_shutdown_timeout, context)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout, ssl_shutdown_timeout)
self.create_task(accept)
ssl_handshake_timeout, ssl_shutdown_timeout, context=context)
self.create_task(accept, context=context)

async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
protocol = None
transport = None
try:
Expand All @@ -233,7 +233,7 @@ async def _accept_connection2(
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
server=server)
server=server, context=context)

try:
await waiter
Expand Down Expand Up @@ -275,9 +275,9 @@ def _ensure_fd_no_transport(self, fd):
f'File descriptor {fd!r} is used by transport '
f'{transport!r}')

def _add_reader(self, fd, callback, *args):
def _add_reader(self, fd, callback, *args, context=None):
self._check_closed()
handle = events.Handle(callback, args, self, None)
handle = events.Handle(callback, args, self, context=context)
key = self._selector.get_map().get(fd)
if key is None:
self._selector.register(fd, selectors.EVENT_READ,
Expand Down Expand Up @@ -770,7 +770,7 @@ class _SelectorTransport(transports._FlowControlMixin,
# exception)
_sock = None

def __init__(self, loop, sock, protocol, extra=None, server=None):
def __init__(self, loop, sock, protocol, extra=None, server=None, context=None):
super().__init__(extra, loop)
self._extra['socket'] = trsock.TransportSocket(sock)
try:
Expand All @@ -784,7 +784,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
self._extra['peername'] = None
self._sock = sock
self._sock_fd = sock.fileno()

self._context = context
self._protocol_connected = False
self.set_protocol(protocol)

Expand Down Expand Up @@ -866,7 +866,7 @@ def close(self):
if not self._buffer:
self._conn_lost += 1
self._loop._remove_writer(self._sock_fd)
self._loop.call_soon(self._call_connection_lost, None)
self._loop.call_soon(self._call_connection_lost, None, context=self._context)

def __del__(self, _warn=warnings.warn):
if self._sock is not None:
Expand Down Expand Up @@ -899,7 +899,7 @@ def _force_close(self, exc):
self._closing = True
self._loop._remove_reader(self._sock_fd)
self._conn_lost += 1
self._loop.call_soon(self._call_connection_lost, exc)
self._loop.call_soon(self._call_connection_lost, exc, context=self._context)

def _call_connection_lost(self, exc):
try:
Expand All @@ -921,7 +921,7 @@ def get_write_buffer_size(self):
def _add_reader(self, fd, callback, *args):
if not self.is_reading():
return
self._loop._add_reader(fd, callback, *args)
self._loop._add_reader(fd, callback, *args, context=self._context)


class _SelectorSocketTransport(_SelectorTransport):
Expand All @@ -930,10 +930,10 @@ class _SelectorSocketTransport(_SelectorTransport):
_sendfile_compatible = constants._SendfileMode.TRY_NATIVE

def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None):

extra=None, server=None, context=None):
assert context is not None
self._read_ready_cb = None
super().__init__(loop, sock, protocol, extra, server)
super().__init__(loop, sock, protocol, extra, server, context)
self._eof = False
self._empty_waiter = None
if _HAS_SENDMSG:
Expand All @@ -945,14 +945,14 @@ def __init__(self, loop, sock, protocol, waiter=None,
# decreases the latency (in some cases significantly.)
base_events._set_nodelay(self._sock)

self._loop.call_soon(self._protocol.connection_made, self)
self._loop.call_soon(self._protocol.connection_made, self, context=context)
# only start reading when connection_made() has been called
self._loop.call_soon(self._add_reader,
self._sock_fd, self._read_ready)
self._sock_fd, self._read_ready, context=context)
if waiter is not None:
# only wake up the waiter when connection_made() has been called
self._loop.call_soon(futures._set_result_unless_cancelled,
waiter, None)
waiter, None, context=context)

def set_protocol(self, protocol):
if isinstance(protocol, protocols.BufferedProtocol):
Expand Down Expand Up @@ -1081,7 +1081,7 @@ def write(self, data):
if not data:
return
# Not all was written; register write handler.
self._loop._add_writer(self._sock_fd, self._write_ready)
self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context)

# Add it to the buffer.
self._buffer.append(data)
Expand Down Expand Up @@ -1185,7 +1185,7 @@ def writelines(self, list_of_data):
self._write_ready()
# If the entire buffer couldn't be written, register a write handler
if self._buffer:
self._loop._add_writer(self._sock_fd, self._write_ready)
self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context)
self._maybe_pause_protocol()

def can_write_eof(self):
Expand Down
Loading