Skip to content

Commit 630cd37

Browse files
authored
fix: openai#1900 fix a bug where SQLAlchemySession could return items in an invalid order (openai#1917)
1 parent d9f1d5f commit 630cd37

File tree

2 files changed

+252
-3
lines changed

2 files changed

+252
-3
lines changed

src/agents/extensions/memory/sqlalchemy_session.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,21 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
195195
stmt = (
196196
select(self._messages.c.message_data)
197197
.where(self._messages.c.session_id == self.session_id)
198-
.order_by(self._messages.c.created_at.asc())
198+
.order_by(
199+
self._messages.c.created_at.asc(),
200+
self._messages.c.id.asc(),
201+
)
199202
)
200203
else:
201204
stmt = (
202205
select(self._messages.c.message_data)
203206
.where(self._messages.c.session_id == self.session_id)
204207
# Use DESC + LIMIT to get the latest N
205208
# then reverse later for chronological order.
206-
.order_by(self._messages.c.created_at.desc())
209+
.order_by(
210+
self._messages.c.created_at.desc(),
211+
self._messages.c.id.desc(),
212+
)
207213
.limit(limit)
208214
)
209215

@@ -278,7 +284,10 @@ async def pop_item(self) -> TResponseInputItem | None:
278284
subq = (
279285
select(self._messages.c.id)
280286
.where(self._messages.c.session_id == self.session_id)
281-
.order_by(self._messages.c.created_at.desc())
287+
.order_by(
288+
self._messages.c.created_at.desc(),
289+
self._messages.c.id.desc(),
290+
)
282291
.limit(1)
283292
)
284293
res = await sess.execute(subq)

tests/extensions/memory/test_sqlalchemy_session.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
from __future__ import annotations
22

3+
import json
4+
from collections.abc import Iterable, Sequence
5+
from contextlib import asynccontextmanager
6+
from datetime import datetime, timedelta
7+
from typing import Any, cast
8+
39
import pytest
10+
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
11+
from openai.types.responses.response_output_text_param import ResponseOutputTextParam
12+
from openai.types.responses.response_reasoning_item_param import (
13+
ResponseReasoningItemParam,
14+
Summary,
15+
)
16+
from sqlalchemy import select, text, update
17+
from sqlalchemy.sql import Select
418

519
pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed
620

@@ -16,6 +30,40 @@
1630
DB_URL = "sqlite+aiosqlite:///:memory:"
1731

1832

33+
def _make_message_item(item_id: str, text_value: str) -> TResponseInputItem:
34+
content: ResponseOutputTextParam = {
35+
"type": "output_text",
36+
"text": text_value,
37+
"annotations": [],
38+
}
39+
message: ResponseOutputMessageParam = {
40+
"id": item_id,
41+
"type": "message",
42+
"role": "assistant",
43+
"status": "completed",
44+
"content": [content],
45+
}
46+
return cast(TResponseInputItem, message)
47+
48+
49+
def _make_reasoning_item(item_id: str, summary_text: str) -> TResponseInputItem:
50+
summary: Summary = {"type": "summary_text", "text": summary_text}
51+
reasoning: ResponseReasoningItemParam = {
52+
"id": item_id,
53+
"type": "reasoning",
54+
"summary": [summary],
55+
}
56+
return cast(TResponseInputItem, reasoning)
57+
58+
59+
def _item_ids(items: Sequence[TResponseInputItem]) -> list[str]:
60+
result: list[str] = []
61+
for item in items:
62+
item_dict = cast(dict[str, Any], item)
63+
result.append(cast(str, item_dict["id"]))
64+
return result
65+
66+
1967
@pytest.fixture
2068
def agent() -> Agent:
2169
"""Fixture for a basic agent with a fake model."""
@@ -151,3 +199,195 @@ async def test_add_empty_items_list():
151199

