Skip to content

Commit ac33e87

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add filter option for deploy configuration in Model Garden deploy SDK
PiperOrigin-RevId: 812990100
1 parent 7db9b4f commit ac33e87

File tree

2 files changed

+231
-85
lines changed

2 files changed

+231
-85
lines changed

tests/unit/vertexai/model_garden/test_model_garden.py

Lines changed: 160 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -181,97 +181,106 @@ def get_publisher_model_mock():
181181
with mock.patch.object(
182182
model_garden_service.ModelGardenServiceClient, "get_publisher_model"
183183
) as get_publisher_model_mock:
184-
get_publisher_model_mock.side_effect = [
185-
types.PublisherModel(name=_TEST_PUBLISHER_MODEL_NAME),
186-
types.PublisherModel(
187-
name=_TEST_PUBLISHER_MODEL_NAME,
188-
supported_actions=types.PublisherModel.CallToAction(
189-
multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex(
190-
multi_deploy_vertex=[
191-
types.PublisherModel.CallToAction.Deploy(
192-
deploy_task_name="vLLM 32K context",
193-
container_spec=types.ModelContainerSpec(
194-
image_uri=_TEST_IMAGE_URI,
195-
command=["python", "main.py"],
196-
args=["--model-id=gemma-2b"],
197-
env=[
198-
types.EnvVar(name="MODEL_ID", value="gemma-2b")
199-
],
200-
),
201-
dedicated_resources=types.DedicatedResources(
202-
machine_spec=types.MachineSpec(
203-
machine_type="g2-standard-16",
204-
accelerator_type="NVIDIA_L4",
205-
accelerator_count=1,
206-
)
207-
),
184+
error_response = types.PublisherModel(name=_TEST_PUBLISHER_MODEL_NAME)
185+
success_response = types.PublisherModel(
186+
name=_TEST_PUBLISHER_MODEL_NAME,
187+
supported_actions=types.PublisherModel.CallToAction(
188+
multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex(
189+
multi_deploy_vertex=[
190+
types.PublisherModel.CallToAction.Deploy(
191+
deploy_task_name="vLLM 32K context",
192+
container_spec=types.ModelContainerSpec(
193+
image_uri=_TEST_IMAGE_URI,
194+
command=["python", "main.py"],
195+
args=["--model-id=gemma-2b"],
196+
env=[types.EnvVar(name="MODEL_ID", value="gemma-2b")],
208197
),
209-
types.PublisherModel.CallToAction.Deploy(
210-
deploy_task_name="vLLM 128K context",
211-
container_spec=types.ModelContainerSpec(
212-
image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest",
213-
command=["python", "main.py"],
214-
args=["--model-id=gemma-2b"],
215-
env=[
216-
types.EnvVar(name="MODEL_ID", value="gemma-2b")
217-
],
218-
),
219-
dedicated_resources=types.DedicatedResources(
220-
machine_spec=types.MachineSpec(
221-
machine_type="g2-standard-32",
222-
accelerator_type="NVIDIA_L4",
223-
accelerator_count=4,
224-
)
225-
),
198+
dedicated_resources=types.DedicatedResources(
199+
machine_spec=types.MachineSpec(
200+
machine_type="g2-standard-16",
201+
accelerator_type="NVIDIA_L4",
202+
accelerator_count=1,
203+
)
226204
),
227-
]
228-
)
229-
),
205+
),
206+
types.PublisherModel.CallToAction.Deploy(
207+
deploy_task_name="vLLM 128K context",
208+
container_spec=types.ModelContainerSpec(
209+
image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest",
210+
command=["python", "main.py"],
211+
args=["--model-id=gemma-2b"],
212+
env=[types.EnvVar(name="MODEL_ID", value="gemma-2b")],
213+
),
214+
dedicated_resources=types.DedicatedResources(
215+
machine_spec=types.MachineSpec(
216+
machine_type="g2-standard-32",
217+
accelerator_type="NVIDIA_L4",
218+
accelerator_count=4,
219+
)
220+
),
221+
),
222+
]
223+
)
230224
),
231-
types.PublisherModel(
232-
name=_TEST_MODEL_HUGGING_FACE_RESOURCE_NAME,
233-
supported_actions=types.PublisherModel.CallToAction(
234-
multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex(
235-
multi_deploy_vertex=[
236-
types.PublisherModel.CallToAction.Deploy(
237-
container_spec=types.ModelContainerSpec(
238-
image_uri=_TEST_IMAGE_URI,
239-
command=["python", "main.py"],
240-
args=["--model-id=gemma-2b"],
241-
env=[
242-
types.EnvVar(name="MODEL_ID", value="gemma-2b")
243-
],
244-
),
245-
dedicated_resources=types.DedicatedResources(
246-
machine_spec=types.MachineSpec(
247-
machine_type="g2-standard-16",
248-
accelerator_type="NVIDIA_L4",
249-
accelerator_count=1,
250-
)
251-
),
225+
)
226+
hf_success_response = types.PublisherModel(
227+
name=_TEST_MODEL_HUGGING_FACE_RESOURCE_NAME,
228+
supported_actions=types.PublisherModel.CallToAction(
229+
multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex(
230+
multi_deploy_vertex=[
231+
types.PublisherModel.CallToAction.Deploy(
232+
container_spec=types.ModelContainerSpec(
233+
image_uri=_TEST_IMAGE_URI,
234+
command=["python", "main.py"],
235+
args=["--model-id=gemma-2b"],
236+
env=[types.EnvVar(name="MODEL_ID", value="gemma-2b")],
252237
),
253-
types.PublisherModel.CallToAction.Deploy(
254-
container_spec=types.ModelContainerSpec(
255-
image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest",
256-
command=["python", "main.py"],
257-
args=["--model-id=gemma-2b"],
258-
env=[
259-
types.EnvVar(name="MODEL_ID", value="gemma-2b")
260-
],
261-
),
262-
dedicated_resources=types.DedicatedResources(
263-
machine_spec=types.MachineSpec(
264-
machine_type="g2-standard-32",
265-
accelerator_type="NVIDIA_L4",
266-
accelerator_count=4,
267-
)
268-
),
238+
dedicated_resources=types.DedicatedResources(
239+
machine_spec=types.MachineSpec(
240+
machine_type="g2-standard-16",
241+
accelerator_type="NVIDIA_L4",
242+
accelerator_count=1,
243+
)
269244
),
270-
]
271-
)
272-
),
245+
),
246+
types.PublisherModel.CallToAction.Deploy(
247+
container_spec=types.ModelContainerSpec(
248+
image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest",
249+
command=["python", "main.py"],
250+
args=["--model-id=gemma-2b"],
251+
env=[types.EnvVar(name="MODEL_ID", value="gemma-2b")],
252+
),
253+
dedicated_resources=types.DedicatedResources(
254+
machine_spec=types.MachineSpec(
255+
machine_type="g2-standard-32",
256+
accelerator_type="NVIDIA_L4",
257+
accelerator_count=4,
258+
)
259+
),
260+
),
261+
]
262+
)
273263
),
274-
]
264+
)
265+
266+
call_counts = {}
267+
268+
def side_effect_func(request, *args, **kwargs):
269+
model_name = request.name
270+
if model_name not in call_counts:
271+
call_counts[model_name] = 0
272+
273+
call_counts[model_name] += 1
274+
275+
if model_name == _TEST_HUGGING_FACE_MODEL_FULL_RESOURCE_NAME:
276+
return hf_success_response
277+
278+
if call_counts[model_name] == 1:
279+
return error_response
280+
else:
281+
return success_response
282+
283+
get_publisher_model_mock.side_effect = side_effect_func
275284
yield get_publisher_model_mock
276285

277286

@@ -1239,6 +1248,72 @@ def test_list_deploy_options_concise(self, get_publisher_model_mock):
12391248
)
12401249
)
12411250

1251+
def test_list_deploy_options_with_filters(self, get_publisher_model_mock):
1252+
"""Tests getting the supported deploy options for a model with filters."""
1253+
aiplatform.init(
1254+
project=_TEST_PROJECT,
1255+
location=_TEST_LOCATION,
1256+
)
1257+
model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
1258+
1259+
expected_message = (
1260+
"Model does not support deployment. "
1261+
"Use `list_deployable_models()` to find supported models."
1262+
)
1263+
with pytest.raises(ValueError) as exception:
1264+
_ = model.list_deploy_options()
1265+
assert str(exception.value) == expected_message
1266+
1267+
# Test serving_container_image_uri_filter
1268+
result = model.list_deploy_options(serving_container_image_uri_filter="vllm")
1269+
assert len(result) == 1
1270+
assert "vllm" in result[0].container_spec.image_uri
1271+
1272+
# Test case-insensitivity for serving_container_image_uri_filter
1273+
result = model.list_deploy_options(serving_container_image_uri_filter="VLLM")
1274+
assert len(result) == 1
1275+
assert "vllm" in result[0].container_spec.image_uri
1276+
1277+
# Test list of strings for serving_container_image_uri_filter
1278+
result = model.list_deploy_options(
1279+
serving_container_image_uri_filter=["vllm", "text-generation-inference"]
1280+
)
1281+
assert len(result) == 2
1282+
1283+
# Test machine_type_filter
1284+
result = model.list_deploy_options(machine_type_filter="g2-standard-16")
1285+
assert len(result) == 1
1286+
assert (
1287+
"g2-standard-16" == result[0].dedicated_resources.machine_spec.machine_type
1288+
)
1289+
1290+
# Test case-insensitivity for machine_type_filter
1291+
result = model.list_deploy_options(machine_type_filter="G2-STANDARD-16")
1292+
assert len(result) == 1
1293+
assert (
1294+
"g2-standard-16" == result[0].dedicated_resources.machine_spec.machine_type
1295+
)
1296+
1297+
# Test accelerator_type_filter
1298+
result = model.list_deploy_options(accelerator_type_filter="L4")
1299+
assert len(result) == 2
1300+
1301+
# Test case-insensitivity for accelerator_type_filter
1302+
result = model.list_deploy_options(accelerator_type_filter="l4")
1303+
assert len(result) == 2
1304+
1305+
# Test combination of filters
1306+
result = model.list_deploy_options(
1307+
serving_container_image_uri_filter="vllm",
1308+
machine_type_filter="g2-standard-16",
1309+
accelerator_type_filter="L4",
1310+
)
1311+
assert len(result) == 1
1312+
1313+
# Test with no match
1314+
with pytest.raises(ValueError):
1315+
model.list_deploy_options(machine_type_filter="non-existent")
1316+
12421317
def test_list_deployable_models(self, list_publisher_models_mock):
12431318
"""Tests getting the supported deploy options for a model."""
12441319
aiplatform.init(

vertexai/model_garden/_model_garden.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,12 +678,25 @@ def deploy(
678678
def list_deploy_options(
679679
self,
680680
concise: bool = False,
681+
serving_container_image_uri_filter: Optional[Union[str, List[str]]] = None,
682+
machine_type_filter: Optional[str] = None,
683+
accelerator_type_filter: Optional[str] = None,
681684
) -> Union[str, Sequence[types.PublisherModel.CallToAction.Deploy]]:
682685
"""Lists the verified deploy options for the model.
683686
684687
Args:
685688
concise: If true, returns a human-readable string with container and
686689
machine specs.
690+
serving_container_image_uri_filter: If specified, only return the
691+
deploy options where the serving container image URI contains one of
692+
the specified keyword(s) (e.g., "vllm" or ["vllm", "tgi"]). The
693+
filter is case-insensitive.
694+
machine_type_filter: If specified, only return the deploy options
695+
where the machine type contains one of the specified keyword(s)
696+
(e.g., "n1" or ["n1", "g2"]). The filter is case-insensitive.
697+
accelerator_type_filter: If specified, only return the deploy options
698+
where the accelerator type contains one of the specified keyword(s)
699+
(e.g., "T4" or ["T4", "L4"]). The filter is case-insensitive.
687700
688701
Returns:
689702
A list of deploy options or a concise formatted string.
@@ -704,6 +717,64 @@ def list_deploy_options(
704717
"Use `list_deployable_models()` to find supported models."
705718
)
706719

720+
if serving_container_image_uri_filter:
721+
if isinstance(serving_container_image_uri_filter, str):
722+
serving_container_image_uri_filter = [
723+
serving_container_image_uri_filter
724+
]
725+
serving_container_image_uri_filter = [
726+
f.lower() for f in serving_container_image_uri_filter
727+
]
728+
deploy_options = [
729+
option
730+
for option in deploy_options
731+
if option.container_spec
732+
and any(
733+
f in option.container_spec.image_uri.lower()
734+
for f in serving_container_image_uri_filter
735+
)
736+
]
737+
738+
if machine_type_filter:
739+
filters = (
740+
[machine_type_filter]
741+
if isinstance(machine_type_filter, str)
742+
else machine_type_filter
743+
)
744+
deploy_options = [
745+
option
746+
for option in deploy_options
747+
if option.dedicated_resources
748+
and option.dedicated_resources.machine_spec
749+
and any(
750+
f.lower()
751+
in option.dedicated_resources.machine_spec.machine_type.lower()
752+
for f in filters
753+
)
754+
]
755+
756+
if accelerator_type_filter:
757+
filters = (
758+
[accelerator_type_filter]
759+
if isinstance(accelerator_type_filter, str)
760+
else accelerator_type_filter
761+
)
762+
deploy_options = [
763+
option
764+
for option in deploy_options
765+
if option.dedicated_resources
766+
and option.dedicated_resources.machine_spec
767+
and option.dedicated_resources.machine_spec.accelerator_type
768+
and any(
769+
f.lower()
770+
in option.dedicated_resources.machine_spec.accelerator_type.name.lower()
771+
for f in filters
772+
)
773+
]
774+
775+
if not deploy_options:
776+
raise ValueError("No deploy options found.")
777+
707778
if not concise:
708779
return deploy_options
709780

0 commit comments

Comments
 (0)