Skip to content

Commit d5a14ba

Browse files
sararobcopybara-github
authored andcommitted
feat: GenAI SDK client - Add experimental prompt_management module with create_version and get methods
PiperOrigin-RevId: 803150745
1 parent 77a3933 commit d5a14ba

File tree

8 files changed

+3761
-202
lines changed

8 files changed

+3761
-202
lines changed

tests/unit/vertexai/genai/replays/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,6 @@ def client(use_vertex, replays_prefix, http_options, request):
184184
os.path.dirname(__file__),
185185
"credentials.json",
186186
)
187-
os.environ["GOOGLE_CLOUD_PROJECT"] = "project-id"
188-
os.environ["GOOGLE_CLOUD_LOCATION"] = "location"
189187
os.environ["VAPO_CONFIG_PATH"] = "gs://dummy-test/dummy-config.json"
190188
os.environ["VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"] = "1234567890"
191189
os.environ["GCS_BUCKET"] = "test-bucket"
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
from google.genai import types as genai_types
20+
21+
22+
TEST_PROMPT_DATASET_ID = "8005484238453342208"
23+
TEST_VARIABLES = [
24+
{"name": genai_types.Part(text="Alice")},
25+
{"name": genai_types.Part(text="Bob")},
26+
]
27+
TEST_RESPONSE_SCHEMA = {
28+
"type": "object",
29+
"properties": {"response": {"type": "string"}},
30+
}
31+
TEST_PROMPT = types.Prompt(
32+
prompt_data=types.PromptData(
33+
contents=[
34+
genai_types.Content(
35+
role="user",
36+
parts=[genai_types.Part(text="Hello, {name}! How are you?")],
37+
)
38+
],
39+
safety_settings=[
40+
genai_types.SafetySetting(
41+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
42+
threshold="BLOCK_MEDIUM_AND_ABOVE",
43+
method="SEVERITY",
44+
),
45+
],
46+
generation_config=genai_types.GenerationConfig(
47+
temperature=0.1,
48+
candidate_count=1,
49+
top_p=0.95,
50+
top_k=40,
51+
response_modalities=["TEXT"],
52+
response_schema=TEST_RESPONSE_SCHEMA,
53+
),
54+
system_instruction=genai_types.Content(
55+
parts=[genai_types.Part(text="Please answer in a short sentence.")]
56+
),
57+
tools=[
58+
genai_types.Tool(
59+
google_search_retrieval=genai_types.GoogleSearchRetrieval(
60+
dynamic_retrieval_config=genai_types.DynamicRetrievalConfig(
61+
mode="MODE_DYNAMIC"
62+
)
63+
)
64+
),
65+
],
66+
tool_config=genai_types.ToolConfig(
67+
retrieval_config=genai_types.RetrievalConfig(
68+
lat_lng=genai_types.LatLng(latitude=37.7749, longitude=-122.4194)
69+
)
70+
),
71+
model="gemini-2.0-flash-001",
72+
variables=TEST_VARIABLES,
73+
),
74+
)
75+
TEST_CONFIG = types.CreatePromptConfig(
76+
prompt_display_name="my_prompt",
77+
version_display_name="my_version",
78+
)
79+
80+
81+
def test_create_dataset(client):
82+
create_dataset_operation = client.prompt_management._create_dataset_resource(
83+
config=types.CreateDatasetConfig(should_return_http_response=True),
84+
name="projects/vertex-sdk-dev/locations/us-central1",
85+
display_name="test display name",
86+
metadata_schema_uri="gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml",
87+
metadata={
88+
"promptType": "freeform",
89+
"promptApiSchema": {
90+
"multimodalPrompt": {
91+
"promptMessage": {
92+
"contents": [
93+
{
94+
"role": "user",
95+
"parts": [{"text": "Hello, {name}! How are you?"}],
96+
}
97+
],
98+
"safety_settings": [
99+
{
100+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
101+
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
102+
"method": "SEVERITY",
103+
}
104+
],
105+
"generation_config": {"temperature": 0.1},
106+
"model": "projects/vertex-sdk-dev/locations/us-central1/publishers/google/models/gemini-2.0-flash-001",
107+
"system_instruction": {
108+
"role": "user",
109+
"parts": [{"text": "Please answer in a short sentence."}],
110+
},
111+
}
112+
},
113+
"apiSchemaVersion": "1.0.0",
114+
"executions": [
115+
{
116+
"arguments": {
117+
"name": {"partList": {"parts": [{"text": "Alice"}]}}
118+
}
119+
},
120+
{"arguments": {"name": {"partList": {"parts": [{"text": "Bob"}]}}}},
121+
],
122+
},
123+
},
124+
model_reference="gemini-2.0-flash-001",
125+
)
126+
assert isinstance(create_dataset_operation, types.CreateDatasetOperationMetadata)
127+
assert create_dataset_operation.sdk_http_response.body is not None
128+
129+
130+
def test_create_dataset_version(client):
131+
dataset_version_resource = (
132+
client.prompt_management._create_dataset_version_resource(
133+
dataset_name=TEST_PROMPT_DATASET_ID,
134+
display_name="my new version yay",
135+
)
136+
)
137+
assert isinstance(
138+
dataset_version_resource, types.CreateDatasetVersionOperationMetadata
139+
)
140+
141+
142+
def test_create_version_e2e(client):
143+
prompt_resource = client.prompt_management.create_version(
144+
prompt=TEST_PROMPT,
145+
config=TEST_CONFIG,
146+
)
147+
assert isinstance(prompt_resource, types.Prompt)
148+
assert isinstance(prompt_resource.dataset, types.Dataset)
149+
150+
# Test local prompt resource is the same after calling get()
151+
retrieved_prompt = client.prompt_management.get(prompt_id=prompt_resource.prompt_id)
152+
assert (
153+
retrieved_prompt.prompt_data.system_instruction
154+
== prompt_resource.prompt_data.system_instruction
155+
)
156+
assert (
157+
retrieved_prompt.prompt_data.variables[0]["name"].text
158+
== TEST_VARIABLES[0]["name"].text
159+
)
160+
assert (
161+
retrieved_prompt.prompt_data.generation_config.temperature
162+
== prompt_resource.prompt_data.generation_config.temperature
163+
)
164+
assert (
165+
retrieved_prompt.prompt_data.safety_settings
166+
== prompt_resource.prompt_data.safety_settings
167+
)
168+
assert retrieved_prompt.prompt_data.model == prompt_resource.prompt_data.model
169+
assert (
170+
retrieved_prompt.prompt_data.tool_config
171+
== prompt_resource.prompt_data.tool_config
172+
)
173+
assert (
174+
retrieved_prompt.prompt_data.generation_config
175+
== prompt_resource.prompt_data.generation_config
176+
)
177+
178+
# Test calling create_version again uses dataset from local Prompt resource.
179+
prompt_resource_2 = client.prompt_management.create_version(
180+
prompt=TEST_PROMPT,
181+
config=types.CreatePromptConfig(
182+
version_display_name="my_version",
183+
),
184+
)
185+
assert prompt_resource_2.dataset.name == prompt_resource.dataset.name
186+
187+
188+
def test_create_version_in_existing_dataset(client):
189+
prompt_resource = client.prompt_management.create_version(
190+
prompt=TEST_PROMPT,
191+
config=types.CreatePromptConfig(
192+
prompt_id=TEST_PROMPT_DATASET_ID,
193+
prompt_display_name=TEST_CONFIG.prompt_display_name,
194+
version_display_name="my_version_existing_dataset",
195+
),
196+
)
197+
assert isinstance(prompt_resource, types.Prompt)
198+
assert isinstance(prompt_resource.dataset, types.Dataset)
199+
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
200+
assert prompt_resource.dataset.name.endswith(TEST_PROMPT_DATASET_ID)
201+
202+
203+
def test_create_version_with_version_name(client):
204+
version_name = "a_new_version_yay"
205+
prompt_resource = client.prompt_management.create_version(
206+
prompt=TEST_PROMPT,
207+
config=types.CreatePromptConfig(
208+
version_display_name=version_name,
209+
),
210+
)
211+
assert isinstance(prompt_resource, types.Prompt)
212+
assert isinstance(prompt_resource.dataset, types.Dataset)
213+
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
214+
assert prompt_resource.dataset_version.display_name == version_name
215+
216+
217+
def test_create_version_with_file_data(client):
218+
version_name = "prompt with file data"
219+
220+
audio_file_part = genai_types.Part(
221+
file_data=genai_types.FileData(
222+
file_uri="https://generativelanguage.googleapis.com/v1beta/files/57w3vpfomj71",
223+
mime_type="video/mp4",
224+
),
225+
)
226+
227+
prompt_resource = client.prompt_management.create_version(
228+
prompt=types.Prompt(
229+
prompt_data=types.PromptData(
230+
contents=[
231+
genai_types.Content(
232+
parts=[
233+
audio_file_part,
234+
genai_types.Part(text="What is this recording about?"),
235+
]
236+
)
237+
],
238+
system_instruction=genai_types.Content(
239+
parts=[genai_types.Part(text="Answer in three sentences.")]
240+
),
241+
generation_config=genai_types.GenerationConfig(temperature=0.1),
242+
safety_settings=[
243+
genai_types.SafetySetting(
244+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
245+
threshold="BLOCK_MEDIUM_AND_ABOVE",
246+
method="SEVERITY",
247+
)
248+
],
249+
model="gemini-2.0-flash-001",
250+
),
251+
),
252+
config=types.CreatePromptConfig(
253+
version_display_name=version_name,
254+
prompt_display_name="my prompt with file data",
255+
),
256+
)
257+
assert isinstance(prompt_resource, types.Prompt)
258+
assert isinstance(prompt_resource.dataset, types.Dataset)
259+
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
260+
assert prompt_resource.dataset_version.display_name == version_name
261+
262+
# Confirm file data is preserved when we retrieve the prompt.
263+
retrieved_prompt = client.prompt_management.get(
264+
prompt_id=prompt_resource.prompt_id,
265+
)
266+
assert (
267+
retrieved_prompt.prompt_data.contents[0].parts[0].file_data.file_uri
268+
== audio_file_part.file_data.file_uri
269+
)
270+
assert (
271+
retrieved_prompt.prompt_data.contents[0].parts[0].file_data.display_name
272+
== audio_file_part.file_data.display_name
273+
)
274+
275+
# Test assemble_contents on the prompt works.
276+
contents = retrieved_prompt.assemble_contents()
277+
assert contents[0] == prompt_resource.prompt_data.contents[0]
278+
279+
280+
pytestmark = pytest_helper.setup(
281+
file=__file__,
282+
globals_for_file=globals(),
283+
test_method="prompt_management.create_version",
284+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
20+
21+
def test_get_dataset_operation(client):
22+
dataset_operation = client.prompt_management._get_dataset_operation(
23+
config=types.GetDatasetOperationConfig(should_return_http_response=True),
24+
dataset_id="6550997480673116160",
25+
operation_id="5108504762664353792",
26+
)
27+
assert dataset_operation.sdk_http_response.body is not None
28+
29+
30+
pytestmark = pytest_helper.setup(
31+
file=__file__,
32+
globals_for_file=globals(),
33+
test_method="prompt_management._get_dataset_operation",
34+
)

0 commit comments

Comments
 (0)