Skip to content

Commit 1368f6a

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add memory related methods to AdkApp
PiperOrigin-RevId: 794137925
1 parent 4b7d43e commit 1368f6a

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,105 @@ def test_delete_session(self):
370370
response0 = app.list_sessions(user_id="test_user_id")
371371
assert not response0.sessions
372372

373+
@pytest.mark.asyncio
374+
async def test_async_add_session_to_memory(self):
375+
app = reasoning_engines.AdkApp(
376+
agent=Agent(name="test_agent", model=_TEST_MODEL)
377+
)
378+
assert app._tmpl_attrs.get("memory_service") is None
379+
session = app.create_session(user_id="test_user_id")
380+
list(
381+
app.stream_query(
382+
user_id="test_user_id",
383+
session_id=session.id,
384+
message="My cat's name is Garfield",
385+
)
386+
)
387+
await app.async_add_session_to_memory(
388+
session=app.get_session(
389+
user_id="test_user_id",
390+
session_id=session.id,
391+
)
392+
)
393+
response = await app.async_search_memory(
394+
user_id="test_user_id",
395+
query="What is my cat's name",
396+
)
397+
assert len(response.memories) >= 1
398+
399+
@pytest.mark.asyncio
400+
async def test_async_add_session_to_memory_dict(self):
401+
app = reasoning_engines.AdkApp(
402+
agent=Agent(name="test_agent", model=_TEST_MODEL)
403+
)
404+
await app.async_add_session_to_memory(
405+
session={
406+
"id": "ca18c25a-644b-4e13-9b24-78c150ec3eb9",
407+
"app_name": "default-app-name",
408+
"user_id": "test_user_id",
409+
"events": [
410+
{
411+
"author": "user",
412+
"content": {
413+
"parts": [{"text": "My cat's name is Garfield"}],
414+
"role": "user",
415+
},
416+
},
417+
{
418+
"author": "my_personal_agent",
419+
"content": {
420+
"parts": [{"text": "Okay, good to know!"}],
421+
"role": "model",
422+
},
423+
},
424+
],
425+
},
426+
)
427+
response = await app.async_search_memory(
428+
user_id="test_user_id",
429+
query="What is my cat's name",
430+
)
431+
assert len(response.memories) >= 1
432+
433+
@pytest.mark.asyncio
434+
async def test_async_search_memory(self):
435+
app = reasoning_engines.AdkApp(
436+
agent=Agent(name="test_agent", model=_TEST_MODEL)
437+
)
438+
response = await app.async_search_memory(
439+
user_id="test_user_id",
440+
query="What is my cat's name",
441+
)
442+
assert not response.memories
443+
await app.async_add_session_to_memory(
444+
session={
445+
"id": "ca18c25a-644b-4e13-9b24-78c150ec3eb9",
446+
"app_name": "default-app-name",
447+
"user_id": "test_user_id",
448+
"events": [
449+
{
450+
"author": "user",
451+
"content": {
452+
"parts": [{"text": "My cat's name is Garfield"}],
453+
"role": "user",
454+
},
455+
},
456+
{
457+
"author": "my_personal_agent",
458+
"content": {
459+
"parts": [{"text": "Okay, good to know!"}],
460+
"role": "model",
461+
},
462+
},
463+
],
464+
},
465+
)
466+
response = await app.async_search_memory(
467+
user_id="test_user_id",
468+
query="What is my cat's name",
469+
)
470+
assert len(response.memories) >= 1
471+
373472
@pytest.mark.usefixtures("caplog")
374473
def test_enable_tracing(
375474
self,

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@
5151
except (ImportError, AttributeError):
5252
BaseSessionService = Any
5353

54+
try:
55+
from google.adk.sessions.session import Session
56+
57+
Session = Session
58+
except (ImportError, AttributeError):
59+
Session = Any
60+
5461
try:
5562
from google.adk.artifacts import BaseArtifactService
5663

@@ -961,6 +968,53 @@ def _asyncio_thread_main():
961968
if isinstance(outcome, RuntimeError):
962969
raise outcome from None
963970

971+
async def async_add_session_to_memory(
972+
self,
973+
*,
974+
session: Union["Session", Dict[str, Any]],
975+
):
976+
"""Generates memories.
977+
978+
Args:
979+
session (Union[Session, Dict[str, Any]]):
980+
Required. The session to use for generating memories.
981+
"""
982+
from google.adk.sessions.session import Session
983+
984+
if isinstance(session, Dict):
985+
session = Session.model_validate(session)
986+
elif not isinstance(session, Session):
987+
raise TypeError("session must be a Session object.")
988+
if not session.events:
989+
# Get the latest version of the session in case it was updated.
990+
session = await self.async_get_session(
991+
user_id=session.user_id,
992+
session_id=session.id,
993+
)
994+
if not self._tmpl_attrs.get("memory_service"):
995+
self.set_up()
996+
return await self._tmpl_attrs.get("memory_service").add_session_to_memory(
997+
session=session,
998+
)
999+
1000+
async def async_search_memory(self, *, user_id: str, query: str):
1001+
"""Searches memories for the given user.
1002+
1003+
Args:
1004+
user_id: The id of the user.
1005+
query: The query to match the memories on.
1006+
1007+
Returns:
1008+
A SearchMemoryResponse containing the matching memories.
1009+
"""
1010+
if not self._tmpl_attrs.get("memory_service"):
1011+
self.set_up()
1012+
return await self._tmpl_attrs.get("memory_service").search_memory(
1013+
app_name=self._tmpl_attrs.get("app_name"),
1014+
user_id=user_id,
1015+
query=query,
1016+
)
1017+
9641018
def register_operations(self) -> Dict[str, List[str]]:
9651019
"""Registers the operations of the ADK application."""
9661020
return {
@@ -975,6 +1029,8 @@ def register_operations(self) -> Dict[str, List[str]]:
9751029
"async_list_sessions",
9761030
"async_create_session",
9771031
"async_delete_session",
1032+
"async_add_session_to_memory",
1033+
"async_search_memory",
9781034
],
9791035
"stream": ["stream_query", "streaming_agent_run_with_events"],
9801036
"async_stream": ["async_stream_query"],

0 commit comments

Comments
 (0)