Skip to content

Commit 9235a7b

Browse files
committed
pool: Track connections and prohibit using them after release.
Connection pool now wraps all connections in `PooledConnectionProxy` objects to raise `InterfaceError` if they are used after being released back to the pool. We also check if connection passed to `pool.release` actually belong to the pool and correctly handle multiple calls to `pool.release` with the same connection object. `PooledConnectionProxy` transparently wraps Connection instances, exposing all Connection public API. `isinstance(asyncpg.connection.Connection)` is `True` for Instances of `PooledConnectionProxy` class.
1 parent 537c8c9 commit 9235a7b

File tree

4 files changed

+202
-22
lines changed

4 files changed

+202
-22
lines changed

asyncpg/connection.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,17 @@
2121
from . import transaction
2222

2323

24-
class Connection:
24+
class _ConnectionProxy:
25+
pass
26+
27+
28+
class ConnectionMeta(type):
29+
def __instancecheck__(cls, instance):
30+
mro = type(instance).__mro__
31+
return Connection in mro or _ConnectionProxy in mro
32+
33+
34+
class Connection(metaclass=ConnectionMeta):
2535
"""A representation of a database session.
2636
2737
Connections are created by calling :func:`~asyncpg.connection.connect`.
@@ -32,7 +42,7 @@ class Connection:
3242
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
3343
'_addr', '_opts', '_command_timeout', '_listeners',
3444
'_server_version', '_server_caps', '_intro_query',
35-
'_reset_query')
45+
'_reset_query', '_proxy')
3646

3747
def __init__(self, protocol, transport, loop, addr, opts, *,
3848
statement_cache_size, command_timeout):
@@ -70,6 +80,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
7080
self._intro_query = introspection.INTRO_LOOKUP_TYPES
7181

7282
self._reset_query = None
83+
self._proxy = None
7384

7485
async def add_listener(self, channel, callback):
7586
"""Add a listener for Postgres notifications.
@@ -478,9 +489,13 @@ def _notify(self, pid, channel, payload):
478489
if channel not in self._listeners:
479490
return
480491

