Skip to content

Commit af8c898

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
feat: Support for explicitly providing class_methods in Agent Engine config.
PiperOrigin-RevId: 819409737
1 parent b91b63c commit af8c898

File tree

5 files changed

+180
-10
lines changed

5 files changed

+180
-10
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,16 @@ def register_operations(self) -> Dict[str, List[str]]:
667667
}
668668

669669

670+
_TEST_AGENT_ENGINE_CLASS_METHODS = [
671+
{
672+
"name": "query",
673+
"description": "Simple query method",
674+
"parameters": {"type": "object", "properties": {"input": {"type": "string"}}},
675+
"api_mode": "",
676+
}
677+
]
678+
679+
670680
def _create_empty_fake_package(package_name: str) -> str:
671681
"""Creates a temporary directory structure representing an empty fake Python package.
672682
@@ -1354,6 +1364,7 @@ def test_create_agent_engine_with_env_vars_dict(
13541364
encryption_spec=None,
13551365
agent_server_mode=None,
13561366
labels=None,
1367+
class_methods=None,
13571368
)
13581369
request_mock.assert_called_with(
13591370
"post",
@@ -1435,6 +1446,7 @@ def test_create_agent_engine_with_custom_service_account(
14351446
encryption_spec=None,
14361447
labels=None,
14371448
agent_server_mode=None,
1449+
class_methods=None,
14381450
)
14391451
request_mock.assert_called_with(
14401452
"post",
@@ -1518,6 +1530,7 @@ def test_create_agent_engine_with_experimental_mode(
15181530
encryption_spec=None,
15191531
labels=None,
15201532
agent_server_mode=_genai_types.AgentServerMode.EXPERIMENTAL,
1533+
class_methods=None,
15211534
)
15221535
request_mock.assert_called_with(
15231536
"post",
@@ -1540,6 +1553,85 @@ def test_create_agent_engine_with_experimental_mode(
15401553
None,
15411554
)
15421555

1556+
@mock.patch.object(agent_engines.AgentEngines, "_create_config")
1557+
@mock.patch.object(_agent_engines_utils, "_await_operation")
1558+
def test_create_agent_engine_with_class_methods(
1559+
self,
1560+
mock_await_operation,
1561+
mock_create_config,
1562+
):
1563+
mock_create_config.return_value = {
1564+
"display_name": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1565+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1566+
"spec": {
1567+
"package_spec": {
1568+
"python_version": _TEST_PYTHON_VERSION,
1569+
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
1570+
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1571+
},
1572+
"class_methods": _TEST_AGENT_ENGINE_CLASS_METHODS,
1573+
},
1574+
}
1575+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
1576+
response=_genai_types.ReasoningEngine(
1577+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
1578+
spec=_TEST_AGENT_ENGINE_SPEC,
1579+
)
1580+
)
1581+
with mock.patch.object(
1582+
self.client.agent_engines._api_client, "request"
1583+
) as request_mock:
1584+
request_mock.return_value = genai_types.HttpResponse(body="")
1585+
self.client.agent_engines.create(
1586+
agent=self.test_agent,
1587+
config=_genai_types.AgentEngineConfig(
1588+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1589+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1590+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
1591+
staging_bucket=_TEST_STAGING_BUCKET,
1592+
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
1593+
),
1594+
)
1595+
mock_create_config.assert_called_with(
1596+
mode="create",
1597+
agent=self.test_agent,
1598+
staging_bucket=_TEST_STAGING_BUCKET,
1599+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1600+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1601+
description=None,
1602+
gcs_dir_name=None,
1603+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
1604+
env_vars=None,
1605+
service_account=None,
1606+
context_spec=None,
1607+
psc_interface_config=None,
1608+
min_instances=None,
1609+
max_instances=None,
1610+
resource_limits=None,
1611+
container_concurrency=None,
1612+
encryption_spec=None,
1613+
labels=None,
1614+
agent_server_mode=None,
1615+
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
1616+
)
1617+
request_mock.assert_called_with(
1618+
"post",
1619+
"reasoningEngines",
1620+
{
1621+
"displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1622+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1623+
"spec": {
1624+
"class_methods": _TEST_AGENT_ENGINE_CLASS_METHODS,
1625+
"package_spec": {
1626+
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
1627+
"python_version": _TEST_PYTHON_VERSION,
1628+
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1629+
},
1630+
},
1631+
},
1632+
None,
1633+
)
1634+
15431635
@pytest.mark.usefixtures("caplog")
15441636
@mock.patch.object(_agent_engines_utils, "_prepare")
15451637
@mock.patch.object(_agent_engines_utils, "_await_operation")

vertexai/_genai/_agent_engines_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,13 @@ def _generate_class_methods_spec_or_raise(
566566
return class_methods_spec
567567

568568

569+
def _class_methods_to_class_methods_spec(
570+
class_methods: List[dict[str, Any]],
571+
) -> List[proto.Message]:
572+
"""Converts a list of class methods to a list of ReasoningEngineSpec.ClassMethod messages."""
573+
return [_to_proto(class_method) for class_method in class_methods]
574+
575+
569576
def _is_pydantic_serializable(param: inspect.Parameter) -> bool:
570577
"""Checks if the parameter is pydantic serializable."""
571578

vertexai/_genai/agent_engines.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,7 @@ def create(
870870
encryption_spec=config.encryption_spec,
871871
agent_server_mode=config.agent_server_mode,
872872
labels=config.labels,
873+
class_methods=config.class_methods,
873874
)
874875
operation = self._create(config=api_config)
875876
# TODO: Use a more specific link.
@@ -928,6 +929,7 @@ def _create_config(
928929
encryption_spec: Optional[genai_types.EncryptionSpecDict] = None,
929930
labels: Optional[dict[str, str]] = None,
930931
agent_server_mode: Optional[types.AgentServerMode] = None,
932+
class_methods: Optional[Sequence[dict[str, Any]]] = None,
931933
) -> types.UpdateAgentEngineConfigDict:
932934
import sys
933935

@@ -975,9 +977,6 @@ def _create_config(
975977
extra_packages = _agent_engines_utils._validate_extra_packages_or_raise(
976978
extra_packages=extra_packages,
977979
)
978-
extra_packages = _agent_engines_utils._validate_extra_packages_or_raise(
979-
extra_packages=extra_packages,
980-
)
981980
# Prepares the Agent Engine for creation/update in Vertex AI. This
982981
# involves packaging and uploading the artifacts for agent_engine,
983982
# requirements and extra_packages to `staging_bucket/gcs_dir_name`.
@@ -1041,13 +1040,27 @@ def _create_config(
10411040
if service_account is not None:
10421041
agent_engine_spec["service_account"] = service_account
10431042
update_masks.append("spec.service_account")
1044-
class_methods = _agent_engines_utils._generate_class_methods_spec_or_raise(
1045-
agent=agent,
1046-
operations=_agent_engines_utils._get_registered_operations(agent=agent),
1047-
)
1043+
1044+
update_masks.append("spec.class_methods")
1045+
class_methods_spec = []
1046+
if class_methods is not None:
1047+
class_methods_spec = (
1048+
_agent_engines_utils._class_methods_to_class_methods_spec(
1049+
class_methods=class_methods
1050+
)
1051+
)
1052+
else:
1053+
class_methods_spec = (
1054+
_agent_engines_utils._generate_class_methods_spec_or_raise(
1055+
agent=agent,
1056+
operations=_agent_engines_utils._get_registered_operations(
1057+
agent=agent
1058+
),
1059+
)
1060+
)
10481061
agent_engine_spec["class_methods"] = [
1049-
_agent_engines_utils._to_dict(class_method)
1050-
for class_method in class_methods
1062+
_agent_engines_utils._to_dict(class_method_spec)
1063+
for class_method_spec in class_methods_spec
10511064
]
10521065

10531066
if agent_server_mode:
@@ -1059,7 +1072,6 @@ def _create_config(
10591072
"agent_server_mode"
10601073
] = agent_server_mode
10611074

1062-
update_masks.append("spec.class_methods")
10631075
agent_engine_spec["agent_framework"] = (
10641076
_agent_engines_utils._get_agent_framework(agent=agent)
10651077
)
@@ -1283,6 +1295,7 @@ def update(
12831295
resource_limits=config.resource_limits,
12841296
container_concurrency=config.container_concurrency,
12851297
labels=config.labels,
1298+
class_methods=config.class_methods,
12861299
)
12871300
operation = self._update(name=name, config=api_config)
12881301
logger.info(

vertexai/_genai/types.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4630,6 +4630,15 @@ class CreateAgentEngineConfig(_common.BaseModel):
46304630
labels: Optional[dict[str, str]] = Field(
46314631
default=None, description="""The labels to be used for the Agent Engine."""
46324632
)
4633+
class_methods: Optional[list[dict[str, Any]]] = Field(
4634+
default=None,
4635+
description="""The class methods to be used for the Agent Engine.
4636+
If specified, they'll override the class methods that are autogenerated by
4637+
default. By default, methods are generated by inspecting the agent object
4638+
and generating a corresponding method for each method defined on the
4639+
agent class.
4640+
""",
4641+
)
46334642

46344643

46354644
class CreateAgentEngineConfigDict(TypedDict, total=False):
@@ -4687,6 +4696,14 @@ class CreateAgentEngineConfigDict(TypedDict, total=False):
46874696
labels: Optional[dict[str, str]]
46884697
"""The labels to be used for the Agent Engine."""
46894698

4699+
class_methods: Optional[list[dict[str, Any]]]
4700+
"""The class methods to be used for the Agent Engine.
4701+
If specified, they'll override the class methods that are autogenerated by
4702+
default. By default, methods are generated by inspecting the agent object
4703+
and generating a corresponding method for each method defined on the
4704+
agent class.
4705+
"""
4706+
46904707

46914708
CreateAgentEngineConfigOrDict = Union[
46924709
CreateAgentEngineConfig, CreateAgentEngineConfigDict
@@ -5244,6 +5261,15 @@ class UpdateAgentEngineConfig(_common.BaseModel):
52445261
labels: Optional[dict[str, str]] = Field(
52455262
default=None, description="""The labels to be used for the Agent Engine."""
52465263
)
5264+
class_methods: Optional[list[dict[str, Any]]] = Field(
5265+
default=None,
5266+
description="""The class methods to be used for the Agent Engine.
5267+
If specified, they'll override the class methods that are autogenerated by
5268+
default. By default, methods are generated by inspecting the agent object
5269+
and generating a corresponding method for each method defined on the
5270+
agent class.
5271+
""",
5272+
)
52475273
update_mask: Optional[str] = Field(
52485274
default=None,
52495275
description="""The update mask to apply. For the `FieldMask` definition, see
@@ -5306,6 +5332,14 @@ class UpdateAgentEngineConfigDict(TypedDict, total=False):
53065332
labels: Optional[dict[str, str]]
53075333
"""The labels to be used for the Agent Engine."""
53085334