152200
items_after_add = await session.get_items()
153201
assert len(items_after_add) == 0
202+
203+
204+
async def test_get_items_same_timestamp_consistent_order():
205+
"""Test that items with identical timestamps keep insertion order."""
206+
session_id = "same_timestamp_test"
207+
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
208+
209+
older_item = _make_message_item("older_same_ts", "old")
210+
reasoning_item = _make_reasoning_item("rs_same_ts", "...")
211+
message_item = _make_message_item("msg_same_ts", "...")
212+
await session.add_items([older_item])
213+
await session.add_items([reasoning_item, message_item])
214+
215+
async with session._session_factory() as sess:
216+
rows = await sess.execute(
217+
select(session._messages.c.id, session._messages.c.message_data).where(
218+
session._messages.c.session_id == session.session_id
219+
)
220+
)
221+
id_map = {
222+
json.loads(message_json)["id"]: row_id
223+
for row_id, message_json in rows.fetchall()
224+
}
225+
shared = datetime(2025, 10, 15, 17, 26, 39, 132483)
226+
older = shared - timedelta(milliseconds=1)
227+
await sess.execute(
228+
update(session._messages)
229+
.where(session._messages.c.id.in_(
230+
[
231+
id_map["rs_same_ts"],
232+
id_map["msg_same_ts"],
233+
]
234+
))
235+
.values(created_at=shared)
236+
)
237+
await sess.execute(
238+
update(session._messages)
239+
.where(session._messages.c.id == id_map["older_same_ts"])
240+
.values(created_at=older)
241+
)
242+
await sess.commit()
243+
244+
real_factory = session._session_factory
245+
246+
class FakeResult:
247+
def __init__(self, rows: Iterable[Any]):
248+
self._rows = list(rows)
249+
250+
def all(self) -> list[Any]:
251+
return list(self._rows)
252+
253+
def needs_shuffle(statement: Any) -> bool:
254+
if not isinstance(statement, Select):
255+
return False
256+
orderings = list(statement._order_by_clause)
257+
if not orderings:
258+
return False
259+
id_asc = session._messages.c.id.asc()
260+
id_desc = session._messages.c.id.desc()
261+
262+
def references_id(clause) -> bool:
263+
try:
264+
return bool(clause.compare(id_asc) or clause.compare(id_desc))
265+
except AttributeError:
266+
return False
267+
268+
if any(references_id(clause) for clause in orderings):
269+
return False
270+
# Only shuffle queries that target the messages table.
271+
target_tables: set[str] = set()
272+
for from_clause in statement.get_final_froms():
273+
name_attr = getattr(from_clause, "name", None)
274+
if isinstance(name_attr, str):
275+
target_tables.add(name_attr)
276+
table_name_obj = getattr(session._messages, "name", "")
277+
table_name = table_name_obj if isinstance(table_name_obj, str) else ""
278+
return bool(table_name in target_tables)
279+
280+
@asynccontextmanager
281+
async def shuffled_session():
282+
async with real_factory() as inner:
283+
original_execute = inner.execute
284+
285+
async def execute_with_shuffle(statement: Any, *args: Any, **kwargs: Any) -> Any:
286+
result = await original_execute(statement, *args, **kwargs)
287+
if needs_shuffle(statement):
288+
rows = result.all()
289+
shuffled = list(rows)
290+
shuffled.reverse()
291+
return FakeResult(shuffled)
292+
return result
293+
294+
cast(Any, inner).execute = execute_with_shuffle
295+
try:
296+
yield inner
297+
finally:
298+
cast(Any, inner).execute = original_execute
299+
300+
session._session_factory = cast(Any, shuffled_session)
301+
try:
302+
retrieved = await session.get_items()
303+
assert _item_ids(retrieved) == ["older_same_ts", "rs_same_ts", "msg_same_ts"]
304+
305+
latest_two = await session.get_items(limit=2)
306+
assert _item_ids(latest_two) == ["rs_same_ts", "msg_same_ts"]
307+
finally:
308+
session._session_factory = real_factory
309+
310+
311+
async def test_pop_item_same_timestamp_returns_latest():
312+
"""Test that pop_item returns the newest item when timestamps tie."""
313+
session_id = "same_timestamp_pop_test"
314+
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
315+
316+
reasoning_item = _make_reasoning_item("rs_pop_same_ts", "...")
317+
message_item = _make_message_item("msg_pop_same_ts", "...")
318+
await session.add_items([reasoning_item, message_item])
319+
320+
async with session._session_factory() as sess:
321+
await sess.execute(
322+
text(
323+
"UPDATE agent_messages "
324+
"SET created_at = :created_at "
325+
"WHERE session_id = :session_id"
326+
),
327+
{
328+
"created_at": "2025-10-15 17:26:39.132483",
329+
"session_id": session.session_id,
330+
},
331+
)
332+
await sess.commit()
333+
334+
popped = await session.pop_item()
335+
assert popped is not None
336+
assert cast(dict[str, Any], popped)["id"] == "msg_pop_same_ts"
337+
338+
remaining = await session.get_items()
339+
assert _item_ids(remaining) == ["rs_pop_same_ts"]
340+
341+
342+
async def test_get_items_orders_by_id_for_ties():
343+
"""Test that get_items adds id ordering to break timestamp ties."""
344+
session_id = "order_by_id_test"
345+
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
346+
347+
await session.add_items(
348+
[
349+
_make_reasoning_item("rs_first", "..."),
350+
_make_message_item("msg_second", "..."),
351+
]
352+
)
353+
354+
real_factory = session._session_factory
355+
recorded: list[Any] = []
356+
357+
@asynccontextmanager
358+
async def wrapped_session():
359+
async with real_factory() as inner:
360+
original_execute = inner.execute
361+
362+
async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any:
363+
recorded.append(statement)
364+
return await original_execute(statement, *args, **kwargs)
365+
366+
cast(Any, inner).execute = recording_execute
367+
try:
368+
yield inner
369+
finally:
370+
cast(Any, inner).execute = original_execute
371+
372+
session._session_factory = cast(Any, wrapped_session)
373+
try:
374+
retrieved_full = await session.get_items()
375+
retrieved_limited = await session.get_items(limit=2)
376+
finally:
377+
session._session_factory = real_factory
378+
379+
assert len(recorded) >= 2
380+
orderings_full = [str(clause) for clause in recorded[0]._order_by_clause]
381+
assert orderings_full == [
382+
"agent_messages.created_at ASC",
383+
"agent_messages.id ASC",
384+
]
385+
386+
orderings_limited = [str(clause) for clause in recorded[1]._order_by_clause]
387+
assert orderings_limited == [
388+
"agent_messages.created_at DESC",
389+
"agent_messages.id DESC",
390+
]
391+
392+
assert _item_ids(retrieved_full) == ["rs_first", "msg_second"]
393+
assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"]

0 commit comments

Comments
 (0)