|
38 | 38 | from . import coroutines |
39 | 39 | from . import events |
40 | 40 | from . import futures |
| 41 | +from . import protocols |
41 | 42 | from . import sslproto |
42 | 43 | from . import tasks |
| 44 | +from . import transports |
43 | 45 | from .log import logger |
44 | 46 |
|
45 | 47 |
|
@@ -155,6 +157,75 @@ def _run_until_complete_cb(fut): |
155 | 157 | futures._get_loop(fut).stop() |
156 | 158 |
|
157 | 159 |
|
| 160 | + |
| 161 | +class _SendfileFallbackProtocol(protocols.Protocol): |
| 162 | + def __init__(self, transp): |
| 163 | + if not isinstance(transp, transports._FlowControlMixin): |
| 164 | + raise TypeError("transport should be _FlowControlMixin instance") |
| 165 | + self._transport = transp |
| 166 | + self._proto = transp.get_protocol() |
| 167 | + self._should_resume_reading = transp.is_reading() |
| 168 | + self._should_resume_writing = transp._protocol_paused |
| 169 | + transp.pause_reading() |
| 170 | + transp.set_protocol(self) |
| 171 | + if self._should_resume_writing: |
| 172 | + self._write_ready_fut = self._transport._loop.create_future() |
| 173 | + else: |
| 174 | + self._write_ready_fut = None |
| 175 | + |
| 176 | + async def drain(self): |
| 177 | + if self._transport.is_closing(): |
| 178 | + raise ConnectionError("Connection closed by peer") |
| 179 | + fut = self._write_ready_fut |
| 180 | + if fut is None: |
| 181 | + return |
| 182 | + await fut |
| 183 | + |
| 184 | + def connection_made(self, transport): |
| 185 | + raise RuntimeError("Invalid state: " |
| 186 | + "connection should have been established already.") |
| 187 | + |
| 188 | + def connection_lost(self, exc): |
| 189 | + if self._write_ready_fut is not None: |
| 190 | + # Never happens if peer disconnects after sending the whole content |
| 191 | + # Thus disconnection is always an exception from user perspective |
| 192 | + if exc is None: |
| 193 | + self._write_ready_fut.set_exception( |
| 194 | + ConnectionError("Connection is closed by peer")) |
| 195 | + else: |
| 196 | + self._write_ready_fut.set_exception(exc) |
| 197 | + self._proto.connection_lost(exc) |
| 198 | + |
| 199 | + def pause_writing(self): |
| 200 | + if self._write_ready_fut is not None: |
| 201 | + return |
| 202 | + self._write_ready_fut = self._transport._loop.create_future() |
| 203 | + |
| 204 | + def resume_writing(self): |
| 205 | + if self._write_ready_fut is None: |
| 206 | + return |
| 207 | + self._write_ready_fut.set_result(False) |
| 208 | + self._write_ready_fut = None |
| 209 | + |
| 210 | + def data_received(self, data): |
| 211 | + raise RuntimeError("Invalid state: reading should be paused") |
| 212 | + |
| 213 | + def eof_received(self): |
| 214 | + raise RuntimeError("Invalid state: reading should be paused") |
| 215 | + |
| 216 | + async def restore(self): |
| 217 | + self._transport.set_protocol(self._proto) |
| 218 | + if self._should_resume_reading: |
| 219 | + self._transport.resume_reading() |
| 220 | + if self._write_ready_fut is not None: |
| 221 | + # Cancel the future. |
| 222 | + # Basically it has no effect because protocol is switched back, |
| 223 | + # no code should wait for it anymore. |
| 224 | + self._write_ready_fut.cancel() |
| 225 | + if self._should_resume_writing: |
| 226 | + self._proto.resume_writing() |
| 227 | + |
| 228 | + |
158 | 229 | class Server(events.AbstractServer): |
159 | 230 |
|
160 | 231 | def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, |
@@ -926,6 +997,77 @@ async def _create_connection_transport( |
926 | 997 |
|
927 | 998 | return transport, protocol |
928 | 999 |
|
| 1000 | + async def sendfile(self, transport, file, offset=0, count=None, |
| 1001 | + *, fallback=True): |
| 1002 | + """Send a file to transport. |
| 1003 | +
|
| 1004 | + Return the total number of bytes which were sent. |
| 1005 | +
|
| 1006 | + The method uses high-performance os.sendfile if available. |
| 1007 | +
|
| 1008 | + file must be a regular file object opened in binary mode. |
| 1009 | +
|
| 1010 | + offset tells from where to start reading the file. If specified, |
| 1011 | + count is the total number of bytes to transmit as opposed to |
| 1012 | + sending the file until EOF is reached. File position is updated on |
| 1013 | + return or also in case of error in which case file.tell() |
| 1014 | + can be used to figure out the number of bytes |
| 1015 | + which were sent. |
| 1016 | +
|
| 1017 | + fallback set to True makes asyncio to manually read and send |
| 1018 | + the file when the platform does not support the sendfile syscall |
| 1019 | + (e.g. Windows or SSL socket on Unix). |
| 1020 | +
|
| 1021 | + Raise SendfileNotAvailableError if the system does not support |
| 1022 | + sendfile syscall and fallback is False. |
| 1023 | + """ |
| 1024 | + if transport.is_closing(): |
| 1025 | + raise RuntimeError("Transport is closing") |
| 1026 | + mode = getattr(transport, '_sendfile_compatible', |
| 1027 | + constants._SendfileMode.UNSUPPORTED) |
| 1028 | + if mode is constants._SendfileMode.UNSUPPORTED: |
| 1029 | + raise RuntimeError( |
| 1030 | + f"sendfile is not supported for transport {transport!r}") |
| 1031 | + if mode is constants._SendfileMode.TRY_NATIVE: |
| 1032 | + try: |
| 1033 | + return await self._sendfile_native(transport, file, |
| 1034 | + offset, count) |
| 1035 | + except events.SendfileNotAvailableError as exc: |
| 1036 | + if not fallback: |
| 1037 | + raise |
| 1038 | + # the mode is FALLBACK or fallback is True |
| 1039 | + return await self._sendfile_fallback(transport, file, |
| 1040 | + offset, count) |
| 1041 | + |
| 1042 | + async def _sendfile_native(self, transp, file, offset, count): |
| 1043 | + raise events.SendfileNotAvailableError( |
| 1044 | + "sendfile syscall is not supported") |
| 1045 | + |
| 1046 | + async def _sendfile_fallback(self, transp, file, offset, count): |
| 1047 | + if offset: |
| 1048 | + file.seek(offset) |
| 1049 | + blocksize = min(count, 16384) if count else 16384 |
| 1050 | + buf = bytearray(blocksize) |
| 1051 | + total_sent = 0 |
| 1052 | + proto = _SendfileFallbackProtocol(transp) |
| 1053 | + try: |
| 1054 | + while True: |
| 1055 | + if count: |
| 1056 | + blocksize = min(count - total_sent, blocksize) |
| 1057 | + if blocksize <= 0: |
| 1058 | + return total_sent |
| 1059 | + view = memoryview(buf)[:blocksize] |
| 1060 | + read = file.readinto(view) |
| 1061 | + if not read: |
| 1062 | + return total_sent # EOF |
| 1063 | + await proto.drain() |
| 1064 | + transp.write(view) |
| 1065 | + total_sent += read |
| 1066 | + finally: |
| 1067 | + if total_sent > 0 and hasattr(file, 'seek'): |
| 1068 | + file.seek(offset + total_sent) |
| 1069 | + await proto.restore() |
| 1070 | + |
929 | 1071 | async def start_tls(self, transport, protocol, sslcontext, *, |
930 | 1072 | server_side=False, |
931 | 1073 | server_hostname=None, |
|
0 commit comments