Skip to content

Commit 900a449

Browse files
ji-yaqinicain
andauthored
fix: add parameters_value in PipelineJob for schema > 2.0.0 (#817)
* feat: update PipelineJob to accept protobuf value * fix tests * address comments * fix: update Pipeline Job parameter values according to schema_version * fix test * fix key to parameters * fix format' * address comments Co-authored-by: nicain <nicain.seattle@gmail.com>
1 parent 2f9a879 commit 900a449

File tree

3 files changed

+204
-82
lines changed

3 files changed

+204
-82
lines changed

google/cloud/aiplatform/utils/pipeline_utils.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
import json
1919
from typing import Any, Dict, Mapping, Optional, Union
20+
import packaging.version
2021

2122

2223
class PipelineRuntimeConfigBuilder(object):
@@ -28,6 +29,7 @@ class PipelineRuntimeConfigBuilder(object):
2829
def __init__(
2930
self,
3031
pipeline_root: str,
32+
schema_version: str,
3133
parameter_types: Mapping[str, str],
3234
parameter_values: Optional[Dict[str, Any]] = None,
3335
):
@@ -36,12 +38,15 @@ def __init__(
3638
Args:
3739
pipeline_root (str):
3840
Required. The root of the pipeline outputs.
41+
schema_version (str):
42+
Required. Schema version of the IR. This field determines the fields supported in current version of IR.
3943
parameter_types (Mapping[str, str]):
4044
Required. The mapping from pipeline parameter name to its type.
4145
parameter_values (Dict[str, Any]):
4246
Optional. The mapping from runtime parameter name to its value.
4347
"""
4448
self._pipeline_root = pipeline_root
49+
self._schema_version = schema_version
4550
self._parameter_types = parameter_types
4651
self._parameter_values = copy.deepcopy(parameter_values or {})
4752

@@ -64,6 +69,8 @@ def from_job_spec_json(
6469
.get("inputDefinitions", {})
6570
.get("parameters", {})
6671
)
72+
schema_version = job_spec["pipelineSpec"]["schemaVersion"]
73+
6774
# 'type' is deprecated in IR and change to 'parameterType'.
6875
parameter_types = {
6976
k: v.get("parameterType") or v.get("type")
@@ -72,7 +79,7 @@ def from_job_spec_json(
7279

7380
pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
7481
parameter_values = _parse_runtime_parameters(runtime_config_spec)
75-
return cls(pipeline_root, parameter_types, parameter_values)
82+
return cls(pipeline_root, schema_version, parameter_types, parameter_values)
7683

7784
def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
7885
"""Updates pipeline_root value.
@@ -95,9 +102,12 @@ def update_runtime_parameters(
95102
"""
96103
if parameter_values:
97104
parameters = dict(parameter_values)
98-
for k, v in parameter_values.items():
99-
if isinstance(v, (dict, list, bool)):
100-
parameters[k] = json.dumps(v)
105+
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
106+
"2.0.0"
107+
):
108+
for k, v in parameter_values.items():
109+
if isinstance(v, (dict, list, bool)):
110+
parameters[k] = json.dumps(v)
101111
self._parameter_values.update(parameters)
102112

103113
def build(self) -> Dict[str, Any]:
@@ -111,9 +121,15 @@ def build(self) -> Dict[str, Any]:
111121
"Pipeline root must be specified, either during "
112122
"compile time, or when calling the service."
113123
)
124+
if packaging.version.parse(self._schema_version) > packaging.version.parse(
125+
"2.0.0"
126+
):
127+
parameter_values_key = "parameterValues"
128+
else:
129+
parameter_values_key = "parameters"
114130
return {
115131
"gcsOutputDirectory": self._pipeline_root,
116-
"parameters": {
132+
parameter_values_key: {
117133
k: self._get_vertex_value(k, v)
118134
for k, v in self._parameter_values.items()
119135
if v is not None
@@ -122,7 +138,7 @@ def build(self) -> Dict[str, Any]:
122138

123139
def _get_vertex_value(
124140
self, name: str, value: Union[int, float, str, bool, list, dict]
125-
) -> Dict[str, Any]:
141+
) -> Union[int, float, str, bool, list, dict]:
126142
"""Converts primitive values into Vertex pipeline Value proto message.
127143
128144
Args:
@@ -147,27 +163,21 @@ def _get_vertex_value(
147163
"pipeline job input definitions.".format(name)
148164
)
149165

150-
result = {}
151-
if self._parameter_types[name] == "INT":
152-
result["intValue"] = value
153-
elif self._parameter_types[name] == "DOUBLE":
154-
result["doubleValue"] = value
155-
elif self._parameter_types[name] == "STRING":
156-
result["stringValue"] = value
157-
elif self._parameter_types[name] == "BOOLEAN":
158-
result["boolValue"] = value
159-
elif self._parameter_types[name] == "NUMBER_DOUBLE":
160-
result["numberValue"] = value
161-
elif self._parameter_types[name] == "NUMBER_INTEGER":
162-
result["numberValue"] = value
163-
elif self._parameter_types[name] == "LIST":
164-
result["listValue"] = value
165-
elif self._parameter_types[name] == "STRUCT":
166-
result["structValue"] = value
166+
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
167+
"2.0.0"
168+
):
169+
result = {}
170+
if self._parameter_types[name] == "INT":
171+
result["intValue"] = value
172+
elif self._parameter_types[name] == "DOUBLE":
173+
result["doubleValue"] = value
174+
elif self._parameter_types[name] == "STRING":
175+
result["stringValue"] = value
176+
else:
177+
raise TypeError("Got unknown type of value: {}".format(value))
178+
return result
167179
else:
168-
raise TypeError("Got unknown type of value: {}".format(value))
169-
170-
return result
180+
return value
171181

172182

173183
def _parse_runtime_parameters(

0 commit comments

Comments
 (0)