@@ -43,11 +43,13 @@ class Connection(metaclass=ConnectionMeta):
4343 '_addr' , '_opts' , '_command_timeout' , '_listeners' ,
4444 '_server_version' , '_server_caps' , '_intro_query' ,
4545 '_reset_query' , '_proxy' , '_stmt_exclusive_section' ,
46-  '_ssl_context' )
46+  '_max_cacheable_statement_size'  ,  ' _ssl_context'
4747
4848 def  __init__ (self , protocol , transport , loop , addr , opts , * ,
4949 statement_cache_size , command_timeout ,
50-  max_cached_statement_lifetime , ssl_context ):
50+  max_cached_statement_lifetime ,
51+  max_cacheable_statement_size ,
52+  ssl_context ):
5153 self ._protocol  =  protocol 
5254 self ._transport  =  transport 
5355 self ._loop  =  loop 
@@ -61,6 +63,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6163 self ._opts  =  opts 
6264 self ._ssl_context  =  ssl_context 
6365
66+  self ._max_cacheable_statement_size  =  max_cacheable_statement_size 
6467 self ._stmt_cache  =  _StatementCache (
6568 loop = loop ,
6669 max_size = statement_cache_size ,
@@ -69,22 +72,6 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6972
7073 self ._stmts_to_close  =  set ()
7174
72-  if  command_timeout  is  not None :
73-  try :
74-  if  isinstance (command_timeout , bool ):
75-  raise  ValueError 
76- 
77-  command_timeout  =  float (command_timeout )
78- 
79-  if  command_timeout  <  0 :
80-  raise  ValueError 
81- 
82-  except  ValueError :
83-  raise  ValueError (
84-  'invalid command_timeout value: ' 
85-  'expected non-negative float (got {!r})' .format (
86-  command_timeout )) from  None 
87- 
8875 self ._command_timeout  =  command_timeout 
8976
9077 self ._listeners  =  {}
@@ -280,7 +267,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
280267 if  statement  is  not None :
281268 return  statement 
282269
283-  if  self ._stmt_cache .get_max_size () or  named :
270+  # Only use the cache when: 
271+  # * `statement_cache_size` is greater than 0; 
272+  # * query size is less than `max_cacheable_statement_size`. 
273+  use_cache  =  self ._stmt_cache .get_max_size () >  0 
274+  if  (use_cache  and 
275+  self ._max_cacheable_statement_size  and 
276+  len (query ) >  self ._max_cacheable_statement_size ):
277+  use_cache  =  False 
278+ 
279+  if  use_cache  or  named :
284280 stmt_name  =  self ._get_unique_id ('stmt' )
285281 else :
286282 stmt_name  =  '' 
@@ -295,7 +291,8 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
295291 types  =  await  self ._types_stmt .fetch (list (ready ))
296292 self ._protocol .get_settings ().register_data_types (types )
297293
298-  self ._stmt_cache .put (query , statement )
294+  if  use_cache :
295+  self ._stmt_cache .put (query , statement )
299296
300297 # If we've just created a new statement object, check if there 
301298 # are any statements for GC. 
@@ -721,6 +718,7 @@ async def connect(dsn=None, *,
721718 timeout = 60 ,
722719 statement_cache_size = 100 ,
723720 max_cached_statement_lifetime = 300 ,
721+  max_cacheable_statement_size = 1024  *  15 ,
724722 command_timeout = None ,
725723 ssl = None ,
726724 __connection_class__ = Connection ,
@@ -772,6 +770,11 @@ async def connect(dsn=None, *,
772770 in the cache. Pass ``0`` to allow statements be cached 
773771 indefinitely. 
774772
773+  :param int max_cacheable_statement_size: 
774+  the maximum size of a statement that can be cached (15KiB by 
775+  default). Pass ``0`` to allow all statements to be cached 
776+  regardless of their size. 
777+ 
775778 :param float command_timeout: 
776779 the default timeout for operations on this connection 
777780 (the default is no timeout). 
@@ -807,6 +810,29 @@ async def connect(dsn=None, *,
807810 if  loop  is  None :
808811 loop  =  asyncio .get_event_loop ()
809812
813+  local_vars  =  locals ()
814+  for  var_name  in  {'max_cacheable_statement_size' ,
815+  'max_cached_statement_lifetime' ,
816+  'statement_cache_size' }:
817+  var_val  =  local_vars [var_name ]
818+  if  var_val  is  None  or  isinstance (var_val , bool ) or  var_val  <  0 :
819+  raise  ValueError (
820+  '{} is expected to be greater ' 
821+  'or equal to 0, got {!r}' .format (var_name , var_val ))
822+ 
823+  if  command_timeout  is  not None :
824+  try :
825+  if  isinstance (command_timeout , bool ):
826+  raise  ValueError 
827+  command_timeout  =  float (command_timeout )
828+  if  command_timeout  <  0 :
829+  raise  ValueError 
830+  except  ValueError :
831+  raise  ValueError (
832+  'invalid command_timeout value: ' 
833+  'expected non-negative float (got {!r})' .format (
834+  command_timeout )) from  None 
835+ 
810836 addrs , opts  =  _parse_connect_params (
811837 dsn = dsn , host = host , port = port , user = user , password = password ,
812838 database = database , opts = opts )
@@ -855,6 +881,7 @@ async def connect(dsn=None, *,
855881 pr , tr , loop , addr , opts ,
856882 statement_cache_size = statement_cache_size ,
857883 max_cached_statement_lifetime = max_cached_statement_lifetime ,
884+  max_cacheable_statement_size = max_cacheable_statement_size ,
858885 command_timeout = command_timeout , ssl_context = ssl )
859886
860887 pr .set_connection (con )
0 commit comments