5335+
class_methods: Optional[list[dict[str, Any]]]
5336+
"""The class methods to be used for the Agent Engine.
5337+
If specified, they'll override the class methods that are autogenerated by
5338+
default. By default, methods are generated by inspecting the agent object
5339+
and generating a corresponding method for each method defined on the
5340+
agent class.
5341+
"""
5342+
53095343
update_mask: Optional[str]
53105344
"""The update mask to apply. For the `FieldMask` definition, see
53115345
https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask."""
@@ -11806,6 +11840,15 @@ class AgentEngineConfig(_common.BaseModel):
1180611840
agent_server_mode: Optional[AgentServerMode] = Field(
1180711841
default=None, description="""The agent server mode to use for deployment."""
1180811842
)
11843+
class_methods: Optional[list[dict[str, Any]]] = Field(
11844+
default=None,
11845+
description="""The class methods to be used for the Agent Engine.
11846+
If specified, they'll override the class methods that are autogenerated by
11847+
default. By default, methods are generated by inspecting the agent object
11848+
and generating a corresponding method for each method defined on the
11849+
agent class.
11850+
""",
11851+
)
1180911852

1181011853

1181111854
class AgentEngineConfigDict(TypedDict, total=False):
@@ -11892,6 +11935,14 @@ class AgentEngineConfigDict(TypedDict, total=False):
1189211935
agent_server_mode: Optional[AgentServerMode]
1189311936
"""The agent server mode to use for deployment."""
1189411937

11938+
class_methods: Optional[list[dict[str, Any]]]
11939+
"""The class methods to be used for the Agent Engine.
11940+
If specified, they'll override the class methods that are autogenerated by
11941+
default. By default, methods are generated by inspecting the agent object
11942+
and generating a corresponding method for each method defined on the
11943+
agent class.
11944+
"""
11945+
1189511946

1189611947
AgentEngineConfigOrDict = Union[AgentEngineConfig, AgentEngineConfigDict]
1189711948

vertexai/agent_engines/_agent_engines.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,3 +1903,10 @@ def _generate_class_methods_spec_or_raise(
19031903
class_methods_spec.append(class_method)
19041904

19051905
return class_methods_spec
1906+
1907+
1908+
def _class_methods_to_class_methods_spec(
1909+
class_methods: List[dict[str, Any]],
1910+
) -> List[proto.Message]:
1911+
"""Converts a list of class methods to a list of ReasoningEngineSpec.ClassMethod messages."""
1912+
return [_utils.to_proto(class_method) for class_method in class_methods]

0 commit comments

Comments
 (0)