Skip to content

Commit af3c9fc

Browse files
committed
refactor: replace cancel_streamed_run tests with run_lifecycle tests
1 parent ccedb5e commit af3c9fc

File tree

2 files changed

+201
-109
lines changed

2 files changed

+201
-109
lines changed

tests/test_cancel_streamed_run.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

tests/test_run_lifecycle.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import asyncio
2+
import pytest
3+
from collections.abc import AsyncIterator
4+
5+
from agents.run import Runner, RunConfig
6+
from agents.model_settings import ModelSettings
7+
from agents.models.interface import Model, ModelTracing
8+
from agents.items import TResponseInputItem
9+
from agents.agent_output import AgentOutputSchemaBase
10+
from agents.tool import Tool
11+
from agents.handoffs import Handoff
12+
from openai.types.responses import ResponseStreamEvent
13+
from agents.stream_events import RunUpdatedStreamEvent
14+
15+
16+
# -------------------------
17+
# Minimal fakes used across tests
18+
# -------------------------
19+
20+
class FakeModel(Model):
21+
"""Never completes; yields generic events forever so we can cancel mid-stream."""
22+
async def get_response(self, *a, **k):
23+
raise NotImplementedError # non-stream path not used here
24+
25+
async def stream_response(
26+
self,
27+
system_instructions: str | None,
28+
input: str | list[TResponseInputItem],
29+
model_settings: ModelSettings,
30+
tools: list[Tool],
31+
output_schema: AgentOutputSchemaBase | None,
32+
handoffs: list[Handoff],
33+
tracing: ModelTracing,
34+
previous_response_id: str | None,
35+
prompt=None,
36+
) -> AsyncIterator[ResponseStreamEvent]:
37+
while True:
38+
await asyncio.sleep(0.02)
39+
# Not a ResponseCompletedEvent; runner will keep streaming
40+
yield object()
41+
42+
43+
class MinimalAgent:
44+
"""Just enough surface for Runner."""
45+
def __init__(self, model: Model, name: str = "test-agent"):
46+
self.name = name
47+
self.model = model
48+
self.model_settings = ModelSettings()
49+
self.output_type = None
50+
self.hooks = None
51+
self.handoffs = []
52+
self.reset_tool_choice = False
53+
self.input_guardrails: list = []
54+
self.output_guardrails: list = []
55+
56+
async def get_system_prompt(self, _): return None
57+
async def get_prompt(self, _): return None
58+
async def get_all_tools(self, _): return []
59+
60+
61+
# -------------------------
62+
# Tests
63+
# -------------------------
64+
65+
@pytest.mark.anyio
66+
async def test_cancel_streamed_run_emits_cancelled_status():
67+
"""When status events are enabled, cancel should emit run.updated(cancelled)."""
68+
agent = MinimalAgent(model=FakeModel())
69+
run_config = RunConfig(model=agent.model)
70+
71+
result = Runner.run_streamed(
72+
starting_agent=agent,
73+
input="hello world",
74+
run_config=run_config,
75+
max_turns=10,
76+
)
77+
# Opt-in to status events for this test
78+
result._emit_status_events = True
79+
80+
seen_status = None
81+
82+
async def consume():
83+
nonlocal seen_status
84+
async for ev in result.stream_events():
85+
if isinstance(ev, RunUpdatedStreamEvent):
86+
seen_status = ev.status
87+
88+
consumer = asyncio.create_task(consume())
89+
90+
await asyncio.sleep(0.08) # allow a couple of ticks
91+
result.cancel("user-requested")
92+
await consumer
93+
94+
assert result.is_complete is True
95+
assert seen_status == "cancelled"
96+
97+
98+
@pytest.mark.anyio
99+
async def test_default_flag_off_emits_no_status_event():
100+
"""By default, no run.updated events should be emitted (back-compat)."""
101+
agent = MinimalAgent(model=FakeModel())
102+
result = Runner.run_streamed(agent, input="x", run_config=RunConfig(model=agent.model))
103+
# DO NOT set result._emit_status_events here
104+
statuses = []
105+
106+
async def consume():
107+
async for ev in result.stream_events():
108+
if isinstance(ev, RunUpdatedStreamEvent):
109+
statuses.append(ev.status)
110+
111+
task = asyncio.create_task(consume())
112+
await asyncio.sleep(0.05)
113+
result.cancel("user")
114+
await task
115+
116+
assert statuses == [] # no run.updated by default
117+
118+
119+
@pytest.mark.anyio
120+
async def test_midstream_cancel_emits_cancelled_status_when_enabled():
121+
"""Cancel while model is streaming yields cancelled when flag is on."""
122+
agent = MinimalAgent(model=FakeModel())
123+
result = Runner.run_streamed(agent, input="x", run_config=RunConfig(model=agent.model))
124+
result._emit_status_events = True
125+
statuses = []
126+
127+
async def consume():
128+
async for ev in result.stream_events():
129+
if isinstance(ev, RunUpdatedStreamEvent):
130+
statuses.append(ev.status)
131+
132+
task = asyncio.create_task(consume())
133+
await asyncio.sleep(0.06)
134+
result.cancel("user")
135+
await task
136+
137+
assert "cancelled" in statuses
138+
139+
140+
@pytest.mark.anyio
141+
async def test_inject_is_consumed_on_next_turn():
142+
"""
143+
Injected items should be included in a subsequent model turn input.
144+
We capture the inputs passed into FakeModel each turn and assert presence.
145+
"""
146+
INJECT_TOKEN = {"role": "user", "content": "INJECTED"} # match message-style items
147+
148+
class FakeModelCapture(Model):
149+
def __init__(self):
150+
self.inputs = [] # list[list[dict]]
151+
152+
async def get_response(self, *a, **k): # non-stream path not used
153+
raise NotImplementedError
154+
155+
async def stream_response(
156+
self, system_instructions, input, model_settings, tools,
157+
output_schema, handoffs, tracing, previous_response_id, prompt=None
158+
) -> AsyncIterator[ResponseStreamEvent]:
159+
# Keep streaming so we never hit the "no final response" error.
160+
while True:
161+
# record the input for this turn
162+
self.inputs.append(list(input))
163+
# emit one event to complete a turn
164+
yield object()
165+
await asyncio.sleep(0.01)
166+
167+
model = FakeModelCapture()
168+
agent = MinimalAgent(model=model)
169+
170+
result = Runner.run_streamed(
171+
starting_agent=agent,
172+
input="hello", # first turn will contain this
173+
run_config=RunConfig(model=agent.model),
174+
max_turns=6,
175+
)
176+
177+
async def drive_and_inject():
178+
# Let at least one turn record baseline input
179+
await asyncio.sleep(0.05)
180+
# Inject so a future turn sees it
181+
result.inject([INJECT_TOKEN])
182+
# Give the runner time to execute another turn (or two) that should include the injection
183+
await asyncio.sleep(0.12)
184+
result.cancel("done")
185+
186+
consumer = asyncio.create_task(drive_and_inject())
187+
async for _ in result.stream_events():
188+
pass
189+
await consumer
190+
191+
# We should have recorded ≥2 turns
192+
assert len(model.inputs) >= 2
193+
194+
# Assert the injected message appears in ANY turn after injection time
195+
flattened_after_injection = [item for turn in model.inputs[1:] for item in turn]
196+
assert any(
197+
isinstance(item, dict)
198+
and item.get("role") == "user"
199+
and item.get("content") == "INJECTED"
200+
for item in flattened_after_injection
201+
), f"Injected item not present after injection; captured={model.inputs}"

0 commit comments

Comments
 (0)