Skip to content

Commit fe4dc45

Browse files
committed
Implement conistent handling of timeouts.
The timeout logic is currently a bit of a mess. This commit attempts to tidy things up. Most importantly, the timeout budget is now applied consistently to the _entire_ call, whereas previously multiple consecutive operations used the same timeout value, making it possible for the overall run time to exceed the timeout. Secondly, tighten the validation for timeouts: booleans are not accepted, and neither any value that cannot be converted to float.
1 parent baf5ce7 commit fe4dc45

File tree

5 files changed

+174
-33
lines changed

5 files changed

+174
-33
lines changed

asyncpg/connection.py

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import socket
1313
import struct
14+
import time
1415
import urllib.parse
1516

1617
from . import cursor
@@ -60,6 +61,27 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6061
self._stmt_cache = collections.OrderedDict()
6162
self._stmts_to_close = set()
6263

64+
if command_timeout is not None:
65+
if isinstance(command_timeout, bool):
66+
raise ValueError(
67+
'invalid command_timeout value: '
68+
'expected non-negative float (got {!r})'.format(
69+
command_timeout))
70+
71+
try:
72+
command_timeout = float(command_timeout)
73+
except ValueError:
74+
raise ValueError(
75+
'invalid command_timeout value: '
76+
'expected non-negative float (got {!r})'.format(
77+
command_timeout)) from None
78+
79+
if command_timeout < 0:
80+
raise ValueError(
81+
'invalid command_timeout value: '
82+
'expected non-negative float (got {!r})'.format(
83+
command_timeout))
84+
6385
self._command_timeout = command_timeout
6486

6587
self._listeners = {}
@@ -187,7 +209,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
187209
if not args:
188210
return await self._protocol.query(query, timeout)
189211

190-
_, status, _ = await self._do_execute(query, args, 0, timeout, True)
212+
_, status, _ = await self._execute(query, args, 0, timeout, True)
191213
return status.decode()
192214

193215
async def executemany(self, command: str, args, timeout: float=None):
@@ -208,8 +230,7 @@ async def executemany(self, command: str, args, timeout: float=None):
208230
209231
.. versionadded:: 0.7.0
210232
"""
211-
stmt = await self._get_statement(command, timeout)
212-
return await self._protocol.bind_execute_many(stmt, args, '', timeout)
233+
return await self._executemany(command, args, timeout)
213234

214235
async def _get_statement(self, query, timeout):
215236
cache = self._stmt_cache_max_size > 0
@@ -281,7 +302,7 @@ async def fetch(self, query, *args, timeout=None) -> list:
281302
282303
:return list: A list of :class:`Record` instances.
283304
"""
284-
return await self._do_execute(query, args, 0, timeout)
305+
return await self._execute(query, args, 0, timeout)
285306

