Skip to content

Commit 5342ca4

Browse files
fix: Task state is not persisted to task store after client disconnect (#472)
# Issue It's been described well in #464 : > When the client-side connection is terminated, the EventConsumer stops processing. As a result, any changes to the task state after the disconnection are not persisted to the TaskStore. The task itself continues running in the background, but its updated state is no longer reflected in the TaskStore. This has been addressed in this PR by simply adding a catch for `(asyncio.CancelledError, GeneratorExit)` in the `on_message_send_stream` method. However, adding that revealed a difference in semantics between Python 3.13+ and <3.13 for `EventQueue.close()`. I have also addressed that. # How it's reproduced [@azyobuzin](https://github.com/azyobuzin) provided a detailed guide on this in #464 . My only addition would be to add loggers for `a2a.server.events.event_queue` and `a2a.server.events.event_consumer` to get a better understanding of what's happening under the hood. # Fix ## Code - Ensure streaming continues persisting events after client disconnect via background consumption by adding a catch for `(asyncio.CancelledError, GeneratorExit)` in the `on_message_send_stream` method. - Align EventQueue.close() behavior on Python ≥3.13 and ≤3.12 (graceful vs. immediate). ## Tests ### Event queue tests (`tests/server/events/test_event_queue.py`) Added/updated tests to verify: - Graceful close on ≥3.13 waits for drain and children. - Immediate close clears queues and propagates. - To support Python 3.10, when simulating ≥3.13 using sys.version_info, inject a dummy queue.shutdown on asyncio.Queue so tests don’t fail on runtimes without it. I've seen this pattern used in existing tests too. ### Request handler tests (`tests/server/request_handlers/test_default_request_handler.py`) Added `test_disconnect_persists_final_task_to_store` which tests the flow described in issue #464 : - Starts streaming, yields first event, then simulates client disconnect. - Background consumer persists the final Task to `InMemoryTaskStore`. - Uses `wait_until` to await disappearance of the specific `background_consume:{task_id}` task, then asserts `TaskState.completed`. ### General Added a cleanup for lingering background tasks, I think it's an improvement for my earlier PR where I've tracked background tasks. # Misc Ruff `0.13.0` now fails the check for unused variables, made some minimal changes as suggested by the linter, irrelevant to the issue: underscores for `_payload` and `_task_manager`. Fixes #464
1 parent 5ec0788 commit 5342ca4

File tree

5 files changed

+183
-36
lines changed

5 files changed

+183
-36
lines changed

src/a2a/client/transports/rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def get_task(
206206
context: ClientCallContext | None = None,
207207
) -> Task:
208208
"""Retrieves the current state and history of a specific task."""
209-
payload, modified_kwargs = await self._apply_interceptors(
209+
_payload, modified_kwargs = await self._apply_interceptors(
210210
request.model_dump(mode='json', exclude_none=True),
211211
self._get_http_args(context),
212212
context,

src/a2a/server/events/event_queue.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,18 @@ def tap(self) -> 'EventQueue':
135135
async def close(self, immediate: bool = False) -> None:
136136
"""Closes the queue for future push events and also closes all child queues.
137137
138-
Once closed, no new events can be enqueued. For Python 3.13+, this will trigger
139-
`asyncio.QueueShutDown` when the queue is empty and a consumer tries to dequeue.
140-
For lower versions, the queue will be marked as closed and optionally cleared.
138+
Once closed, no new events can be enqueued. Behavior is consistent across
139+
Python versions:
140+
- Python >= 3.13: Uses `asyncio.Queue.shutdown` to stop the queue. With
141+
`immediate=True` the queue is shut down and pending events are cleared; with
142+
`immediate=False` the queue is shut down and we wait for it to drain via
143+
`queue.join()`.
144+
- Python < 3.13: Emulates the same semantics by clearing on `immediate=True`
145+
or awaiting `queue.join()` on `immediate=False`.
146+
147+
Consumers attempting to dequeue after close on an empty queue will observe
148+
`asyncio.QueueShutDown` on Python >= 3.13 and `asyncio.QueueEmpty` on
149+
Python < 3.13.
141150
142151
Args:
143152
immediate (bool):
@@ -152,23 +161,30 @@ async def close(self, immediate: bool = False) -> None:
152161
return
153162
if not self._is_closed:
154163
self._is_closed = True
155-
# If using python 3.13 or higher, use the shutdown method
164+
# If using python 3.13 or higher, use shutdown but match <3.13 semantics
156165
if sys.version_info >= (3, 13):
157-
self.queue.shutdown(immediate)
158-
for child in self._children:
159-
await child.close(immediate)
166+
if immediate:
167+
# Immediate: stop queue and clear any pending events, then close children
168+
self.queue.shutdown(True)
169+
await self.clear_events(True)
170+
for child in self._children:
171+
await child.close(True)
172+
return
173+
# Graceful: prevent further gets/puts via shutdown, then wait for drain and children
174+
self.queue.shutdown(False)
175+
await asyncio.gather(
176+
self.queue.join(), *(child.close() for child in self._children)
177+
)
160178
# Otherwise, join the queue
161179
else:
162180
if immediate:
163181
await self.clear_events(True)
164182
for child in self._children:
165183
await child.close(immediate)
166184
return
167-
tasks = [asyncio.create_task(self.queue.join())]
168-
tasks.extend(
169-
asyncio.create_task(child.close()) for child in self._children
185+
await asyncio.gather(
186+
self.queue.join(), *(child.close() for child in self._children)
170187
)
171-
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
172188

173189
def is_closed(self) -> bool:
174190
"""Checks if the queue is closed."""

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ async def on_message_send(
314314
result (Task or Message).
315315
"""
316316
(
317-
task_manager,
317+
_task_manager,
318318
task_id,
319319
queue,
320320
result_aggregator,
@@ -379,16 +379,16 @@ async def on_message_send_stream(
379379
by the agent.
380380
"""
381381
(
382-
task_manager,
382+
_task_manager,
383383
task_id,
384384
queue,
385385
result_aggregator,
386386
producer_task,
387387
) = await self._setup_message_execution(params, context)
388+
consumer = EventConsumer(queue)
389+
producer_task.add_done_callback(consumer.agent_task_callback)
388390

389391
try:
390-
consumer = EventConsumer(queue)
391-
producer_task.add_done_callback(consumer.agent_task_callback)
392392
async for event in result_aggregator.consume_and_emit(consumer):
393393
if isinstance(event, Task):
394394
self._validate_task_id_match(task_id, event.id)
@@ -397,6 +397,14 @@ async def on_message_send_stream(
397397
task_id, result_aggregator
398398
)
399399
yield event
400+
except (asyncio.CancelledError, GeneratorExit):
401+
# Client disconnected: continue consuming and persisting events in the background
402+
bg_task = asyncio.create_task(
403+
result_aggregator.consume_all(consumer)
404+
)
405+
bg_task.set_name(f'background_consume:{task_id}')
406+
self._track_background_task(bg_task)
407+
raise
400408
finally:
401409
cleanup_task = asyncio.create_task(
402410
self._cleanup_producer(producer_task, task_id)

tests/server/events/test_event_queue.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -271,15 +271,7 @@ async def test_tap_creates_child_queue(event_queue: EventQueue) -> None:
271271

272272

273273
@pytest.mark.asyncio
274-
@patch(
275-
'asyncio.wait'
276-
) # To monitor calls to asyncio.wait for older Python versions
277-
@patch(
278-
'asyncio.create_task'
279-
) # To monitor calls to asyncio.create_task for older Python versions
280274
async def test_close_sets_flag_and_handles_internal_queue_old_python(
281-
mock_create_task: MagicMock,
282-
mock_asyncio_wait: AsyncMock,
283275
event_queue: EventQueue,
284276
) -> None:
285277
"""Test close behavior on Python < 3.13 (using queue.join)."""
@@ -290,24 +282,47 @@ async def test_close_sets_flag_and_handles_internal_queue_old_python(
290282
await event_queue.close()
291283

292284
assert event_queue.is_closed() is True
293-
event_queue.queue.join.assert_called_once() # specific to <3.13
294-
mock_create_task.assert_called_once() # create_task for join
295-
mock_asyncio_wait.assert_called_once() # wait for join
285+
event_queue.queue.join.assert_awaited_once() # waited for drain
296286

297287

298288
@pytest.mark.asyncio
299289
async def test_close_sets_flag_and_handles_internal_queue_new_python(
300290
event_queue: EventQueue,
301291
) -> None:
302292
"""Test close behavior on Python >= 3.13 (using queue.shutdown)."""
303-
with patch('sys.version_info', (3, 13, 0)): # Simulate Python 3.13+
304-
# Mock queue.shutdown as it's called in newer versions
305-
event_queue.queue.shutdown = MagicMock() # shutdown is not async
293+
with patch('sys.version_info', (3, 13, 0)):
294+
# Inject a dummy shutdown method for non-3.13 runtimes
295+
from typing import cast
306296

297+
queue = cast('Any', event_queue.queue)
298+
queue.shutdown = MagicMock() # type: ignore[attr-defined]
307299
await event_queue.close()
308-
309300
assert event_queue.is_closed() is True
310-
event_queue.queue.shutdown.assert_called_once() # specific to >=3.13
301+
queue.shutdown.assert_called_once_with(False)
302+
303+
304+
@pytest.mark.asyncio
305+
async def test_close_graceful_py313_waits_for_join_and_children(
306+
event_queue: EventQueue,
307+
) -> None:
308+
"""For Python >=3.13 and immediate=False, close should shutdown(False), then wait for join and children."""
309+
with patch('sys.version_info', (3, 13, 0)):
310+
# Arrange
311+
from typing import cast
312+
313+
q_any = cast('Any', event_queue.queue)
314+
q_any.shutdown = MagicMock() # type: ignore[attr-defined]
315+
event_queue.queue.join = AsyncMock()
316+
317+
child = event_queue.tap()
318+
child.close = AsyncMock()
319+
320+
# Act
321+
await event_queue.close(immediate=False)
322+
323+
# Assert
324+
event_queue.queue.join.assert_awaited_once()
325+
child.close.assert_awaited_once()
311326

312327

313328
@pytest.mark.asyncio
@@ -345,15 +360,18 @@ async def test_close_idempotent(event_queue: EventQueue) -> None:
345360

346361
# Reset for new Python version test
347362
event_queue_new = EventQueue() # New queue for fresh state
348-
with patch('sys.version_info', (3, 13, 0)): # Test with newer version logic
349-
event_queue_new.queue.shutdown = MagicMock()
363+
with patch('sys.version_info', (3, 13, 0)):
364+
from typing import cast
365+
366+
queue = cast('Any', event_queue_new.queue)
367+
queue.shutdown = MagicMock() # type: ignore[attr-defined]
350368
await event_queue_new.close()
351369
assert event_queue_new.is_closed() is True
352-
event_queue_new.queue.shutdown.assert_called_once()
370+
queue.shutdown.assert_called_once()
353371

354372
await event_queue_new.close()
355373
assert event_queue_new.is_closed() is True
356-
event_queue_new.queue.shutdown.assert_called_once() # Still only called once
374+
queue.shutdown.assert_called_once() # Still only called once
357375

358376

359377
@pytest.mark.asyncio

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import logging
34
import time
45

@@ -48,6 +49,7 @@
4849
TaskQueryParams,
4950
TaskState,
5051
TaskStatus,
52+
TaskStatusUpdateEvent,
5153
TextPart,
5254
UnsupportedOperationError,
5355
)
@@ -1331,6 +1333,15 @@ async def single_event_stream():
13311333
mock_result_aggregator_instance.consume_and_emit.return_value = (
13321334
single_event_stream()
13331335
)
1336+
# Signal when background consume_all is started
1337+
bg_started = asyncio.Event()
1338+
1339+
async def mock_consume_all(_consumer):
1340+
bg_started.set()
1341+
# emulate short-running background work
1342+
await asyncio.sleep(0)
1343+
1344+
mock_result_aggregator_instance.consume_all = mock_consume_all
13341345

13351346
produced_task: asyncio.Task | None = None
13361347
cleanup_task: asyncio.Task | None = None
@@ -1367,6 +1378,9 @@ def create_task_spy(coro):
13671378
assert produced_task is not None
13681379
assert cleanup_task is not None
13691380

1381+
# Assert background consume_all started
1382+
await asyncio.wait_for(bg_started.wait(), timeout=0.2)
1383+
13701384
# execute should have started
13711385
await asyncio.wait_for(execute_started.wait(), timeout=0.1)
13721386

@@ -1385,6 +1399,91 @@ def create_task_spy(coro):
13851399
# Running agents is cleared
13861400
assert task_id not in request_handler._running_agents
13871401

1402+
# Cleanup any lingering background tasks started by on_message_send_stream
1403+
# (e.g., background_consume)
1404+
for t in list(request_handler._background_tasks):
1405+
t.cancel()
1406+
with contextlib.suppress(asyncio.CancelledError):
1407+
await t
1408+
1409+
1410+
@pytest.mark.asyncio
1411+
async def test_disconnect_persists_final_task_to_store():
1412+
"""After client disconnect, ensure background consumer persists final Task to store."""
1413+
task_store = InMemoryTaskStore()
1414+
queue_manager = InMemoryQueueManager()
1415+
1416+
# Custom agent that emits a working update then a completed final update
1417+
class FinishingAgent(AgentExecutor):
1418+
def __init__(self):
1419+
self.allow_finish = asyncio.Event()
1420+
1421+
async def execute(
1422+
self, context: RequestContext, event_queue: EventQueue
1423+
):
1424+
from typing import cast
1425+
1426+
updater = TaskUpdater(
1427+
event_queue,
1428+
cast('str', context.task_id),
1429+
cast('str', context.context_id),
1430+
)
1431+
await updater.update_status(TaskState.working)
1432+
await self.allow_finish.wait()
1433+
await updater.update_status(TaskState.completed)
1434+
1435+
async def cancel(
1436+
self, context: RequestContext, event_queue: EventQueue
1437+
):
1438+
return None
1439+
1440+
agent = FinishingAgent()
1441+
1442+
handler = DefaultRequestHandler(
1443+
agent_executor=agent, task_store=task_store, queue_manager=queue_manager
1444+
)
1445+
1446+
params = MessageSendParams(
1447+
message=Message(
1448+
role=Role.user,
1449+
message_id='msg_persist',
1450+
parts=[],
1451+
)
1452+
)
1453+
1454+
# Start streaming and consume the first event (working)
1455+
agen = handler.on_message_send_stream(params, create_server_call_context())
1456+
first = await agen.__anext__()
1457+
if isinstance(first, TaskStatusUpdateEvent):
1458+
assert first.status.state == TaskState.working
1459+
task_id = first.task_id
1460+
else:
1461+
assert (
1462+
isinstance(first, Task) and first.status.state == TaskState.working
1463+
)
1464+
task_id = first.id
1465+
1466+
# Disconnect client
1467+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1468+
1469+
# Finish agent and allow background consumer to persist final state
1470+
agent.allow_finish.set()
1471+
1472+
# Wait until background_consume task for this task_id is gone
1473+
await wait_until(
1474+
lambda: all(
1475+
not t.get_name().startswith(f'background_consume:{task_id}')
1476+
for t in handler._background_tasks
1477+
),
1478+
timeout=1.0,
1479+
interval=0.01,
1480+
)
1481+
1482+
# Verify task is persisted as completed
1483+
persisted = await task_store.get(task_id, create_server_call_context())
1484+
assert persisted is not None
1485+
assert persisted.status.state == TaskState.completed
1486+
13881487

13891488
async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0):
13901489
"""Await until predicate() is True or timeout elapses."""
@@ -1505,6 +1604,12 @@ def create_task_spy(coro):
15051604
timeout=0.1,
15061605
)
15071606

1607+
# Cleanup any lingering background tasks
1608+
for t in list(request_handler._background_tasks):
1609+
t.cancel()
1610+
with contextlib.suppress(asyncio.CancelledError):
1611+
await t
1612+
15081613

15091614
@pytest.mark.asyncio
15101615
async def test_on_message_send_stream_task_id_mismatch():

0 commit comments

Comments
 (0)