|  | 
| 1 | 1 | from __future__ import annotations | 
| 2 | 2 | 
 | 
|  | 3 | +import json | 
|  | 4 | +from collections.abc import Iterable, Sequence | 
|  | 5 | +from contextlib import asynccontextmanager | 
|  | 6 | +from datetime import datetime, timedelta | 
|  | 7 | +from typing import Any, cast | 
|  | 8 | + | 
| 3 | 9 | import pytest | 
|  | 10 | +from openai.types.responses.response_output_message_param import ResponseOutputMessageParam | 
|  | 11 | +from openai.types.responses.response_output_text_param import ResponseOutputTextParam | 
|  | 12 | +from openai.types.responses.response_reasoning_item_param import ( | 
|  | 13 | + ResponseReasoningItemParam, | 
|  | 14 | + Summary, | 
|  | 15 | +) | 
|  | 16 | +from sqlalchemy import select, text, update | 
|  | 17 | +from sqlalchemy.sql import Select | 
| 4 | 18 | 
 | 
| 5 | 19 | pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed | 
| 6 | 20 | 
 | 
|  | 
| 16 | 30 | DB_URL = "sqlite+aiosqlite:///:memory:" | 
| 17 | 31 | 
 | 
| 18 | 32 | 
 | 
|  | 33 | +def _make_message_item(item_id: str, text_value: str) -> TResponseInputItem: | 
|  | 34 | + content: ResponseOutputTextParam = { | 
|  | 35 | + "type": "output_text", | 
|  | 36 | + "text": text_value, | 
|  | 37 | + "annotations": [], | 
|  | 38 | + } | 
|  | 39 | + message: ResponseOutputMessageParam = { | 
|  | 40 | + "id": item_id, | 
|  | 41 | + "type": "message", | 
|  | 42 | + "role": "assistant", | 
|  | 43 | + "status": "completed", | 
|  | 44 | + "content": [content], | 
|  | 45 | + } | 
|  | 46 | + return cast(TResponseInputItem, message) | 
|  | 47 | + | 
|  | 48 | + | 
|  | 49 | +def _make_reasoning_item(item_id: str, summary_text: str) -> TResponseInputItem: | 
|  | 50 | + summary: Summary = {"type": "summary_text", "text": summary_text} | 
|  | 51 | + reasoning: ResponseReasoningItemParam = { | 
|  | 52 | + "id": item_id, | 
|  | 53 | + "type": "reasoning", | 
|  | 54 | + "summary": [summary], | 
|  | 55 | + } | 
|  | 56 | + return cast(TResponseInputItem, reasoning) | 
|  | 57 | + | 
|  | 58 | + | 
|  | 59 | +def _item_ids(items: Sequence[TResponseInputItem]) -> list[str]: | 
|  | 60 | + result: list[str] = [] | 
|  | 61 | + for item in items: | 
|  | 62 | + item_dict = cast(dict[str, Any], item) | 
|  | 63 | + result.append(cast(str, item_dict["id"])) | 
|  | 64 | + return result | 
|  | 65 | + | 
|  | 66 | + | 
| 19 | 67 | @pytest.fixture | 
| 20 | 68 | def agent() -> Agent: | 
| 21 | 69 |  """Fixture for a basic agent with a fake model.""" | 
| @@ -151,3 +199,195 @@ async def test_add_empty_items_list(): | 
| 151 | 199 | 
 | 
