1111import os
1212import socket
1313import struct
14+ import time
1415import urllib .parse
1516
1617from . import cursor
@@ -60,6 +61,27 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6061 self ._stmt_cache = collections .OrderedDict ()
6162 self ._stmts_to_close = set ()
6263
64+ if command_timeout is not None :
65+ if isinstance (command_timeout , bool ):
66+ raise ValueError (
67+ 'invalid command_timeout value: '
68+ 'expected non-negative float (got {!r})' .format (
69+ command_timeout ))
70+
71+ try :
72+ command_timeout = float (command_timeout )
73+ except ValueError :
74+ raise ValueError (
75+ 'invalid command_timeout value: '
76+ 'expected non-negative float (got {!r})' .format (
77+ command_timeout )) from None
78+
79+ if command_timeout < 0 :
80+ raise ValueError (
81+ 'invalid command_timeout value: '
82+ 'expected non-negative float (got {!r})' .format (
83+ command_timeout ))
84+
6385 self ._command_timeout = command_timeout
6486
6587 self ._listeners = {}
@@ -187,7 +209,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
187209 if not args :
188210 return await self ._protocol .query (query , timeout )
189211
190- _ , status , _ = await self ._do_execute (query , args , 0 , timeout , True )
212+ _ , status , _ = await self ._execute (query , args , 0 , timeout , True )
191213 return status .decode ()
192214
193215 async def executemany (self , command : str , args , timeout : float = None ):
@@ -208,8 +230,7 @@ async def executemany(self, command: str, args, timeout: float=None):
208230
209231 .. versionadded:: 0.7.0
210232 """
211- stmt = await self ._get_statement (command , timeout )
212- return await self ._protocol .bind_execute_many (stmt , args , '' , timeout )
233+ return await self ._executemany (command , args , timeout )
213234
214235 async def _get_statement (self , query , timeout ):
215236 cache = self ._stmt_cache_max_size > 0
@@ -281,7 +302,7 @@ async def fetch(self, query, *args, timeout=None) -> list:
281302
282303 :return list: A list of :class:`Record` instances.
283304 """
284- return await self ._do_execute (query , args , 0 , timeout )
305+ return await self ._execute (query , args , 0 , timeout )
285306
286307 async def fetchval (self , query , * args , column = 0 , timeout = None ):
287308 """Run a query and return a value in the first row.
@@ -297,7 +318,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
297318
298319 :return: The value of the specified column of the first record.
299320 """
300- data = await self ._do_execute (query , args , 1 , timeout )
321+ data = await self ._execute (query , args , 1 , timeout )
301322 if not data :
302323 return None
303324 return data [0 ][column ]
@@ -311,7 +332,7 @@ async def fetchrow(self, query, *args, timeout=None):
311332
312333 :return: The first row as a :class:`Record` instance.
313334 """
314- data = await self ._do_execute (query , args , 1 , timeout )
335+ data = await self ._execute (query , args , 1 , timeout )
315336 if not data :
316337 return None
317338 return data [0 ]
@@ -430,7 +451,9 @@ async def _cleanup_stmts(self):
430451 to_close = self ._stmts_to_close
431452 self ._stmts_to_close = set ()
432453 for stmt in to_close :
433- await self ._protocol .close_statement (stmt , False )
454+ # It is imperative that statements are cleaned properly,
455+ # so we ignore the timeout.
456+ await self ._protocol .close_statement (stmt , protocol .NO_TIMEOUT )
434457
435458 def _request_portal_name (self ):
436459 return self ._get_unique_id ()
@@ -554,14 +577,29 @@ def _drop_global_statement_cache(self):
554577 else :
555578 self ._drop_local_statement_cache ()
556579
557- async def _do_execute (self , query , args , limit , timeout ,
558- return_status = False ):
559- stmt = await self ._get_statement (query , timeout )
580+ def _execute (self , query , args , limit , timeout , return_status = False ):
581+ executor = lambda stmt , timeout : self ._protocol .bind_execute (
582+ stmt , args , '' , limit , return_status , timeout )
583+ timeout = self ._protocol ._get_timeout (timeout )
584+ if timeout is not None :
585+ return self ._do_execute_with_timeout (query , executor , timeout )
586+ else :
587+ return self ._do_execute (query , executor )
588+
589+ def _executemany (self , query , args , timeout ):
590+ executor = lambda stmt , timeout : self ._protocol .bind_execute_many (
591+ stmt , args , '' , timeout )
592+ timeout = self ._protocol ._get_timeout (timeout )
593+ if timeout is not None :
594+ return self ._do_execute_with_timeout (query , executor , timeout )
595+ else :
596+ return self ._do_execute (query , executor )
560597
561- try :
562- result = await self ._protocol .bind_execute (
563- stmt , args , '' , limit , return_status , timeout )
598+ async def _do_execute (self , query , executor , retry = True ):
599+ stmt = await self ._get_statement (query , None )
564600
601+ try :
602+ result = await executor (stmt , None )
565603 except exceptions .InvalidCachedStatementError as e :
566604 # PostgreSQL will raise an exception when it detects
567605 # that the result type of the query has changed from
@@ -586,13 +624,38 @@ async def _do_execute(self, query, args, limit, timeout,
586624 # for discussion.
587625 #
588626 self ._drop_global_statement_cache ()
627+ if self ._protocol .is_in_transaction () or not retry :
628+ raise
629+ else :
630+ result = await self ._do_execute (
631+ query , executor , retry = False )
632+
633+ return result
634+
635+ async def _do_execute_with_timeout (self , query , executor , timeout ,
636+ retry = True ):
637+ before = time .monotonic ()
638+ stmt = await self ._get_statement (query , timeout )
639+ after = time .monotonic ()
640+ timeout -= after - before
641+ before = after
642+
643+ try :
644+ try :
645+ result = await executor (stmt , timeout )
646+ finally :
647+ after = time .monotonic ()
648+ timeout -= after - before
649+
650+ except exceptions .InvalidCachedStatementError as e :
651+ # See comment in _do_execute().
652+ self ._drop_global_statement_cache ()
589653
590- if self ._protocol .is_in_transaction ():
654+ if self ._protocol .is_in_transaction () or not retry :
591655 raise
592656 else :
593- stmt = await self ._get_statement (query , timeout )
594- result = await self ._protocol .bind_execute (
595- stmt , args , '' , limit , return_status , timeout )
657+ result = await self ._do_execute_with_timeout (
658+ query , executor , timeout , retry = False )
596659
597660 return result
598661
0 commit comments