Skip to content

Commit 55b7c23

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client: Add async Memory and Memory Revisions methods
PiperOrigin-RevId: 821824583
1 parent bbf788a commit 55b7c23

File tree

7 files changed

+432
-22
lines changed

7 files changed

+432
-22
lines changed

tests/unit/vertexai/genai/replays/test_delete_agent_engine_memory.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@
2323

2424
def test_delete_memory(client):
2525
agent_engine = client.agent_engines.create()
26-
operation = client.agent_engines.create_memory(
26+
operation = client.agent_engines.memories.create(
2727
name=agent_engine.api_resource.name,
2828
fact="memory_fact",
2929
scope={"user_id": "123"},
3030
)
3131
memory = operation.response
32-
operation = client.agent_engines.delete_memory(name=memory.name)
32+
operation = client.agent_engines.memories.delete(name=memory.name)
3333
assert isinstance(operation, types.DeleteAgentEngineMemoryOperation)
3434
assert operation.name.startswith(memory.name + "/operations/")
35+
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
3536

3637

3738
pytestmark = pytest_helper.setup(
@@ -46,14 +47,16 @@ def test_delete_memory(client):
4647

4748
@pytest.mark.asyncio
4849
async def test_delete_memory_async(client):
49-
# TODO(b/431785750): use async methods for create() and create_memory() when available
5050
agent_engine = client.agent_engines.create()
51-
operation = client.agent_engines.create_memory(
51+
operation = await client.aio.agent_engines.memories.create(
5252
name=agent_engine.api_resource.name,
5353
fact="memory_fact",
5454
scope={"user_id": "123"},
5555
)
5656
memory = operation.response
57-
operation = await client.aio.agent_engines.delete_memory(name=memory.name)
57+
operation = await client.aio.agent_engines.memories.delete(name=memory.name)
5858
assert isinstance(operation, types.DeleteAgentEngineMemoryOperation)
5959
assert operation.name.startswith(memory.name + "/operations/")
60+
await client.aio.agent_engines.delete(
61+
name=agent_engine.api_resource.name, force=True
62+
)

tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import pytest
18+
19+
1720
from tests.unit.vertexai.genai.replays import pytest_helper
1821
from vertexai._genai import types
1922
from google.genai import types as genai_types
2023

2124

2225
def test_generate_and_rollback_memories(client):
26+
# TODO(): Use prod endpoint once experiment is fully rolled out.
2327
client._api_client._http_options.base_url = (
2428
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
2529
)
@@ -146,3 +150,64 @@ def test_generate_memories_direct_memories_source(client):
146150
globals_for_file=globals(),
147151
test_method="agent_engines.generate_memories",
148152
)
153+
154+
155+
pytest_plugins = ("pytest_asyncio",)
156+
157+
158+
@pytest.mark.asyncio
159+
async def test_generate_and_rollback_memories_async(client):
160+
# TODO(): Use prod endpoint once revisions experiment is fully rolled out.
161+
client._api_client._http_options.base_url = (
162+
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
163+
)
164+
agent_engine = client.agent_engines.create()
165+
await client.aio.agent_engines.memories.generate(
166+
name=agent_engine.api_resource.name,
167+
scope={"user_id": "test-user-id"},
168+
direct_memories_source=types.GenerateMemoriesRequestDirectMemoriesSource(
169+
direct_memories=[
170+
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
171+
fact="I am a software engineer."
172+
),
173+
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
174+
fact="I like to write replay tests."
175+
),
176+
]
177+
),
178+
config=types.GenerateAgentEngineMemoriesConfig(wait_for_completion=True),
179+
)
180+
memories_pager = await client.aio.agent_engines.memories.list(
181+
name=agent_engine.api_resource.name
182+
)
183+
memory_list = [item async for item in memories_pager]
184+
assert len(memory_list) >= 1
185+
186+
revisions_pager = await client.aio.agent_engines.memories.revisions.list(
187+
name=memory_list[0].name
188+
)
189+
memory_revisions = [item async for item in revisions_pager]
190+
assert len(memory_revisions) >= 1
191+
revision_name = memory_revisions[0].name
192+
193+
# Update the memory.
194+
client.agent_engines.memories._update(
195+
name=memory_list[0].name,
196+
fact="This is temporary",
197+
scope={"user_id": "test-user-id"},
198+
)
199+
memory = await client.aio.agent_engines.memories.get(name=memory_list[0].name)
200+
assert memory.fact == "This is temporary"
201+
202+
# Rollback to the revision with the original fact that was created by the
203+
# generation request.
204+
await client.aio.agent_engines.memories.rollback(
205+
name=memory_list[0].name,
206+
target_revision_id=revision_name.split("/")[-1],
207+
)
208+
memory = await client.aio.agent_engines.memories.get(name=memory_list[0].name)
209+
assert memory.fact == memory_revisions[0].fact
210+
211+
await client.aio.agent_engines.delete(
212+
name=agent_engine.api_resource.name, force=True
213+
)

