Skip to content
Prev Previous commit
Next Next commit
Return error details for streaming + tests
  • Loading branch information
DanieleMorotti authored May 24, 2025
commit f083c48e8a870e83036663a69acb41c9dba71ec7
3 changes: 2 additions & 1 deletion src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MaxTurnsExceeded,
ModelBehaviorError,
OutputGuardrailTripwireTriggered,
RunErrorDetails,
UserError,
)
from .guardrail import (
Expand Down Expand Up @@ -44,7 +45,7 @@
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
from .models.openai_provider import OpenAIProvider
from .models.openai_responses import OpenAIResponsesModel
from .result import RunErrorDetails, RunResult, RunResultStreaming
from .result import RunResult, RunResultStreaming
from .run import RunConfig, Runner
from .run_context import RunContextWrapper, TContext
from .stream_events import (
Expand Down
36 changes: 31 additions & 5 deletions src/agents/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,49 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from .agent import Agent
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .result import RunErrorDetails
from .items import ModelResponse, RunItem, TResponseInputItem
from .run_context import RunContextWrapper

from .util._pretty_print import pretty_print_run_error_details


@dataclass
class RunErrorDetails:
"""Data collected from an agent run when an exception occurs."""
input: str | list[TResponseInputItem]
new_items: list[RunItem]
raw_responses: list[ModelResponse]
last_agent: Agent[Any]
context_wrapper: RunContextWrapper[Any]
input_guardrail_results: list[InputGuardrailResult]
output_guardrail_results: list[OutputGuardrailResult]

def __str__(self) -> str:
return pretty_print_run_error_details(self)


class AgentsException(Exception):
"""Base class for all exceptions in the Agents SDK."""
run_data: RunErrorDetails | None

def __init__(self, *args: object) -> None:
super().__init__(*args)
self.run_data = None


class MaxTurnsExceeded(AgentsException):
"""Exception raised when the maximum number of turns is exceeded."""

message: str
run_error_details: RunErrorDetails | None

def __init__(self, message: str, run_error_details: RunErrorDetails | None = None):
def __init__(self, message: str):
self.message = message
self.run_error_details = run_error_details
super().__init__(message)


class ModelBehaviorError(AgentsException):
Expand All @@ -31,6 +55,7 @@ class ModelBehaviorError(AgentsException):

def __init__(self, message: str):
self.message = message
super().__init__(message)


class UserError(AgentsException):
Expand All @@ -40,6 +65,7 @@ class UserError(AgentsException):

def __init__(self, message: str):
self.message = message
super().__init__(message)


class InputGuardrailTripwireTriggered(AgentsException):
Expand Down
111 changes: 67 additions & 44 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from ._run_impl import QueueCompleteSentinel
from .agent import Agent
from .agent_output import AgentOutputSchemaBase
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .exceptions import (
AgentsException,
InputGuardrailTripwireTriggered,
MaxTurnsExceeded,
RunErrorDetails,
)
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
Expand All @@ -20,7 +25,6 @@
from .tracing import Trace
from .util._pretty_print import (
pretty_print_result,
pretty_print_run_error_details,
pretty_print_run_result_streaming,
)

Expand Down Expand Up @@ -212,29 +216,79 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:

def _check_errors(self):
if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
max_turns_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code seems ~identical to lines 236-243 and the ones before, thoughts on extracting into self._create_error_details?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. I also noticed it but I was waiting for your feedback. I'll change this.

Do I also need to add an explanation of these changes to the documentation?

)
self._stored_exception = max_turns_exc

# Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result)
tripwire_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = tripwire_exc

# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():
exc = self._run_impl_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
run_impl_exc = self._run_impl_task.exception()
if run_impl_exc and isinstance(run_impl_exc, Exception):
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
run_impl_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = run_impl_exc

if self._input_guardrails_task and self._input_guardrails_task.done():
exc = self._input_guardrails_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
in_guard_exc = self._input_guardrails_task.exception()
if in_guard_exc and isinstance(in_guard_exc, Exception):
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
in_guard_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = in_guard_exc

if self._output_guardrails_task and self._output_guardrails_task.done():
exc = self._output_guardrails_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
out_guard_exc = self._output_guardrails_task.exception()
if out_guard_exc and isinstance(out_guard_exc, Exception):
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
out_guard_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = out_guard_exc

def _cleanup_tasks(self):
if self._run_impl_task and not self._run_impl_task.done():
Expand All @@ -249,34 +303,3 @@ def _cleanup_tasks(self):
def __str__(self) -> str:
return pretty_print_run_result_streaming(self)


@dataclass
class RunErrorDetails:
input: str | list[TResponseInputItem]
"""The original input items i.e. the items before run() was called. This may be a mutated
version of the input, if there are handoff input filters that mutate the input.
"""

new_items: list[RunItem]
"""The new items generated during the agent run. These include things like new messages, tool
calls and their outputs, etc.
"""

raw_responses: list[ModelResponse]
"""The raw LLM responses generated by the model during the agent run."""

input_guardrail_results: list[InputGuardrailResult]
"""Guardrail results for the input messages."""

context_wrapper: RunContextWrapper[Any]
"""The context wrapper for the agent run."""

_last_agent: Agent[Any]

@property
def last_agent(self) -> Agent[Any]:
"""The last agent that was run."""
return self._last_agent

def __str__(self) -> str:
return pretty_print_run_error_details(self)
34 changes: 21 additions & 13 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MaxTurnsExceeded,
ModelBehaviorError,
OutputGuardrailTripwireTriggered,
RunErrorDetails,
)
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
from .handoffs import Handoff, HandoffInputFilter, handoff
Expand All @@ -36,7 +37,7 @@
from .model_settings import ModelSettings
from .models.interface import Model, ModelProvider
from .models.multi_provider import MultiProvider
from .result import RunErrorDetails, RunResult, RunResultStreaming
from .result import RunResult, RunResultStreaming
from .run_context import RunContextWrapper, TContext
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
from .tool import Tool
Expand Down Expand Up @@ -286,23 +287,17 @@ async def run(
raise AgentsException(
f"Unknown next step type: {type(turn_result.next_step)}"
)
except Exception as e:
run_error_details = RunErrorDetails(
except AgentsException as exc:
exc.run_data = RunErrorDetails(
input=original_input,
new_items=generated_items,
raw_responses=model_responses,
input_guardrail_results=input_guardrail_results,
last_agent=current_agent,
context_wrapper=context_wrapper,
_last_agent=current_agent
input_guardrail_results=input_guardrail_results,
output_guardrail_results=[]
)
# Re-raise with the error details
if isinstance(e, MaxTurnsExceeded):
raise MaxTurnsExceeded(
f"Max turns ({max_turns}) exceeded",
run_error_details
) from e
else:
raise
raise
finally:
if current_span:
current_span.finish(reset_current=True)
Expand Down Expand Up @@ -629,6 +624,19 @@ async def _run_streamed_impl(
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
elif isinstance(turn_result.next_step, NextStepRunAgain):
pass
except AgentsException as exc:
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
exc.run_data = RunErrorDetails(
input=streamed_result.input,
new_items=streamed_result.new_items,
raw_responses=streamed_result.raw_responses,
last_agent=current_agent,
context_wrapper=context_wrapper,
input_guardrail_results=streamed_result.input_guardrail_results,
output_guardrail_results=streamed_result.output_guardrail_results,
)
raise
except Exception as e:
if current_span:
_error_tracing.attach_error_to_span(
Expand Down
5 changes: 4 additions & 1 deletion src/agents/util/_pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from pydantic import BaseModel

if TYPE_CHECKING:
from ..result import RunErrorDetails, RunResult, RunResultBase, RunResultStreaming
from ..exceptions import RunErrorDetails
from ..result import RunResult, RunResultBase, RunResultStreaming


def _indent(text: str, indent_level: int) -> str:
Expand Down Expand Up @@ -37,6 +38,7 @@ def pretty_print_result(result: "RunResult") -> str:

return output


def pretty_print_run_error_details(result: "RunErrorDetails") -> str:
output = "RunErrorDetails:"
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
Expand All @@ -47,6 +49,7 @@ def pretty_print_run_error_details(result: "RunErrorDetails") -> str:

return output


def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
output = "RunResultStreaming:"
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'
Expand Down
44 changes: 44 additions & 0 deletions tests/test_run_error_details.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json

import pytest

from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner

from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message


@pytest.mark.asyncio
async def test_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
with pytest.raises(MaxTurnsExceeded) as exc:
await Runner.run(agent, input="hello", max_turns=1)
data = exc.value.run_data
assert isinstance(data, RunErrorDetails)
assert data.last_agent == agent
assert len(data.raw_responses) == 1
assert len(data.new_items) > 0


@pytest.mark.asyncio
async def test_streamed_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
result = Runner.run_streamed(agent, input="hello", max_turns=1)
with pytest.raises(MaxTurnsExceeded) as exc:
async for _ in result.stream_events():
pass
data = exc.value.run_data
assert isinstance(data, RunErrorDetails)
assert data.last_agent == agent
assert len(data.raw_responses) == 1
assert len(data.new_items) > 0
4 changes: 0 additions & 4 deletions tests/test_tracing_errors_streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ async def test_tool_call_error():
"children": [
{
"type": "agent",
"error": {
"message": "Error in agent run",
"data": {"error": "Invalid JSON input for tool foo: bad_json"},
},
"data": {
"name": "test_agent",
"handoffs": [],
Expand Down