Skip to content

Commit 6c689d2

Browse files
committed
fix key to parameters
1 parent 8f50d38 commit 6c689d2

File tree

2 files changed

+22
-56
lines changed

2 files changed

+22
-56
lines changed

google/cloud/aiplatform/utils/pipeline_utils.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,10 @@ def from_job_spec_json(
6464
A PipelineRuntimeConfigBuilder object.
6565
"""
6666
runtime_config_spec = job_spec["runtimeConfig"]
67-
input_definitions = (
68-
job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {}
69-
)
7067
parameter_input_definitions = (
71-
input_definitions.get("parameterValues")
72-
or input_definitions.get("parameters")
73-
or {}
68+
job_spec["pipelineSpec"]["root"]
69+
.get("inputDefinitions", {})
70+
.get("parameters", {})
7471
)
7572
schema_version = job_spec["pipelineSpec"].get("schemaVersion")
7673

@@ -141,7 +138,7 @@ def build(self) -> Dict[str, Any]:
141138

142139
def _get_vertex_value(
143140
self, name: str, value: Union[int, float, str, bool, list, dict]
144-
) -> Dict[str, Any]:
141+
) -> Union[Dict[str, Any], int, float, str, bool, list, dict]:
145142
"""Converts primitive values into Vertex pipeline Value proto message.
146143
147144
Args:
@@ -166,26 +163,21 @@ def _get_vertex_value(
166163
"pipeline job input definitions.".format(name)
167164
)
168165

169-
result = {}
170-
if self._parameter_types[name] == "INT":
171-
result["intValue"] = value
172-
elif self._parameter_types[name] == "DOUBLE":
173-
result["doubleValue"] = value
174-
elif self._parameter_types[name] == "STRING":
175-
result["stringValue"] = value
176-
elif self._parameter_types[name] == "BOOLEAN":
177-
result["boolValue"] = value
178-
elif self._parameter_types[name] == "NUMBER_DOUBLE":
179-
result["numberValue"] = value
180-
elif self._parameter_types[name] == "NUMBER_INTEGER":
181-
result["numberValue"] = value
182-
elif self._parameter_types[name] == "LIST":
183-
result["listValue"] = value
184-
elif self._parameter_types[name] == "STRUCT":
185-
result["structValue"] = value
166+
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
167+
"2.0.0"
168+
):
169+
result = {}
170+
if self._parameter_types[name] == "INT":
171+
result["intValue"] = value
172+
elif self._parameter_types[name] == "DOUBLE":
173+
result["doubleValue"] = value
174+
elif self._parameter_types[name] == "STRING":
175+
result["stringValue"] = value
176+
else:
177+
raise TypeError("Got unknown type of value: {}".format(value))
178+
return result
186179
else:
187-
raise TypeError("Got unknown type of value: {}".format(value))
188-
return result
180+
return value
189181

190182

191183
def _parse_runtime_parameters(

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"root": {
7979
"dag": {"tasks": {}},
8080
"inputDefinitions": {
81-
"parameterValues": {
81+
"parameters": {
8282
"string_param": {"parameterType": "STRING"},
8383
"bool_param": {"parameterType": "BOOLEAN"},
8484
"double_param": {"parameterType": "NUMBER_DOUBLE"},
@@ -98,17 +98,7 @@
9898
"pipelineSpec": _TEST_PIPELINE_SPEC_LEGACY,
9999
}
100100
_TEST_PIPELINE_JOB = {
101-
"runtimeConfig": {
102-
"parameterValues": {
103-
"string_param": "lorem ipsum",
104-
"bool_param": True,
105-
"double_param": 12.34,
106-
"int_param": 5678,
107-
"list_int_param": [123, 456, 789],
108-
"list_string_param": ["lorem", "ipsum"],
109-
"struct_param": {"key1": 12345, "key2": 67890},
110-
},
111-
},
101+
"runtimeConfig": {"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES},
112102
"pipelineSpec": _TEST_PIPELINE_SPEC,
113103
}
114104

@@ -282,15 +272,7 @@ def test_run_call_pipeline_service_create(
282272

283273
expected_runtime_config_dict = {
284274
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
285-
"parameterValues": {
286-
"bool_param": {"boolValue": True},
287-
"double_param": {"numberValue": 12.34},
288-
"int_param": {"numberValue": 5678},
289-
"list_int_param": {"listValue": [123, 456, 789]},
290-
"list_string_param": {"listValue": ["lorem", "ipsum"]},
291-
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
292-
"string_param": {"stringValue": "hello world"},
293-
},
275+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
294276
}
295277
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
296278
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
@@ -425,15 +407,7 @@ def test_submit_call_pipeline_service_pipeline_job_create(
425407

426408
expected_runtime_config_dict = {
427409
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
428-
"parameterValues": {
429-
"bool_param": {"boolValue": True},
430-
"double_param": {"numberValue": 12.34},
431-
"int_param": {"numberValue": 5678},
432-
"list_int_param": {"listValue": [123, 456, 789]},
433-
"list_string_param": {"listValue": ["lorem", "ipsum"]},
434-
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
435-
"string_param": {"stringValue": "hello world"},
436-
},
410+
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
437411
}
438412
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
439413
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

0 commit comments

Comments
 (0)