Skip to content

Commit dd8840a

Browse files
yeesiancopybara-github
authored andcommitted
feat: Allow list of events to be passed to AdkApp.async_stream_query
PiperOrigin-RevId: 845350483
1 parent 6d5404d commit dd8840a

File tree

2 files changed

+144
-1
lines changed

2 files changed

+144
-1
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,83 @@ def __init__(self, name: str, model: str):
7878
"streaming_mode": "sse",
7979
"max_llm_calls": 500,
8080
}
81+
_TEST_SESSION_EVENTS = [
82+
{
83+
"author": "user",
84+
"content": {
85+
"parts": [
86+
{
87+
"text": "What is the exchange rate from US dollars to "
88+
"Swedish krona on 2025-09-25?"
89+
}
90+
],
91+
"role": "user",
92+
},
93+
"id": "8967297909049524224",
94+
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
95+
"timestamp": 1765832134.629513,
96+
},
97+
{
98+
"author": "currency_exchange_agent",
99+
"content": {
100+
"parts": [
101+
{
102+
"functionCall": {
103+
"args": {
104+
"currency_date": "2025-09-25",
105+
"currency_from": "USD",
106+
"currency_to": "SEK",
107+
},
108+
"id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7",
109+
"name": "get_exchange_rate",
110+
}
111+
}
112+
],
113+
"role": "model",
114+
},
115+
"id": "3155402589927899136",
116+
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
117+
"timestamp": 1765832134.723713,
118+
},
119+
{
120+
"author": "currency_exchange_agent",
121+
"content": {
122+
"parts": [
123+
{
124+
"functionResponse": {
125+
"id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7",
126+
"name": "get_exchange_rate",
127+
"response": {
128+
"amount": 1,
129+
"base": "USD",
130+
"date": "2025-09-25",
131+
"rates": {"SEK": 9.4118},
132+
},
133+
}
134+
}
135+
],
136+
"role": "user",
137+
},
138+
"id": "1678221912150376448",
139+
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
140+
"timestamp": 1765832135.764961,
141+
},
142+
{
143+
"author": "currency_exchange_agent",
144+
"content": {
145+
"parts": [
146+
{
147+
"text": "The exchange rate from US dollars to Swedish "
148+
"krona on 2025-09-25 is 1 USD to 9.4118 SEK."
149+
}
150+
],
151+
"role": "model",
152+
},
153+
"id": "2470855446567583744",
154+
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
155+
"timestamp": 1765832135.853299,
156+
},
157+
]
81158

82159

83160
@pytest.fixture(scope="module")
@@ -392,6 +469,46 @@ async def test_async_stream_query(self):
392469
events.append(event)
393470
assert len(events) == 1
394471

472+
@pytest.mark.asyncio
473+
async def test_async_stream_query_with_empty_session_events(self):
474+
app = reasoning_engines.AdkApp(
475+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
476+
)
477+
assert app._tmpl_attrs.get("runner") is None
478+
app.set_up()
479+
app._tmpl_attrs["runner"] = _MockRunner()
480+
events = []
481+
async for event in app.async_stream_query(
482+
user_id=_TEST_USER_ID,
483+
session_events=[],
484+
message="test message",
485+
):
486+
events.append(event)
487+
assert app._tmpl_attrs.get("session_service") is not None
488+
sessions = app.list_sessions(user_id=_TEST_USER_ID)
489+
assert len(sessions.sessions) == 1
490+
491+
@pytest.mark.asyncio
492+
async def test_async_stream_query_with_session_events(
493+
self,
494+
):
495+
app = reasoning_engines.AdkApp(
496+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
497+
)
498+
assert app._tmpl_attrs.get("runner") is None
499+
app.set_up()
500+
app._tmpl_attrs["runner"] = _MockRunner()
501+
events = []
502+
async for event in app.async_stream_query(
503+
user_id=_TEST_USER_ID,
504+
session_events=_TEST_SESSION_EVENTS,
505+
message="on the day after that?",
506+
):
507+
events.append(event)
508+
assert app._tmpl_attrs.get("session_service") is not None
509+
sessions = app.list_sessions(user_id=_TEST_USER_ID)
510+
assert len(sessions.sessions) == 1
511+
395512
@pytest.mark.asyncio
396513
@mock.patch.dict(
397514
os.environ,

vertexai/agent_engines/templates/adk.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ async def async_stream_query(
932932
message: Union[str, Dict[str, Any]],
933933
user_id: str,
934934
session_id: Optional[str] = None,
935+
session_events: Optional[List[Dict[str, Any]]] = None,
935936
run_config: Optional[Dict[str, Any]] = None,
936937
**kwargs,
937938
) -> AsyncIterable[Dict[str, Any]]:
@@ -944,7 +945,11 @@ async def async_stream_query(
944945
Required. The ID of the user.
945946
session_id (str):
946947
Optional. The ID of the session. If not provided, a new
947-
session will be created for the user.
948+
session will be created for the user. If this is specified, then
949+
`session_events` will be ignored.
950+
session_events (Optional[List[Dict[str, Any]]]):
951+
Optional. The session events to use for the query. This will be
952+
used to initialize the session if `session_id` is not provided.
948953
run_config (Optional[Dict[str, Any]]):
949954
Optional. The run config to use for the query. If you want to
950955
pass in a `run_config` pydantic object, you can pass in a dict
@@ -955,6 +960,11 @@ async def async_stream_query(
955960
956961
Yields:
957962
Event dictionaries asynchronously.
963+
964+
Raises:
965+
TypeError: If message is not a string or a dictionary representing
966+
a Content object.
967+
ValueError: If both session_id and session_events are specified.
958968
"""
959969
from vertexai.agent_engines import _utils
960970
from google.genai import types
@@ -971,9 +981,25 @@ async def async_stream_query(
971981

972982
if not self._tmpl_attrs.get("runner"):
973983
self.set_up()
984+
if session_id and session_events:
985+
raise ValueError(
986+
"Only one of session_id and session_events should be specified."
987+
)
974988
if not session_id:
975989
session = await self.async_create_session(user_id=user_id)
976990
session_id = session.id
991+
if session_events is not None:
992+
# We allow for session_events to be an empty list.
993+
from google.adk.events.event import Event
994+
995+
session_service = self._tmpl_attrs.get("session_service")
996+
for event in session_events:
997+
if not isinstance(event, Event):
998+
event = Event.model_validate(event)
999+
await session_service.append_event(
1000+
session=session,
1001+
event=event,
1002+
)
9771003

9781004
run_config = _validate_run_config(run_config)
9791005
if run_config:

0 commit comments

Comments
 (0)