286307
async def fetchval(self, query, *args, column=0, timeout=None):
287308
"""Run a query and return a value in the first row.
@@ -297,7 +318,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
297318
298319
:return: The value of the specified column of the first record.
299320
"""
300-
data = await self._do_execute(query, args, 1, timeout)
321+
data = await self._execute(query, args, 1, timeout)
301322
if not data:
302323
return None
303324
return data[0][column]
@@ -311,7 +332,7 @@ async def fetchrow(self, query, *args, timeout=None):
311332
312333
:return: The first row as a :class:`Record` instance.
313334
"""
314-
data = await self._do_execute(query, args, 1, timeout)
335+
data = await self._execute(query, args, 1, timeout)
315336
if not data:
316337
return None
317338
return data[0]
@@ -430,7 +451,9 @@ async def _cleanup_stmts(self):
430451
to_close = self._stmts_to_close
431452
self._stmts_to_close = set()
432453
for stmt in to_close:
433-
await self._protocol.close_statement(stmt, False)
454+
# It is imperative that statements are cleaned properly,
455+
# so we ignore the timeout.
456+
await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT)
434457

435458
def _request_portal_name(self):
436459
return self._get_unique_id()
@@ -554,14 +577,29 @@ def _drop_global_statement_cache(self):
554577
else:
555578
self._drop_local_statement_cache()
556579

557-
async def _do_execute(self, query, args, limit, timeout,
558-
return_status=False):
559-
stmt = await self._get_statement(query, timeout)
580+
def _execute(self, query, args, limit, timeout, return_status=False):
581+
executor = lambda stmt, timeout: self._protocol.bind_execute(
582+
stmt, args, '', limit, return_status, timeout)
583+
timeout = self._protocol._get_timeout(timeout)
584+
if timeout is not None:
585+
return self._do_execute_with_timeout(query, executor, timeout)
586+
else:
587+
return self._do_execute(query, executor)
588+
589+
def _executemany(self, query, args, timeout):
590+
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
591+
stmt, args, '', timeout)
592+
timeout = self._protocol._get_timeout(timeout)
593+
if timeout is not None:
594+
return self._do_execute_with_timeout(query, executor, timeout)
595+
else:
596+
return self._do_execute(query, executor)
560597

561-
try:
562-
result = await self._protocol.bind_execute(
563-
stmt, args, '', limit, return_status, timeout)
598+
async def _do_execute(self, query, executor, retry=True):
599+
stmt = await self._get_statement(query, None)
564600

601+
try:
602+
result = await executor(stmt, None)
565603
except exceptions.InvalidCachedStatementError as e:
566604
# PostgreSQL will raise an exception when it detects
567605
# that the result type of the query has changed from
@@ -586,13 +624,38 @@ async def _do_execute(self, query, args, limit, timeout,
586624
# for discussion.
587625
#
588626
self._drop_global_statement_cache()
627+
if self._protocol.is_in_transaction() or not retry:
628+
raise
629+
else:
630+
result = await self._do_execute(
631+
query, executor, retry=False)
632+
633+
return result
634+
635+
async def _do_execute_with_timeout(self, query, executor, timeout,
636+
retry=True):
637+
before = time.monotonic()
638+
stmt = await self._get_statement(query, timeout)
639+
after = time.monotonic()
640+
timeout -= after - before
641+
before = after
642+
643+
try:
644+
try:
645+
result = await executor(stmt, timeout)
646+
finally:
647+
after = time.monotonic()
648+
timeout -= after - before
649+
650+
except exceptions.InvalidCachedStatementError as e:
651+
# See comment in _do_execute().
652+
self._drop_global_statement_cache()
589653

590-
if self._protocol.is_in_transaction():
654+
if self._protocol.is_in_transaction() or not retry:
591655
raise
592656
else:
593-
stmt = await self._get_statement(query, timeout)
594-
result = await self._protocol.bind_execute(
595-
stmt, args, '', limit, return_status, timeout)
657+
result = await self._do_execute_with_timeout(
658+
query, executor, timeout, retry=False)
596659

597660
return result
598661

asyncpg/protocol/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8-
from .protocol import Protocol, Record # NOQA
8+
from .protocol import Protocol, Record, NO_TIMEOUT # NOQA

asyncpg/protocol/protocol.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ cdef class BaseProtocol(CoreProtocol):
4747

4848
PreparedStatementState statement
4949

50-
cdef _ensure_clear_state(self)
50+
cdef _get_timeout_impl(self, timeout)
51+
cdef _check_state(self)
5152
cdef _new_waiter(self, timeout)
5253

5354
cdef _on_result__connect(self, object waiter)

asyncpg/protocol/protocol.pyx

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ include "coreproto.pyx"
7979
include "prepared_stmt.pyx"
8080

8181

82+
NO_TIMEOUT = object()
83+
84+
8285
cdef class BaseProtocol(CoreProtocol):
8386
def __init__(self, addr, connected_fut, con_args, loop):
8487
CoreProtocol.__init__(self, con_args)
@@ -132,7 +135,8 @@ cdef class BaseProtocol(CoreProtocol):
132135
await self.cancel_sent_waiter
133136
self.cancel_sent_waiter = None
134137

135-
self._ensure_clear_state()
138+
self._check_state()
139+
timeout = self._get_timeout_impl(timeout)
136140

137141
if stmt_name is None:
138142
self.uid_counter += 1
@@ -154,7 +158,8 @@ cdef class BaseProtocol(CoreProtocol):
154158
await self.cancel_sent_waiter
155159
self.cancel_sent_waiter = None
156160

157-
self._ensure_clear_state()
161+
self._check_state()
162+
timeout = self._get_timeout_impl(timeout)
158163

159164
self._bind_execute(
160165
portal_name,
@@ -178,7 +183,8 @@ cdef class BaseProtocol(CoreProtocol):
178183
await self.cancel_sent_waiter
179184
self.cancel_sent_waiter = None
180185

181-
self._ensure_clear_state()
186+
self._check_state()
187+
timeout = self._get_timeout_impl(timeout)
182188

183189
# Make sure the argument sequence is encoded lazily with
184190
# this generator expression to keep the memory pressure under
@@ -209,7 +215,8 @@ cdef class BaseProtocol(CoreProtocol):
209215
await self.cancel_sent_waiter
210216
self.cancel_sent_waiter = None
211217

212-
self._ensure_clear_state()
218+
self._check_state()
219+
timeout = self._get_timeout_impl(timeout)
213220

214221
self._bind(
215222
portal_name,
@@ -231,7 +238,8 @@ cdef class BaseProtocol(CoreProtocol):
231238
await self.cancel_sent_waiter
232239
self.cancel_sent_waiter = None
233240

234-
self._ensure_clear_state()
241+
self._check_state()
242+
timeout = self._get_timeout_impl(timeout)
235243

236244
self._execute(
237245
portal_name,
@@ -251,7 +259,11 @@ cdef class BaseProtocol(CoreProtocol):
251259
await self.cancel_sent_waiter
252260
self.cancel_sent_waiter = None
253261

254-
self._ensure_clear_state()
262+
self._check_state()
263+
# query() needs to call _get_timeout instead of _get_timeout_impl
264+
# for consistent validation, as it is called differently from
265+
# prepare/bind/execute methods.
266+
timeout = self._get_timeout(timeout)
255267

256268
self._simple_query(query)
257269
self.last_query = query
@@ -266,7 +278,8 @@ cdef class BaseProtocol(CoreProtocol):
266278
await self.cancel_sent_waiter
267279
self.cancel_sent_waiter = None
268280

269-
self._ensure_clear_state()
281+
self._check_state()
282+
timeout = self._get_timeout_impl(timeout)
270283

271284
if state.refs != 0:
272285
raise RuntimeError(
@@ -348,7 +361,34 @@ cdef class BaseProtocol(CoreProtocol):
348361
cdef _set_server_parameter(self, name, val):
349362
self.settings.add_setting(name, val)
350363

351-
cdef _ensure_clear_state(self):
364+
def _get_timeout(self, timeout):
365+
if type(timeout) is bool:
366+
raise ValueError(
367+
'invalid timeout value: expected non-negative float '
368+
'(got {!r})'.format(timeout))
369+
elif timeout is not None:
370+
try:
371+
timeout = float(timeout)
372+
except ValueError:
373+
raise ValueError(
374+
'invalid timeout value: expected non-negative float '
375+
'(got {!r})'.format(timeout)) from None
376+
377+
return self._get_timeout_impl(timeout)
378+
379+
cdef inline _get_timeout_impl(self, timeout):
380+
if timeout is None:
381+
timeout = self.connection._command_timeout
382+
elif timeout is NO_TIMEOUT:
383+
timeout = None
384+
else:
385+
timeout = float(timeout)
386+
387+
if timeout is not None and timeout <= 0:
388+
raise asyncio.TimeoutError()
389+
return timeout
390+
391+
cdef _check_state(self):
352392
if self.cancel_waiter is not None:
353393
raise apg_exc.InterfaceError(
354394
'cannot perform operation: another operation is cancelling')
@@ -361,11 +401,9 @@ cdef class BaseProtocol(CoreProtocol):
361401

362402
cdef _new_waiter(self, timeout):
363403
self.waiter = self.create_future()
364-
if timeout is not False:
365-
timeout = timeout or self.connection._command_timeout
366-
if timeout is not None and timeout > 0:
367-
self.timeout_handle = self.connection._loop.call_later(
368-
timeout, self.timeout_callback, self.waiter)
404+
if timeout is not None:
405+
self.timeout_handle = self.connection._loop.call_later(
406+
timeout, self.timeout_callback, self.waiter)
369407
self.waiter.add_done_callback(self.completed_callback)
370408
return self.waiter
371409

tests/test_timeout.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77

88
import asyncio
9-
import asyncpg
109

10+
import asyncpg
11+
from asyncpg import connection as pg_connection
1112
from asyncpg import _testbase as tb
1213

1314

@@ -108,6 +109,22 @@ async def test_timeout_06(self):
108109

109110
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
110111

112+
async def test_invalid_timeout(self):
113+
for command_timeout in ('a', False, -1):
114+
with self.subTest(command_timeout=command_timeout):
115+
with self.assertRaisesRegex(ValueError,
116+
'invalid command_timeout'):
117+
await self.cluster.connect(
118+
database='postgres', loop=self.loop,
119+
command_timeout=command_timeout)
120+
121+
# Note: negative timeouts are OK for method calls.
122+
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
123+
for timeout in ('a', False):
124+
with self.subTest(timeout=timeout):
125+
with self.assertRaisesRegex(ValueError, 'invalid timeout'):
126+
await self.con.execute('SELECT 1', timeout=timeout)
127+
111128

112129
class TestConnectionCommandTimeout(tb.ConnectedTestCase):
113130

@@ -123,3 +140,25 @@ async def test_command_timeout_01(self):
123140
meth = getattr(self.con, methname)
124141
await meth('select pg_sleep(10)')
125142
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
143+
144+
145+
class SlowPrepareConnection(pg_connection.Connection):
146+
"""Connection class to test timeouts."""
147+
async def _get_statement(self, query, timeout):
148+
await asyncio.sleep(0.15, loop=self._loop)
149+
return await super()._get_statement(query, timeout)
150+
151+
152+
class TestTimeoutCoversPrepare(tb.ConnectedTestCase):
153+
154+
def getExtraConnectOptions(self):
155+
return {
156+
'__connection_class__': SlowPrepareConnection,
157+
'command_timeout': 0.3
158+
}
159+
160+
async def test_timeout_covers_prepare_01(self):
161+
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
162+
with self.assertRaises(asyncio.TimeoutError):
163+
meth = getattr(self.con, methname)
164+
await meth('select pg_sleep($1)', 0.2)

0 commit comments

Comments
 (0)