| 152 | 200 |  items_after_add = await session.get_items() | 
| 153 | 201 |  assert len(items_after_add) == 0 | 
|  | 202 | + | 
|  | 203 | + | 
|  | 204 | +async def test_get_items_same_timestamp_consistent_order(): | 
|  | 205 | + """Test that items with identical timestamps keep insertion order.""" | 
|  | 206 | + session_id = "same_timestamp_test" | 
|  | 207 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) | 
|  | 208 | + | 
|  | 209 | + older_item = _make_message_item("older_same_ts", "old") | 
|  | 210 | + reasoning_item = _make_reasoning_item("rs_same_ts", "...") | 
|  | 211 | + message_item = _make_message_item("msg_same_ts", "...") | 
|  | 212 | + await session.add_items([older_item]) | 
|  | 213 | + await session.add_items([reasoning_item, message_item]) | 
|  | 214 | + | 
|  | 215 | + async with session._session_factory() as sess: | 
|  | 216 | + rows = await sess.execute( | 
|  | 217 | + select(session._messages.c.id, session._messages.c.message_data).where( | 
|  | 218 | + session._messages.c.session_id == session.session_id | 
|  | 219 | + ) | 
|  | 220 | + ) | 
|  | 221 | + id_map = { | 
|  | 222 | + json.loads(message_json)["id"]: row_id | 
|  | 223 | + for row_id, message_json in rows.fetchall() | 
|  | 224 | + } | 
|  | 225 | + shared = datetime(2025, 10, 15, 17, 26, 39, 132483) | 
|  | 226 | + older = shared - timedelta(milliseconds=1) | 
|  | 227 | + await sess.execute( | 
|  | 228 | + update(session._messages) | 
|  | 229 | + .where(session._messages.c.id.in_( | 
|  | 230 | + [ | 
|  | 231 | + id_map["rs_same_ts"], | 
|  | 232 | + id_map["msg_same_ts"], | 
|  | 233 | + ] | 
|  | 234 | + )) | 
|  | 235 | + .values(created_at=shared) | 
|  | 236 | + ) | 
|  | 237 | + await sess.execute( | 
|  | 238 | + update(session._messages) | 
|  | 239 | + .where(session._messages.c.id == id_map["older_same_ts"]) | 
|  | 240 | + .values(created_at=older) | 
|  | 241 | + ) | 
|  | 242 | + await sess.commit() | 
|  | 243 | + | 
|  | 244 | + real_factory = session._session_factory | 
|  | 245 | + | 
|  | 246 | + class FakeResult: | 
|  | 247 | + def __init__(self, rows: Iterable[Any]): | 
|  | 248 | + self._rows = list(rows) | 
|  | 249 | + | 
|  | 250 | + def all(self) -> list[Any]: | 
|  | 251 | + return list(self._rows) | 
|  | 252 | + | 
|  | 253 | + def needs_shuffle(statement: Any) -> bool: | 
|  | 254 | + if not isinstance(statement, Select): | 
|  | 255 | + return False | 
|  | 256 | + orderings = list(statement._order_by_clause) | 
|  | 257 | + if not orderings: | 
|  | 258 | + return False | 
|  | 259 | + id_asc = session._messages.c.id.asc() | 
|  | 260 | + id_desc = session._messages.c.id.desc() | 
|  | 261 | + | 
|  | 262 | + def references_id(clause) -> bool: | 
|  | 263 | + try: | 
|  | 264 | + return bool(clause.compare(id_asc) or clause.compare(id_desc)) | 
|  | 265 | + except AttributeError: | 
|  | 266 | + return False | 
|  | 267 | + | 
|  | 268 | + if any(references_id(clause) for clause in orderings): | 
|  | 269 | + return False | 
|  | 270 | + # Only shuffle queries that target the messages table. | 
|  | 271 | + target_tables: set[str] = set() | 
|  | 272 | + for from_clause in statement.get_final_froms(): | 
|  | 273 | + name_attr = getattr(from_clause, "name", None) | 
|  | 274 | + if isinstance(name_attr, str): | 
|  | 275 | + target_tables.add(name_attr) | 
|  | 276 | + table_name_obj = getattr(session._messages, "name", "") | 
|  | 277 | + table_name = table_name_obj if isinstance(table_name_obj, str) else "" | 
|  | 278 | + return bool(table_name in target_tables) | 
|  | 279 | + | 
|  | 280 | + @asynccontextmanager | 
|  | 281 | + async def shuffled_session(): | 
|  | 282 | + async with real_factory() as inner: | 
|  | 283 | + original_execute = inner.execute | 
|  | 284 | + | 
|  | 285 | + async def execute_with_shuffle(statement: Any, *args: Any, **kwargs: Any) -> Any: | 
|  | 286 | + result = await original_execute(statement, *args, **kwargs) | 
|  | 287 | + if needs_shuffle(statement): | 
|  | 288 | + rows = result.all() | 
|  | 289 | + shuffled = list(rows) | 
|  | 290 | + shuffled.reverse() | 
|  | 291 | + return FakeResult(shuffled) | 
|  | 292 | + return result | 
|  | 293 | + | 
|  | 294 | + cast(Any, inner).execute = execute_with_shuffle | 
|  | 295 | + try: | 
|  | 296 | + yield inner | 
|  | 297 | + finally: | 
|  | 298 | + cast(Any, inner).execute = original_execute | 
|  | 299 | + | 
|  | 300 | + session._session_factory = cast(Any, shuffled_session) | 
|  | 301 | + try: | 
|  | 302 | + retrieved = await session.get_items() | 
|  | 303 | + assert _item_ids(retrieved) == ["older_same_ts", "rs_same_ts", "msg_same_ts"] | 
|  | 304 | + | 
|  | 305 | + latest_two = await session.get_items(limit=2) | 
|  | 306 | + assert _item_ids(latest_two) == ["rs_same_ts", "msg_same_ts"] | 
|  | 307 | + finally: | 
|  | 308 | + session._session_factory = real_factory | 
|  | 309 | + | 
|  | 310 | + | 
|  | 311 | +async def test_pop_item_same_timestamp_returns_latest(): | 
|  | 312 | + """Test that pop_item returns the newest item when timestamps tie.""" | 
|  | 313 | + session_id = "same_timestamp_pop_test" | 
|  | 314 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) | 
|  | 315 | + | 
|  | 316 | + reasoning_item = _make_reasoning_item("rs_pop_same_ts", "...") | 
|  | 317 | + message_item = _make_message_item("msg_pop_same_ts", "...") | 
|  | 318 | + await session.add_items([reasoning_item, message_item]) | 
|  | 319 | + | 
|  | 320 | + async with session._session_factory() as sess: | 
|  | 321 | + await sess.execute( | 
|  | 322 | + text( | 
|  | 323 | + "UPDATE agent_messages " | 
|  | 324 | + "SET created_at = :created_at " | 
|  | 325 | + "WHERE session_id = :session_id" | 
|  | 326 | + ), | 
|  | 327 | + { | 
|  | 328 | + "created_at": "2025-10-15 17:26:39.132483", | 
|  | 329 | + "session_id": session.session_id, | 
|  | 330 | + }, | 
|  | 331 | + ) | 
|  | 332 | + await sess.commit() | 
|  | 333 | + | 
|  | 334 | + popped = await session.pop_item() | 
|  | 335 | + assert popped is not None | 
|  | 336 | + assert cast(dict[str, Any], popped)["id"] == "msg_pop_same_ts" | 
|  | 337 | + | 
|  | 338 | + remaining = await session.get_items() | 
|  | 339 | + assert _item_ids(remaining) == ["rs_pop_same_ts"] | 
|  | 340 | + | 
|  | 341 | + | 
|  | 342 | +async def test_get_items_orders_by_id_for_ties(): | 
|  | 343 | + """Test that get_items adds id ordering to break timestamp ties.""" | 
|  | 344 | + session_id = "order_by_id_test" | 
|  | 345 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) | 
|  | 346 | + | 
|  | 347 | + await session.add_items( | 
|  | 348 | + [ | 
|  | 349 | + _make_reasoning_item("rs_first", "..."), | 
|  | 350 | + _make_message_item("msg_second", "..."), | 
|  | 351 | + ] | 
|  | 352 | + ) | 
|  | 353 | + | 
|  | 354 | + real_factory = session._session_factory | 
|  | 355 | + recorded: list[Any] = [] | 
|  | 356 | + | 
|  | 357 | + @asynccontextmanager | 
|  | 358 | + async def wrapped_session(): | 
|  | 359 | + async with real_factory() as inner: | 
|  | 360 | + original_execute = inner.execute | 
|  | 361 | + | 
|  | 362 | + async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any: | 
|  | 363 | + recorded.append(statement) | 
|  | 364 | + return await original_execute(statement, *args, **kwargs) | 
|  | 365 | + | 
|  | 366 | + cast(Any, inner).execute = recording_execute | 
|  | 367 | + try: | 
|  | 368 | + yield inner | 
|  | 369 | + finally: | 
|  | 370 | + cast(Any, inner).execute = original_execute | 
|  | 371 | + | 
|  | 372 | + session._session_factory = cast(Any, wrapped_session) | 
|  | 373 | + try: | 
|  | 374 | + retrieved_full = await session.get_items() | 
|  | 375 | + retrieved_limited = await session.get_items(limit=2) | 
|  | 376 | + finally: | 
|  | 377 | + session._session_factory = real_factory | 
|  | 378 | + | 
|  | 379 | + assert len(recorded) >= 2 | 
|  | 380 | + orderings_full = [str(clause) for clause in recorded[0]._order_by_clause] | 
|  | 381 | + assert orderings_full == [ | 
|  | 382 | + "agent_messages.created_at ASC", | 
|  | 383 | + "agent_messages.id ASC", | 
|  | 384 | + ] | 
|  | 385 | + | 
|  | 386 | + orderings_limited = [str(clause) for clause in recorded[1]._order_by_clause] | 
|  | 387 | + assert orderings_limited == [ | 
|  | 388 | + "agent_messages.created_at DESC", | 
|  | 389 | + "agent_messages.id DESC", | 
|  | 390 | + ] | 
|  | 391 | + | 
|  | 392 | + assert _item_ids(retrieved_full) == ["rs_first", "msg_second"] | 
|  | 393 | + assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"] | 
0 commit comments