Skip to content

Commit ebbb6bb

Browse files
committed
Set timestamps on AG-UI events (#3734)
Adds timestamps to all AG-UI events to fix ordering and timing issues for frontends consuming the event stream. - Override handle_event() to set timestamps on all transformed events - Explicitly set timestamps on lifecycle events (RunStarted, RunFinished, RunError) - Add test_timestamps_are_set() to verify all events have valid timestamps - Update all test assertions to expect timestamp fields
1 parent 2e96d12 commit ebbb6bb

File tree

2 files changed

+336
-93
lines changed

2 files changed

+336
-93
lines changed

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from dataclasses import dataclass, field
1212
from typing import Final
1313

14+
from ..._utils import now_utc
1415
from ...messages import (
1516
BuiltinToolCallPart,
1617
BuiltinToolReturnPart,
@@ -26,7 +27,7 @@
2627
)
2728
from ...output import OutputDataT
2829
from ...tools import AgentDepsT
29-
from .. import SSE_CONTENT_TYPE, UIEventStream
30+
from .. import SSE_CONTENT_TYPE, NativeEvent, UIEventStream
3031

3132
try:
3233
from ag_ui.core import (
@@ -86,10 +87,22 @@ def content_type(self) -> str:
8687
def encode_event(self, event: BaseEvent) -> str:
8788
return self._event_encoder.encode(event)
8889

90+
@staticmethod
91+
def _get_now_utc_milliseconds() -> int:
92+
return int(now_utc().timestamp() * 1_000)
93+
94+
async def handle_event(self, event: NativeEvent) -> AsyncIterator[BaseEvent]:
95+
"""Override to set timestamps on all AG-UI events."""
96+
async for agui_event in super().handle_event(event):
97+
if agui_event.timestamp is None:
98+
agui_event.timestamp = self._get_now_utc_milliseconds()
99+
yield agui_event
100+
89101
async def before_stream(self) -> AsyncIterator[BaseEvent]:
90102
yield RunStartedEvent(
91103
thread_id=self.run_input.thread_id,
92104
run_id=self.run_input.run_id,
105+
timestamp=self._get_now_utc_milliseconds(),
93106
)
94107

95108
async def before_response(self) -> AsyncIterator[BaseEvent]:
@@ -104,11 +117,12 @@ async def after_stream(self) -> AsyncIterator[BaseEvent]:
104117
yield RunFinishedEvent(
105118
thread_id=self.run_input.thread_id,
106119
run_id=self.run_input.run_id,
120+
timestamp=self._get_now_utc_milliseconds(),
107121
)
108122

109123
async def on_error(self, error: Exception) -> AsyncIterator[BaseEvent]:
110124
self._error = True
111-
yield RunErrorEvent(message=str(error))
125+
yield RunErrorEvent(message=str(error), timestamp=self._get_now_utc_milliseconds())
112126

113127
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseEvent]:
114128
if follows_text:

0 commit comments

Comments
 (0)