33import asyncio
44import dataclasses
55import inspect
6+ import contextlib
67from collections .abc import Awaitable
78from dataclasses import dataclass , field
89from typing import TYPE_CHECKING , Any , cast
@@ -225,6 +226,25 @@ def get_model_tracing_impl(
225226 else :
226227 return ModelTracing .ENABLED_WITHOUT_DATA
227228
229+ # --- NEW: helpers for cancellable tool execution ---
230+
231+ async def _await_cancellable (awaitable ):
232+ """Await an awaitable in its own task so CancelledError interrupts promptly."""
233+ task = asyncio .create_task (awaitable )
234+ try :
235+ return await task
236+ except asyncio .CancelledError :
237+ # propagate so run.py can handle terminal cancel
238+ raise
239+
240+ def _maybe_call_cancel_hook (tool_obj ) -> None :
241+ """Best-effort: call a cancel/terminate hook on the tool if present."""
242+ for name in ("cancel" , "terminate" , "stop" ):
243+ cb = getattr (tool_obj , name , None )
244+ if callable (cb ):
245+ with contextlib .suppress (Exception ):
246+ cb ()
247+ break
228248
229249class RunImpl :
230250 @classmethod
@@ -556,24 +576,26 @@ async def run_single_tool(
556576 if config .trace_include_sensitive_data :
557577 span_fn .span_data .input = tool_call .arguments
558578 try :
559- _ , _ , result = await asyncio .gather (
560- hooks .on_tool_start (tool_context , agent , func_tool ),
561- (
562- agent .hooks .on_tool_start (tool_context , agent , func_tool )
563- if agent .hooks
564- else _coro .noop_coroutine ()
565- ),
566- func_tool .on_invoke_tool (tool_context , tool_call .arguments ),
567- )
579+ # run start hooks first (don’t tie them to the cancellable task)
580+ await asyncio .gather (
581+ hooks .on_tool_start (tool_context , agent , func_tool ),
582+ (agent .hooks .on_tool_start (tool_context , agent , func_tool ) if agent .hooks else _coro .noop_coroutine ()),
583+ )
584+
585+ try :
586+ result = await _await_cancellable (
587+ func_tool .on_invoke_tool (tool_context , tool_call .arguments )
588+ )
589+ except asyncio .CancelledError :
590+ _maybe_call_cancel_hook (func_tool )
591+ raise
592+
593+ await asyncio .gather (
594+ hooks .on_tool_end (tool_context , agent , func_tool , result ),
595+ (agent .hooks .on_tool_end (tool_context , agent , func_tool , result ) if agent .hooks else _coro .noop_coroutine ()),
596+ )
597+
568598
569- await asyncio .gather (
570- hooks .on_tool_end (tool_context , agent , func_tool , result ),
571- (
572- agent .hooks .on_tool_end (tool_context , agent , func_tool , result )
573- if agent .hooks
574- else _coro .noop_coroutine ()
575- ),
576- )
577599 except Exception as e :
578600 _error_tracing .attach_error_to_current_span (
579601 SpanError (
@@ -643,44 +665,45 @@ async def execute_computer_actions(
643665 context_wrapper : RunContextWrapper [TContext ],
644666 config : RunConfig ,
645667 ) -> list [RunItem ]:
646- results : list [RunItem ] = []
647- # Need to run these serially, because each action can affect the computer state
648- for action in actions :
649- acknowledged : list [ComputerCallOutputAcknowledgedSafetyCheck ] | None = None
650- if action .tool_call .pending_safety_checks and action .computer_tool .on_safety_check :
651- acknowledged = []
652- for check in action .tool_call .pending_safety_checks :
653- data = ComputerToolSafetyCheckData (
654- ctx_wrapper = context_wrapper ,
655- agent = agent ,
656- tool_call = action .tool_call ,
657- safety_check = check ,
658- )
659- maybe = action .computer_tool .on_safety_check (data )
660- ack = await maybe if inspect .isawaitable (maybe ) else maybe
661- if ack :
662- acknowledged .append (
663- ComputerCallOutputAcknowledgedSafetyCheck (
664- id = check .id ,
665- code = check .code ,
666- message = check .message ,
667- )
668- )
669- else :
670- raise UserError ("Computer tool safety check was not acknowledged" )
671-
672- results .append (
673- await ComputerAction .execute (
674- agent = agent ,
675- action = action ,
676- hooks = hooks ,
677- context_wrapper = context_wrapper ,
678- config = config ,
679- acknowledged_safety_checks = acknowledged ,
680- )
681- )
682-
683- return results
668+ results : list [RunItem ] = []
669+ for action in actions :
670+ acknowledged : list [ComputerCallOutputAcknowledgedSafetyCheck ] | None = None
671+ if action .tool_call .pending_safety_checks and action .computer_tool .on_safety_check :
672+ acknowledged = []
673+ for check in action .tool_call .pending_safety_checks :
674+ data = ComputerToolSafetyCheckData (
675+ ctx_wrapper = context_wrapper ,
676+ agent = agent ,
677+ tool_call = action .tool_call ,
678+ safety_check = check ,
679+ )
680+ maybe = action .computer_tool .on_safety_check (data )
681+ ack = await maybe if inspect .isawaitable (maybe ) else maybe
682+ if ack :
683+ acknowledged .append (ComputerCallOutputAcknowledgedSafetyCheck (
684+ id = check .id , code = check .code , message = check .message
685+ ))
686+ else :
687+ raise UserError ("Computer tool safety check was not acknowledged" )
688+
689+ try :
690+ item = await _await_cancellable (
691+ ComputerAction .execute (
692+ agent = agent ,
693+ action = action ,
694+ hooks = hooks ,
695+ context_wrapper = context_wrapper ,
696+ config = config ,
697+ acknowledged_safety_checks = acknowledged ,
698+ )
699+ )
700+ except asyncio .CancelledError :
701+ _maybe_call_cancel_hook (action .computer_tool )
702+ raise
703+
704+ results .append (item )
705+
706+ return results
684707
685708 @classmethod
686709 async def execute_handoffs (
@@ -1052,16 +1075,23 @@ async def execute(
10521075 else cls ._get_screenshot_sync (action .computer_tool .computer , action .tool_call )
10531076 )
10541077
1055- _ , _ , output = await asyncio .gather (
1078+ # start hooks first
1079+ await asyncio .gather (
10561080 hooks .on_tool_start (context_wrapper , agent , action .computer_tool ),
10571081 (
10581082 agent .hooks .on_tool_start (context_wrapper , agent , action .computer_tool )
10591083 if agent .hooks
10601084 else _coro .noop_coroutine ()
10611085 ),
1062- output_func ,
10631086 )
1064-
1087+ # run the action (screenshot/etc) in a cancellable task
1088+ try :
1089+ output = await _await_cancellable (output_func )
1090+ except asyncio .CancelledError :
1091+ _maybe_call_cancel_hook (action .computer_tool )
1092+ raise
1093+
1094+ # end hooks
10651095 await asyncio .gather (
10661096 hooks .on_tool_end (context_wrapper , agent , action .computer_tool , output ),
10671097 (
@@ -1169,10 +1199,20 @@ async def execute(
11691199 data = call .tool_call ,
11701200 )
11711201 output = call .local_shell_tool .executor (request )
1172- if inspect .isawaitable (output ):
1173- result = await output
1174- else :
1175- result = output
1202+ try :
1203+ if inspect .isawaitable (output ):
1204+ result = await _await_cancellable (output )
1205+ else :
1206+ # If executor returns a sync result, just use it (can’t cancel mid-call)
1207+ result = output
1208+ except asyncio .CancelledError :
1209+ # Best-effort: if the executor or tool exposes a cancel/terminate / kill, call it
1210+ _maybe_call_cancel_hook (call .local_shell_tool )
1211+ # If your executor returns a proc handle (common pattern), adddress it here if needed:
1212+ # with contextlib.suppress(Exception):
1213+ # proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0)
1214+ # proc.kill()
1215+ raise
11761216
11771217 await asyncio .gather (
11781218 hooks .on_tool_end (context_wrapper , agent , call .local_shell_tool , result ),
@@ -1185,7 +1225,7 @@ async def execute(
11851225
11861226 return ToolCallOutputItem (
11871227 agent = agent ,
1188- output = output ,
1228+ output = result ,
11891229 raw_item = {
11901230 "type" : "local_shell_call_output" ,
11911231 "id" : call .tool_call .call_id ,
0 commit comments