Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ class Connection(metaclass=ConnectionMeta):
'_addr', '_opts', '_command_timeout', '_listeners',
'_server_version', '_server_caps', '_intro_query',
'_reset_query', '_proxy', '_stmt_exclusive_section',
'_ssl_context')
'_max_cacheable_statement_size', '_ssl_context')

def __init__(self, protocol, transport, loop, addr, opts, *,
statement_cache_size, command_timeout,
max_cached_statement_lifetime, ssl_context):
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl_context):
self._protocol = protocol
self._transport = transport
self._loop = loop
Expand All @@ -61,6 +63,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
self._opts = opts
self._ssl_context = ssl_context

self._max_cacheable_statement_size = max_cacheable_statement_size
self._stmt_cache = _StatementCache(
loop=loop,
max_size=statement_cache_size,
Expand All @@ -69,22 +72,6 @@ def __init__(self, protocol, transport, loop, addr, opts, *,

self._stmts_to_close = set()

if command_timeout is not None:
try:
if isinstance(command_timeout, bool):
raise ValueError

command_timeout = float(command_timeout)

if command_timeout < 0:
raise ValueError

except ValueError:
raise ValueError(
'invalid command_timeout value: '
'expected non-negative float (got {!r})'.format(
command_timeout)) from None

self._command_timeout = command_timeout

self._listeners = {}
Expand Down Expand Up @@ -280,7 +267,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
if statement is not None:
return statement

if self._stmt_cache.get_max_size() or named:
# Only use the cache when:
# * `statement_cache_size` is greater than 0;
# * query size is less than `max_cacheable_statement_size`.
use_cache = self._stmt_cache.get_max_size() > 0
if (use_cache and
self._max_cacheable_statement_size and
len(query) > self._max_cacheable_statement_size):
use_cache = False

if use_cache or named:
stmt_name = self._get_unique_id('stmt')
else:
stmt_name = ''
Expand All @@ -295,7 +291,8 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
types = await self._types_stmt.fetch(list(ready))
self._protocol.get_settings().register_data_types(types)

self._stmt_cache.put(query, statement)
if use_cache:
self._stmt_cache.put(query, statement)

# If we've just created a new statement object, check if there
# are any statements for GC.
Expand Down Expand Up @@ -721,6 +718,7 @@ async def connect(dsn=None, *,
timeout=60,
statement_cache_size=100,
max_cached_statement_lifetime=300,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
__connection_class__=Connection,
Expand Down Expand Up @@ -772,6 +770,11 @@ async def connect(dsn=None, *,
in the cache. Pass ``0`` to allow statements be cached
indefinitely.

:param int max_cacheable_statement_size:
the maximum size of a statement that can be cached (15KiB by
default). Pass ``0`` to allow all statements to be cached
regardless of their size.

:param float command_timeout:
the default timeout for operations on this connection
(the default is no timeout).
Expand Down Expand Up @@ -807,6 +810,29 @@ async def connect(dsn=None, *,
if loop is None:
loop = asyncio.get_event_loop()

local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
'statement_cache_size'}:
var_val = local_vars[var_name]
if var_val is None or isinstance(var_val, bool) or var_val < 0:
raise ValueError(
'{} is expected to be greater '
'or equal to 0, got {!r}'.format(var_name, var_val))

if command_timeout is not None:
try:
if isinstance(command_timeout, bool):
raise ValueError
command_timeout = float(command_timeout)
if command_timeout < 0:
raise ValueError
except ValueError:
raise ValueError(
'invalid command_timeout value: '
'expected non-negative float (got {!r})'.format(
command_timeout)) from None

addrs, opts = _parse_connect_params(
dsn=dsn, host=host, port=port, user=user, password=password,
database=database, opts=opts)
Expand Down Expand Up @@ -855,6 +881,7 @@ async def connect(dsn=None, *,
pr, tr, loop, addr, opts,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
command_timeout=command_timeout, ssl_context=ssl)

pr.set_connection(con)
Expand Down
14 changes: 13 additions & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def test_auth_unsupported(self):
pass


class TestConnectParams(unittest.TestCase):
class TestConnectParams(tb.TestCase):

TESTS = [
{
Expand Down Expand Up @@ -421,6 +421,18 @@ def test_connect_params(self):
for testcase in self.TESTS:
self.run_testcase(testcase)

async def test_connect_args_validation(self):
for val in {-1, 'a', True, False}:
with self.assertRaisesRegex(ValueError, 'non-negative'):
await asyncpg.connect(command_timeout=val)

for arg in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
'statement_cache_size'}:
for val in {None, -1, True, False}:
with self.assertRaisesRegex(ValueError, 'greater or equal'):
await asyncpg.connect(**{arg: val})


class TestConnection(tb.ConnectedTestCase):

Expand Down
21 changes: 21 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,24 @@ async def test_prepare_26_max_lifetime_max_size(self):

# Check that nothing crashes after the initial timeout
await asyncio.sleep(1, loop=self.loop)

@tb.with_connection_options(max_cacheable_statement_size=50)
async def test_prepare_27_max_cacheable_statement_size(self):
cache = self.con._stmt_cache

await self.con.prepare('SELECT 1')
self.assertEqual(len(cache), 1)

# Test that long and explicitly created prepared statements
# are not cached.
await self.con.prepare("SELECT \'" + "a" * 50 + "\'")
self.assertEqual(len(cache), 1)

# Test that implicitly created long prepared statements
# are not cached.
await self.con.fetchval("SELECT \'" + "a" * 50 + "\'")
self.assertEqual(len(cache), 1)

# Test that short prepared statements can still be cached.
await self.con.prepare('SELECT 2')
self.assertEqual(len(cache), 2)