Skip to content

Commit ccedb5e

Browse files
authored
Merge pull request #8 from vrtnis/codex/fix-asyncio.cancellederror-handling-in-stream
Handle cooperative cancellation in streamed runs
2 parents 40d6d71 + f408536 commit ccedb5e

File tree

2 files changed

+77
-59
lines changed

2 files changed

+77
-59
lines changed

src/agents/result.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,10 @@ class RunResultStreaming(RunResultBase):
145145
"""Whether the agent has finished running."""
146146

147147
_emit_status_events: bool = False
148-
"""Whether to emit RunUpdatedStreamEvent status updates (default False for backward compatibility)."""
148+
"""Whether to emit RunUpdatedStreamEvent status updates.
149149
150+
Defaults to False for backward compatibility.
151+
"""
150152

151153
# Queues that the background run_loop writes to
152154
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
@@ -169,24 +171,18 @@ def last_agent(self) -> Agent[Any]:
169171
"""
170172
return self.current_agent
171173

172-
173174
def cancel(self, reason: str | None = None) -> None:
174175
# 1) Signal cooperative cancel to the runner
175176
active = getattr(self, "_active_run", None)
176177
if active:
177178
with contextlib.suppress(Exception):
178179
active.cancel(reason)
179-
180-
# 2) Wake any stream_events() consumer immediately
181-
with contextlib.suppress(Exception):
182-
self._event_queue.put_nowait(QueueCompleteSentinel())
183-
184-
# 3) Do NOT cancel the background task; let the loop unwind cooperatively
180+
# 2) Do NOT cancel the background task; let the loop unwind cooperatively
185181
# task = getattr(self, "_run_impl_task", None)
186182
# if task and not task.done():
187183
# with contextlib.suppress(Exception):
188184
# task.cancel()
189-
185+
190186
# 4) Mark complete; flushing only when status events are disabled
191187
self.is_complete = True
192188
if not getattr(self, "_emit_status_events", False):
@@ -199,7 +195,6 @@ def cancel(self, reason: str | None = None) -> None:
199195
self._input_guardrail_queue.get_nowait()
200196
self._input_guardrail_queue.task_done()
201197

202-
203198
def inject(self, items: list[TResponseInputItem]) -> None:
204199
"""
205200
Inject new input items mid-run. They will be consumed at the start of the next step.
@@ -210,12 +205,11 @@ def inject(self, items: list[TResponseInputItem]) -> None:
210205
active.inject(items)
211206
except Exception:
212207
pass
213-
208+
214209
@property
215210
def active_run(self):
216211
"""Access the underlying ActiveRun handle (may be None early in startup)."""
217212
return getattr(self, "_active_run", None)
218-
219213

220214
async def stream_events(self) -> AsyncIterator[StreamEvent]:
221215
"""Stream deltas for new items as they are generated. We're using the types from the
@@ -284,21 +278,33 @@ def _check_errors(self):
284278
# Check the tasks for any exceptions
285279
if self._run_impl_task and self._run_impl_task.done():
286280
run_impl_exc = self._run_impl_task.exception()
287-
if run_impl_exc and isinstance(run_impl_exc, Exception):
281+
if (
282+
run_impl_exc
283+
and isinstance(run_impl_exc, Exception)
284+
and not isinstance(run_impl_exc, asyncio.CancelledError)
285+
):
288286
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
289287
run_impl_exc.run_data = self._create_error_details()
290288
self._stored_exception = run_impl_exc
291289

292290
if self._input_guardrails_task and self._input_guardrails_task.done():
293291
in_guard_exc = self._input_guardrails_task.exception()
294-
if in_guard_exc and isinstance(in_guard_exc, Exception):
292+
if (
293+
in_guard_exc
294+
and isinstance(in_guard_exc, Exception)
295+
and not isinstance(in_guard_exc, asyncio.CancelledError)
296+
):
295297
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
296298
in_guard_exc.run_data = self._create_error_details()
297299
self._stored_exception = in_guard_exc
298300

299301
if self._output_guardrails_task and self._output_guardrails_task.done():
300302
out_guard_exc = self._output_guardrails_task.exception()
301-
if out_guard_exc and isinstance(out_guard_exc, Exception):
303+
if (
304+
out_guard_exc
305+
and isinstance(out_guard_exc, Exception)
306+
and not isinstance(out_guard_exc, asyncio.CancelledError)
307+
):
302308
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
303309
out_guard_exc.run_data = self._create_error_details()
304310
self._stored_exception = out_guard_exc

0 commit comments

Comments
 (0)