Skip to content

Commit 6fb2cb1

Browse files
committed
Reset Connection._top_xact; disallow managed xacts inside manual ones
1 parent bdef8fc commit 6fb2cb1

File tree

6 files changed

+73
-18
lines changed

6 files changed

+73
-18
lines changed

asyncpg/_testbase.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import inspect
1313
import logging
1414
import os
15+
import re
1516
import time
1617
import unittest
1718

@@ -95,6 +96,30 @@ def assertRunUnder(self, delta):
9596
raise AssertionError(
9697
'running block took longer than {}'.format(delta))
9798

99+
@contextlib.contextmanager
100+
def assertLoopErrorHandlerCalled(self, msg_re: str):
101+
contexts = []
102+
103+
def handler(loop, ctx):
104+
contexts.append(ctx)
105+
106+
old_handler = self.loop.get_exception_handler()
107+
self.loop.set_exception_handler(handler)
108+
try:
109+
yield
110+
111+
for ctx in contexts:
112+
msg = ctx.get('message')
113+
if msg and re.search(msg_re, msg):
114+
return
115+
116+
raise AssertionError(
117+
'no message matching {!r} was logged with '
118+
'loop.call_exception_handler()'.format(msg_re))
119+
120+
finally:
121+
self.loop.set_exception_handler(old_handler)
122+
98123

99124
_default_cluster = None
100125

asyncpg/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,11 +513,12 @@ def _get_reset_query(self):
513513
caps = self._server_caps
514514

515515
_reset_query = []
516-
if self._protocol.is_in_transaction():
516+
if self._protocol.is_in_transaction() or self._top_xact is not None:
517517
self._loop.call_exception_handler({
518518
'message': 'Resetting connection with an '
519519
'active transaction {!r}'.format(self)
520520
})
521+
self._top_xact = None
521522
_reset_query.append('ROLLBACK;')
522523
if caps.advisory_locks:
523524
_reset_query.append('SELECT pg_advisory_unlock_all();')

asyncpg/transaction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ async def start(self):
8484
con = self._connection
8585

8686
if con._top_xact is None:
87+
if con._protocol.is_in_transaction():
88+
raise apg_errors.InterfaceError(
89+
'cannot use Connection.transaction() in '
90+
'a manually started transaction')
8791
con._top_xact = self
8892
else:
8993
# Nested transaction block

tests/test_pool.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -353,32 +353,28 @@ async def test_pool_release_in_xact(self):
353353
"""Test that Connection.reset() closes any open transaction."""
354354
async with self.create_pool(database='postgres',
355355
min_size=1, max_size=1) as pool:
356-
357-
last_error_ctx = None
358-
359356
async def get_xact_id(con):
360357
return await con.fetchval('select txid_current()')
361358

362-
def logger(loop, ctx):
363-
nonlocal last_error_ctx
364-
last_error_ctx = ctx
365-
366-
async with pool.acquire() as con:
367-
id1 = await get_xact_id(con)
359+
with self.assertLoopErrorHandlerCalled('an active transaction'):
360+
async with pool.acquire() as con:
361+
real_con = con._con # unwrap PoolConnectionProxy
368362

369-
tr = con.transaction()
370-
await tr.start()
363+
id1 = await get_xact_id(con)
371364

372-
id2 = await get_xact_id(con)
373-
self.assertNotEqual(id1, id2)
365+
tr = con.transaction()
366+
self.assertIsNone(con._con._top_xact)
367+
await tr.start()
368+
self.assertIs(real_con._top_xact, tr)
374369

375-
self.loop.set_exception_handler(logger)
370+
id2 = await get_xact_id(con)
371+
self.assertNotEqual(id1, id2)
376372

377-
self.assertIsNotNone(last_error_ctx)
378-
self.assertIn('an active transaction', last_error_ctx['message'])
379-
self.loop.set_exception_handler(None)
373+
self.assertIsNone(real_con._top_xact)
380374

381375
async with pool.acquire() as con:
376+
self.assertIs(con._con, real_con)
377+
self.assertIsNone(con._con._top_xact)
382378
id3 = await get_xact_id(con)
383379
self.assertNotEqual(id2, id3)
384380

tests/test_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,14 @@ def test_tests_fail_1(self):
3333
suite.run(result)
3434

3535
self.assertIn('ZeroDivisionError', result.errors[0][1])
36+
37+
38+
class TestHelpers(tb.TestCase):
39+
40+
async def test_tests_assertLoopErrorHandlerCalled_01(self):
41+
with self.assertRaisesRegex(AssertionError, r'no message.*was logged'):
42+
with self.assertLoopErrorHandlerCalled('aa'):
43+
self.loop.call_exception_handler({'message': 'bb a bb'})
44+
45+
with self.assertLoopErrorHandlerCalled('aa'):
46+
self.loop.call_exception_handler({'message': 'bbaabb'})

tests/test_transaction.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,21 @@ async def test_transaction_interface_errors(self):
139139
async with tr:
140140
async with tr:
141141
pass
142+
143+
async def test_transaction_within_manual_transaction(self):
144+
self.assertIsNone(self.con._top_xact)
145+
146+
await self.con.execute('BEGIN')
147+
148+
tr = self.con.transaction()
149+
self.assertIsNone(self.con._top_xact)
150+
151+
with self.assertRaisesRegex(asyncpg.InterfaceError,
152+
'cannot use Connection.transaction'):
153+
await tr.start()
154+
155+
with self.assertLoopErrorHandlerCalled(
156+
'Resetting connection with an active transaction'):
157+
await self.con.reset()
158+
159+
self.assertIsNone(self.con._top_xact)

0 commit comments

Comments
 (0)