492+
con_ref = self._proxy
493+
if con_ref is None:
494+
con_ref = self
495+
481496
for cb in self._listeners[channel]:
482497
try:
483-
cb(self, pid, channel, payload)
498+
cb(con_ref, pid, channel, payload)
484499
except Exception as ex:
485500
self._loop.call_exception_handler({
486501
'message': 'Unhandled exception in asyncpg notification '
@@ -517,6 +532,9 @@ def _get_reset_query(self):
517532

518533
return _reset_query
519534

535+
def _set_proxy(self, proxy):
536+
self._proxy = proxy
537+
520538

521539
async def connect(dsn=None, *,
522540
host=None, port=None,
@@ -526,7 +544,7 @@ async def connect(dsn=None, *,
526544
timeout=60,
527545
statement_cache_size=100,
528546
command_timeout=None,
529-
connection_class=Connection,
547+
__connection_class__=Connection,
530548
**opts):
531549
"""A coroutine to establish a connection to a PostgreSQL server.
532550
@@ -564,11 +582,7 @@ async def connect(dsn=None, *,
564582
:param float command_timeout: the default timeout for operations on
565583
this connection (the default is no timeout).
566584
567-
:param builtins.type connection_class: A class used to represent
568-
the connection.
569-
Defaults to :class:`~asyncpg.connection.Connection`.
570-
571-
:return: A *connection_class* instance.
585+
:return: A :class:`~asyncpg.connection.Connection` instance.
572586
573587
Example:
574588
@@ -582,10 +596,6 @@ async def connect(dsn=None, *,
582596
... print(types)
583597
>>> asyncio.get_event_loop().run_until_complete(run())
584598
[<Record typname='bool' typnamespace=11 ...
585-
586-
587-
.. versionadded:: 0.10.0
588-
*connection_class* argument.
589599
"""
590600
if loop is None:
591601
loop = asyncio.get_event_loop()
@@ -629,9 +639,9 @@ async def connect(dsn=None, *,
629639
tr.close()
630640
raise
631641

632-
con = connection_class(pr, tr, loop, addr, opts,
633-
statement_cache_size=statement_cache_size,
634-
command_timeout=command_timeout)
642+
con = __connection_class__(pr, tr, loop, addr, opts,
643+
statement_cache_size=statement_cache_size,
644+
command_timeout=command_timeout)
635645
pr.set_connection(con)
636646
return con
637647

asyncpg/pool.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,78 @@
66

77

88
import asyncio
9+
import functools
910

1011
from . import connection
1112
from . import exceptions
1213

1314

15+
class PooledConnectionProxyMeta(type):
16+
17+
def __new__(mcls, name, bases, dct, *, wrap=False):
18+
if wrap:
19+
def get_wrapper(methname):
20+
meth = getattr(connection.Connection, methname)
21+
22+
def wrapper(self, *args, **kwargs):
23+
return self._dispatch(meth, args, kwargs)
24+
25+
return wrapper
26+
27+
for attrname in dir(connection.Connection):
28+
if attrname.startswith('_') or attrname in dct:
29+
continue
30+
wrapper = get_wrapper(attrname)
31+
wrapper = functools.update_wrapper(
32+
wrapper, getattr(connection.Connection, attrname))
33+
dct[attrname] = wrapper
34+
35+
if '__doc__' not in dct:
36+
dct['__doc__'] = connection.Connection.__doc__
37+
38+
return super().__new__(mcls, name, bases, dct)
39+
40+
41+
class PooledConnectionProxy(connection._ConnectionProxy,
42+
metaclass=PooledConnectionProxyMeta,
43+
wrap=True):
44+
45+
__slots__ = ('_in_pool', '_con', '_owner')
46+
47+
def __init__(self, owner: 'Pool', con: connection.Connection):
48+
self._in_pool = True
49+
self._con = con
50+
self._owner = owner
51+
con._set_proxy(self)
52+
53+
def _unwrap(self) -> connection.Connection:
54+
if not self._in_pool:
55+
raise exceptions.InterfaceError(
56+
'internal asyncpg error: cannot unwrap pooled connection')
57+
58+
self._in_pool = False
59+
con, self._con = self._con, None
60+
con._set_proxy(None)
61+
return con
62+
63+
def _dispatch(self, meth, args, kwargs):
64+
if not self._in_pool:
65+
raise exceptions.InterfaceError(
66+
'cannot call Connection.{}(): '
67+
'connection has been released back to the pool'.format(
68+
meth.__name__))
69+
70+
return meth(self._con, *args, **kwargs)
71+
72+
def __repr__(self):
73+
if self._con is None:
74+
return '<{classname} [released] {id:#x}>'.format(
75+
classname=self.__class__.__name__, id=id(self))
76+
else:
77+
return '<{classname} {con!r} {id:#x}>'.format(
78+
classname=self.__class__.__name__, con=self._con, id=id(self))
79+
80+
1481
class Pool:
1582
"""A connection pool.
1683
@@ -168,6 +235,8 @@ async def _acquire_impl(self):
168235
else:
169236
con = await self._queue.get()
170237

238+
con = PooledConnectionProxy(self, con)
239+
171240
if self._setup is not None:
172241
try:
173242
await self._setup(con)
@@ -179,6 +248,20 @@ async def _acquire_impl(self):
179248

180249
async def release(self, connection):
181250
"""Release a database connection back to the pool."""
251+
252+
if (connection.__class__ is not PooledConnectionProxy or
253+
connection._owner is not self):
254+
raise exceptions.InterfaceError(
255+
'Pool.release() received invalid connection: '
256+
'{connection!r} is not a member of this pool'.format(
257+
connection=connection))
258+
259+
if connection._con is None:
260+
# Already released, do nothing.
261+
return
262+
263+
connection = connection._unwrap()
264+
182265
# Use asyncio.shield() to guarantee that task cancellation
183266
# does not prevent the connection from being returned to the
184267
# pool properly.
@@ -325,6 +408,10 @@ def create_pool(dsn=None, *,
325408
:param loop: An asyncio event loop instance. If ``None``, the default
326409
event loop will be used.
327410
:return: An instance of :class:`~asyncpg.pool.Pool`.
411+
412+
.. versionchanged:: 0.10.0
413+
An :exc:`~asyncpg.exceptions.InterfaceError` will be raised on any
414+
attempted operation on a released connection.
328415
"""
329416
return Pool(dsn,
330417
min_size=min_size, max_size=max_size,

tests/test_connect.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import asyncpg
1515
from asyncpg import _testbase as tb
16-
from asyncpg.connection import _parse_connect_params
16+
from asyncpg import connection
1717
from asyncpg.serverversion import split_server_version_string
1818

1919
_system = platform.uname().system
@@ -355,7 +355,7 @@ def run_testcase(self, testcase):
355355
if expected_error:
356356
es.enter_context(self.assertRaisesRegex(*expected_error))
357357

358-
result = _parse_connect_params(
358+
result = connection._parse_connect_params(
359359
dsn=dsn, host=host, port=port, user=user, password=password,
360360
database=database, opts=opts)
361361

@@ -411,3 +411,11 @@ def test_test_connect_params_run_testcase(self):
411411
def test_connect_params(self):
412412
for testcase in self.TESTS:
413413
self.run_testcase(testcase)
414+
415+
416+
class TestConnection(tb.ConnectedTestCase):
417+
418+
async def test_connection_isinstance(self):
419+
self.assertTrue(isinstance(self.con, connection.Connection))
420+
self.assertTrue(isinstance(self.con, object))
421+
self.assertFalse(isinstance(self.con, list))

tests/test_pool.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77

88
import asyncio
9+
import asyncpg
10+
import inspect
911
import platform
1012
import os
1113
import unittest
@@ -35,7 +37,7 @@ async def reset(self):
3537
class SlowResetConnectionPool(pg_pool.Pool):
3638
async def _connect(self, *args, **kwargs):
3739
return await pg_connection.connect(
38-
*args, connection_class=SlowResetConnection, **kwargs)
40+
*args, __connection_class__=SlowResetConnection, **kwargs)
3941

4042

4143
class TestPool(tb.ConnectedTestCase):
@@ -88,7 +90,7 @@ async def test_pool_04(self):
8890
con.terminate()
8991
await pool.release(con)
9092

91-
async with pool.acquire(timeout=POOL_NOMINAL_TIMEOUT):
93+
async with pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) as con:
9294
con.terminate()
9395

9496
con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT)
@@ -127,7 +129,7 @@ async def test_pool_07(self):
127129
cons = set()
128130

129131
async def setup(con):
130-
if con not in cons:
132+
if con._con not in cons: # `con` is `PooledConnectionProxy`.
131133
raise RuntimeError('init was not called before setup')
132134

133135
async def init(con):
@@ -137,7 +139,7 @@ async def init(con):
137139

138140
async def user(pool):
139141
async with pool.acquire() as con:
140-
if con not in cons:
142+
if con._con not in cons: # `con` is `PooledConnectionProxy`.
141143
raise RuntimeError('init was not called')
142144

143145
async with self.create_pool(database='postgres',
@@ -150,6 +152,79 @@ async def user(pool):
150152

151153
self.assertEqual(len(cons), 5)
152154

155+
async def test_pool_08(self):
156+
pool = await self.create_pool(database='postgres',
157+
min_size=1, max_size=1)
158+
159+
con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT)
160+
with self.assertRaisesRegex(asyncpg.InterfaceError, 'is not a member'):
161+
await pool.release(con._con)
162+
163+
async def test_pool_09(self):
164+
pool1 = await self.create_pool(database='postgres',
165+
min_size=1, max_size=1)
166+
167+
pool2 = await self.create_pool(database='postgres',
168+
min_size=1, max_size=1)
169+
170+
con = await pool1.acquire(timeout=POOL_NOMINAL_TIMEOUT)
171+
with self.assertRaisesRegex(asyncpg.InterfaceError, 'is not a member'):
172+
await pool2.release(con)
173+
174+
await pool1.close()
175+
await pool2.close()
176+
177+
async def test_pool_10(self):
178+
pool = await self.create_pool(database='postgres',
179+
min_size=1, max_size=1)
180+
181+
con = await pool.acquire()
182+
await pool.release(con)
183+
await pool.release(con)
184+
185+
await pool.close()
186+
187+
async def test_pool_11(self):
188+
pool = await self.create_pool(database='postgres',
189+
min_size=1, max_size=1)
190+
191+
async with pool.acquire() as con:
192+
self.assertIn(repr(con._con), repr(con)) # Test __repr__.
193+
194+
self.assertIn('[released]', repr(con))
195+
196+
with self.assertRaisesRegex(
197+
asyncpg.InterfaceError,
198+
r'cannot call Connection\.execute.*released back to the pool'):
199+
200+
con.execute('select 1')
201+
202+
await pool.close()
203+
204+
async def test_pool_12(self):
205+
pool = await self.create_pool(database='postgres',
206+
min_size=1, max_size=1)
207+
208+
async with pool.acquire() as con:
209+
self.assertTrue(isinstance(con, pg_connection.Connection))
210+
self.assertFalse(isinstance(con, list))
211+
212+
await pool.close()
213+
214+
async def test_pool_13(self):
215+
pool = await self.create_pool(database='postgres',
216+
min_size=1, max_size=1)
217+
218+
async with pool.acquire() as con:
219+
self.assertIn('Execute an SQL command', con.execute.__doc__)
220+
self.assertEqual(con.execute.__name__, 'execute')
221+
222+
self.assertIn(
223+
str(inspect.signature(con.execute))[1:],
224+
str(inspect.signature(pg_connection.Connection.execute)))
225+
226+
await pool.close()
227+
153228
async def test_pool_auth(self):
154229
if not self.cluster.is_managed():
155230
self.skipTest('unmanaged cluster')

0 commit comments

Comments
 (0)