tests/unit/vertexai/genai/replays/test_get_agent_engine_memory.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,24 @@
2222

2323
def test_get_memory(client):
2424
agent_engine = client.agent_engines.create()
25-
operation = client.agent_engines.create_memory(
25+
operation = client.agent_engines.memories.create(
2626
name=agent_engine.api_resource.name,
2727
fact="memory_fact",
2828
scope={"user_id": "123"},
2929
)
3030
assert isinstance(operation, types.AgentEngineMemoryOperation)
31-
memory = client.agent_engines.get_memory(
31+
memory = client.agent_engines.memories.get(
3232
name=operation.response.name,
3333
)
3434
assert isinstance(memory, types.Memory)
3535
assert memory.name == operation.response.name
36+
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
3637

3738

3839
pytestmark = pytest_helper.setup(
3940
file=__file__,
4041
globals_for_file=globals(),
41-
test_method="agent_engines.get_memory",
42+
test_method="agent_engines.memories.get",
4243
)
4344

4445

@@ -47,16 +48,18 @@ def test_get_memory(client):
4748

4849
@pytest.mark.asyncio
4950
async def test_get_memory_async(client):
50-
# TODO(b/431785750): use async methods for create() and create_memory() when available
5151
agent_engine = client.agent_engines.create()
52-
operation = client.agent_engines.create_memory(
52+
operation = await client.aio.agent_engines.memories.create(
5353
name=agent_engine.api_resource.name,
5454
fact="memory_fact",
5555
scope={"user_id": "123"},
5656
)
5757
assert isinstance(operation, types.AgentEngineMemoryOperation)
58-
memory = await client.aio.agent_engines.get_memory(
58+
memory = await client.aio.agent_engines.memories.get(
5959
name=operation.response.name,
6060
)
6161
assert isinstance(memory, types.Memory)
6262
assert memory.name == operation.response.name
63+
await client.aio.agent_engines.delete(
64+
name=agent_engine.api_resource.name, force=True
65+
)

tests/unit/vertexai/genai/replays/test_list_agent_engine_memories.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import pytest
18+
1719
from tests.unit.vertexai.genai.replays import pytest_helper
1820
from vertexai._genai import types
1921

@@ -59,5 +61,36 @@ def test_list_memories(client):
5961
pytestmark = pytest_helper.setup(
6062
file=__file__,
6163
globals_for_file=globals(),
62-
test_method="agent_engines.list_memories",
64+
test_method="agent_engines.memories.list",
6365
)
66+
67+
68+
pytest_plugins = ("pytest_asyncio",)
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_async_list_memories(client):
73+
agent_engine = client.agent_engines.create()
74+
pager = await client.aio.agent_engines.memories.list(
75+
name=agent_engine.api_resource.name
76+
)
77+
assert not [item async for item in pager]
78+
79+
await client.aio.agent_engines.memories.create(
80+
name=agent_engine.api_resource.name,
81+
fact="memory_fact_2",
82+
scope={"user_id": "456"},
83+
config={
84+
"wait_for_completion": True,
85+
},
86+
)
87+
pager = await client.aio.agent_engines.memories.list(
88+
name=agent_engine.api_resource.name
89+
)
90+
memory_list = [item async for item in pager]
91+
assert len(memory_list) == 1
92+
assert isinstance(memory_list[0], types.Memory)
93+
94+
await client.aio.agent_engines.delete(
95+
name=agent_engine.api_resource.name, force=True
96+
)

tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import pytest
18+
19+
1720
from tests.unit.vertexai.genai.replays import pytest_helper
1821
from vertexai._genai import types
1922
from google.genai import pagers
@@ -22,23 +25,23 @@
2225
def test_retrieve_memories_with_similarity_search_params(client):
2326
agent_engine = client.agent_engines.create()
2427
assert not list(
25-
client.agent_engines.retrieve_memories(
28+
client.agent_engines.memories.retrieve(
2629
name=agent_engine.api_resource.name,
2730
scope={"user_id": "123"},
2831
similarity_search_params=types.RetrieveMemoriesRequestSimilaritySearchParams(
2932
search_query="memory_fact_1",
3033
),
3134
)
3235
)
33-
client.agent_engines.create_memory(
36+
client.agent_engines.memories.create(
3437
name=agent_engine.api_resource.name,
3538
fact="memory_fact_1",
3639
scope={"user_id": "123"},
3740
)
3841
assert (
3942
len(
4043
list(
41-
client.agent_engines.retrieve_memories(
44+
client.agent_engines.memories.retrieve(
4245
name=agent_engine.api_resource.name,
4346
scope={"user_id": "123"},
4447
)
@@ -47,20 +50,20 @@ def test_retrieve_memories_with_similarity_search_params(client):
4750
== 1
4851
)
4952
assert not list(
50-
client.agent_engines.retrieve_memories(
53+
client.agent_engines.memories.retrieve(
5154
name=agent_engine.api_resource.name,
5255
scope={"user_id": "456"},
5356
)
5457
)
55-
client.agent_engines.create_memory(
58+
client.agent_engines.memories.create(
5659
name=agent_engine.api_resource.name,
5760
fact="memory_fact_2",
5861
scope={"user_id": "123"},
5962
)
6063
assert (
6164
len(
6265
list(
63-
client.agent_engines.retrieve_memories(
66+
client.agent_engines.memories.retrieve(
6467
name=agent_engine.api_resource.name,
6568
scope={"user_id": "123"},
6669
)
@@ -74,12 +77,12 @@ def test_retrieve_memories_with_similarity_search_params(client):
7477

7578
def test_retrieve_memories_with_simple_retrieval_params(client):
7679
agent_engine = client.agent_engines.create()
77-
client.agent_engines.create_memory(
80+
client.agent_engines.memories.create(
7881
name=agent_engine.api_resource.name,
7982
fact="memory_fact_1",
8083
scope={"user_id": "123"},
8184
)
82-
memories = client.agent_engines.retrieve_memories(
85+
memories = client.agent_engines.memories.retrieve(
8386
name=agent_engine.api_resource.name,
8487
scope={"user_id": "123"},
8588
simple_retrieval_params=types.RetrieveMemoriesRequestSimpleRetrievalParams(
@@ -98,3 +101,27 @@ def test_retrieve_memories_with_simple_retrieval_params(client):
98101
globals_for_file=globals(),
99102
test_method="agent_engines.create_memory",
100103
)
104+
105+
106+
pytest_plugins = ("pytest_asyncio",)
107+
108+
109+
@pytest.mark.asyncio
110+
async def test_retrieve_memories_async(client):
111+
agent_engine = client.agent_engines.create()
112+
operation = await client.aio.agent_engines.memories.create(
113+
name=agent_engine.api_resource.name,
114+
fact="memory_fact",
115+
scope={"user_id": "123"},
116+
)
117+
assert isinstance(operation, types.AgentEngineMemoryOperation)
118+
pager = await client.aio.agent_engines.memories.retrieve(
119+
name=agent_engine.api_resource.name,
120+
scope={"user_id": "123"},
121+
)
122+
memories = [item async for item in pager]
123+
assert len(memories) == 1
124+
assert isinstance(memories[0], types.RetrieveMemoriesResponseRetrievedMemory)
125+
await client.aio.agent_engines.delete(
126+
name=agent_engine.api_resource.name, force=True
127+
)

0 commit comments

Comments
 (0)