Skip to content

Commit 40d6d71

Browse files
committed
feat(run): add cooperative cancellation, inject, and status event handling
- Added cooperative cancel support to RunResultStreaming - Introduced flag for backward compatibility - Differentiated vs run statuses - Updated internal run loop to handle injected items mid-run - Added initial tests (test_cancel_streamed_run.py) for cancellation behavior
1 parent df95141 commit 40d6d71

File tree

6 files changed

+569
-107
lines changed

6 files changed

+569
-107
lines changed

src/agents/_run_impl.py

Lines changed: 103 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import dataclasses
55
import inspect
6+
import contextlib
67
from collections.abc import Awaitable
78
from dataclasses import dataclass, field
89
from typing import TYPE_CHECKING, Any, cast
@@ -225,6 +226,25 @@ def get_model_tracing_impl(
225226
else:
226227
return ModelTracing.ENABLED_WITHOUT_DATA
227228

229+
# --- NEW: helpers for cancellable tool execution ---
230+
231+
async def _await_cancellable(awaitable):
232+
"""Await an awaitable in its own task so CancelledError interrupts promptly."""
233+
task = asyncio.create_task(awaitable)
234+
try:
235+
return await task
236+
except asyncio.CancelledError:
237+
# propagate so run.py can handle terminal cancel
238+
raise
239+
240+
def _maybe_call_cancel_hook(tool_obj) -> None:
241+
"""Best-effort: call a cancel/terminate hook on the tool if present."""
242+
for name in ("cancel", "terminate", "stop"):
243+
cb = getattr(tool_obj, name, None)
244+
if callable(cb):
245+
with contextlib.suppress(Exception):
246+
cb()
247+
break
228248

229249
class RunImpl:
230250
@classmethod
@@ -556,24 +576,26 @@ async def run_single_tool(
556576
if config.trace_include_sensitive_data:
557577
span_fn.span_data.input = tool_call.arguments
558578
try:
559-
_, _, result = await asyncio.gather(
560-
hooks.on_tool_start(tool_context, agent, func_tool),
561-
(
562-
agent.hooks.on_tool_start(tool_context, agent, func_tool)
563-
if agent.hooks
564-
else _coro.noop_coroutine()
565-
),
566-
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
567-
)
579+
# run start hooks first (don’t tie them to the cancellable task)
580+
await asyncio.gather(
581+
hooks.on_tool_start(tool_context, agent, func_tool),
582+
(agent.hooks.on_tool_start(tool_context, agent, func_tool) if agent.hooks else _coro.noop_coroutine()),
583+
)
584+
585+
try:
586+
result = await _await_cancellable(
587+
func_tool.on_invoke_tool(tool_context, tool_call.arguments)
588+
)
589+
except asyncio.CancelledError:
590+
_maybe_call_cancel_hook(func_tool)
591+
raise
592+
593+
await asyncio.gather(
594+
hooks.on_tool_end(tool_context, agent, func_tool, result),
595+
(agent.hooks.on_tool_end(tool_context, agent, func_tool, result) if agent.hooks else _coro.noop_coroutine()),
596+
)
597+
568598

569-
await asyncio.gather(
570-
hooks.on_tool_end(tool_context, agent, func_tool, result),
571-
(
572-
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
573-
if agent.hooks
574-
else _coro.noop_coroutine()
575-
),
576-
)
577599
except Exception as e:
578600
_error_tracing.attach_error_to_current_span(
579601
SpanError(
@@ -643,44 +665,45 @@ async def execute_computer_actions(
643665
context_wrapper: RunContextWrapper[TContext],
644666
config: RunConfig,
645667
) -> list[RunItem]:
646-
results: list[RunItem] = []
647-
# Need to run these serially, because each action can affect the computer state
648-
for action in actions:
649-
acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
650-
if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
651-
acknowledged = []
652-
for check in action.tool_call.pending_safety_checks:
653-
data = ComputerToolSafetyCheckData(
654-
ctx_wrapper=context_wrapper,
655-
agent=agent,
656-
tool_call=action.tool_call,
657-
safety_check=check,
658-
)
659-
maybe = action.computer_tool.on_safety_check(data)
660-
ack = await maybe if inspect.isawaitable(maybe) else maybe
661-
if ack:
662-
acknowledged.append(
663-
ComputerCallOutputAcknowledgedSafetyCheck(
664-
id=check.id,
665-
code=check.code,
666-
message=check.message,
667-
)
668-
)
669-
else:
670-
raise UserError("Computer tool safety check was not acknowledged")
671-
672-
results.append(
673-
await ComputerAction.execute(
674-
agent=agent,
675-
action=action,
676-
hooks=hooks,
677-
context_wrapper=context_wrapper,
678-
config=config,
679-
acknowledged_safety_checks=acknowledged,
680-
)
681-
)
682-
683-
return results
668+
results: list[RunItem] = []
669+
for action in actions:
670+
acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
671+
if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
672+
acknowledged = []
673+
for check in action.tool_call.pending_safety_checks:
674+
data = ComputerToolSafetyCheckData(
675+
ctx_wrapper=context_wrapper,
676+
agent=agent,
677+
tool_call=action.tool_call,
678+
safety_check=check,
679+
)
680+
maybe = action.computer_tool.on_safety_check(data)
681+
ack = await maybe if inspect.isawaitable(maybe) else maybe
682+
if ack:
683+
acknowledged.append(ComputerCallOutputAcknowledgedSafetyCheck(
684+
id=check.id, code=check.code, message=check.message
685+
))
686+
else:
687+
raise UserError("Computer tool safety check was not acknowledged")
688+
689+
try:
690+
item = await _await_cancellable(
691+
ComputerAction.execute(
692+
agent=agent,
693+
action=action,
694+
hooks=hooks,
695+
context_wrapper=context_wrapper,
696+
config=config,
697+
acknowledged_safety_checks=acknowledged,
698+
)
699+
)
700+
except asyncio.CancelledError:
701+
_maybe_call_cancel_hook(action.computer_tool)
702+
raise
703+
704+
results.append(item)
705+
706+
return results
684707

685708
@classmethod
686709
async def execute_handoffs(
@@ -1052,16 +1075,23 @@ async def execute(
10521075
else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call)
10531076
)
10541077

1055-
_, _, output = await asyncio.gather(
1078+
# start hooks first
1079+
await asyncio.gather(
10561080
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
10571081
(
10581082
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
10591083
if agent.hooks
10601084
else _coro.noop_coroutine()
10611085
),
1062-
output_func,
10631086
)
1064-
1087+
# run the action (screenshot/etc) in a cancellable task
1088+
try:
1089+
output = await _await_cancellable(output_func)
1090+
except asyncio.CancelledError:
1091+
_maybe_call_cancel_hook(action.computer_tool)
1092+
raise
1093+
1094+
# end hooks
10651095
await asyncio.gather(
10661096
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
10671097
(
@@ -1169,10 +1199,20 @@ async def execute(
11691199
data=call.tool_call,
11701200
)
11711201
output = call.local_shell_tool.executor(request)
1172-
if inspect.isawaitable(output):
1173-
result = await output
1174-
else:
1175-
result = output
1202+
try:
1203+
if inspect.isawaitable(output):
1204+
result = await _await_cancellable(output)
1205+
else:
1206+
# If executor returns a sync result, just use it (can’t cancel mid-call)
1207+
result = output
1208+
except asyncio.CancelledError:
1209+
# Best-effort: if the executor or tool exposes a cancel/terminate / kill, call it
1210+
_maybe_call_cancel_hook(call.local_shell_tool)
1211+
# If your executor returns a proc handle (common pattern), adddress it here if needed:
1212+
# with contextlib.suppress(Exception):
1213+
# proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0)
1214+
# proc.kill()
1215+
raise
11761216

11771217
await asyncio.gather(
11781218
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
@@ -1185,7 +1225,7 @@ async def execute(
11851225

11861226
return ToolCallOutputItem(
11871227
agent=agent,
1188-
output=output,
1228+
output=result,
11891229
raw_item={
11901230
"type": "local_shell_call_output",
11911231
"id": call.tool_call.call_id,

src/agents/models/openai_responses.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
import asyncio
45
from collections.abc import AsyncIterator
56
from dataclasses import dataclass
67
from typing import TYPE_CHECKING, Any, Literal, cast, overload
@@ -171,16 +172,31 @@ async def stream_response(
171172
)
172173

173174
final_response: Response | None = None
174-
175-
async for chunk in stream:
176-
if isinstance(chunk, ResponseCompletedEvent):
177-
final_response = chunk.response
178-
yield chunk
179-
175+
176+
try:
177+
async for chunk in stream: # type: ignore[arg-type] # ensure type checkers relax here
178+
if isinstance(chunk, ResponseCompletedEvent):
179+
final_response = chunk.response
180+
yield chunk
181+
except asyncio.CancelledError:
182+
# Cooperative cancel: ensure the HTTP stream is closed, then propagate
183+
try:
184+
await stream.aclose()
185+
except Exception:
186+
pass
187+
raise
188+
finally:
189+
# Always close the stream if the async iterator exits (normal or error)
190+
try:
191+
await stream.aclose()
192+
except Exception:
193+
pass
194+
180195
if final_response and tracing.include_data():
181196
span_response.span_data.response = final_response
182197
span_response.span_data.input = input
183-
198+
199+
184200
except Exception as e:
185201
span_response.set_error(
186202
SpanError(

src/agents/result.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
import asyncio
5+
import contextlib
56
from collections.abc import AsyncIterator
67
from dataclasses import dataclass, field
78
from typing import TYPE_CHECKING, Any, cast
@@ -143,6 +144,10 @@ class RunResultStreaming(RunResultBase):
143144
is_complete: bool = False
144145
"""Whether the agent has finished running."""
145146

147+
_emit_status_events: bool = False
148+
"""Whether to emit RunUpdatedStreamEvent status updates (default False for backward compatibility)."""
149+
150+
146151
# Queues that the background run_loop writes to
147152
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
148153
default_factory=asyncio.Queue, repr=False
@@ -164,17 +169,53 @@ def last_agent(self) -> Agent[Any]:
164169
"""
165170
return self.current_agent
166171

167-
def cancel(self) -> None:
168-
"""Cancels the streaming run, stopping all background tasks and marking the run as
169-
complete."""
170-
self._cleanup_tasks() # Cancel all running tasks
171-
self.is_complete = True # Mark the run as complete to stop event streaming
172-
173-
# Optionally, clear the event queue to prevent processing stale events
174-
while not self._event_queue.empty():
175-
self._event_queue.get_nowait()
176-
while not self._input_guardrail_queue.empty():
177-
self._input_guardrail_queue.get_nowait()
172+
173+
def cancel(self, reason: str | None = None) -> None:
174+
# 1) Signal cooperative cancel to the runner
175+
active = getattr(self, "_active_run", None)
176+
if active:
177+
with contextlib.suppress(Exception):
178+
active.cancel(reason)
179+
180+
# 2) Wake any stream_events() consumer immediately
181+
with contextlib.suppress(Exception):
182+
self._event_queue.put_nowait(QueueCompleteSentinel())
183+
184+
# 3) Do NOT cancel the background task; let the loop unwind cooperatively
185+
# task = getattr(self, "_run_impl_task", None)
186+
# if task and not task.done():
187+
# with contextlib.suppress(Exception):
188+
# task.cancel()
189+
190+
# 4) Mark complete; flushing only when status events are disabled
191+
self.is_complete = True
192+
if not getattr(self, "_emit_status_events", False):
193+
with contextlib.suppress(Exception):
194+
while not self._event_queue.empty():
195+
self._event_queue.get_nowait()
196+
self._event_queue.task_done()
197+
with contextlib.suppress(Exception):
198+
while not self._input_guardrail_queue.empty():
199+
self._input_guardrail_queue.get_nowait()
200+
self._input_guardrail_queue.task_done()
201+
202+
203+
def inject(self, items: list[TResponseInputItem]) -> None:
204+
"""
205+
Inject new input items mid-run. They will be consumed at the start of the next step.
206+
"""
207+
active = getattr(self, "_active_run", None)
208+
if active is not None:
209+
try:
210+
active.inject(items)
211+
except Exception:
212+
pass
213+
214+
@property
215+
def active_run(self):
216+
"""Access the underlying ActiveRun handle (may be None early in startup)."""
217+
return getattr(self, "_active_run", None)
218+
178219

179220
async def stream_events(self) -> AsyncIterator[StreamEvent]:
180221
"""Stream deltas for new items as they are generated. We're using the types from the

0 commit comments

Comments
 (0)