1717import copy
1818import json
1919from typing import Any , Dict , Mapping , Optional , Union
20+ import packaging .version
2021
2122
2223class PipelineRuntimeConfigBuilder (object ):
@@ -28,6 +29,7 @@ class PipelineRuntimeConfigBuilder(object):
2829 def __init__ (
2930 self ,
3031 pipeline_root : str ,
32+ schema_version : str ,
3133 parameter_types : Mapping [str , str ],
3234 parameter_values : Optional [Dict [str , Any ]] = None ,
3335 ):
@@ -36,12 +38,15 @@ def __init__(
3638 Args:
3739 pipeline_root (str):
3840 Required. The root of the pipeline outputs.
41+ schema_version (str):
42+ Required. Schema version of the IR. This field determines the fields supported in current version of IR.
3943 parameter_types (Mapping[str, str]):
4044 Required. The mapping from pipeline parameter name to its type.
4145 parameter_values (Dict[str, Any]):
4246 Optional. The mapping from runtime parameter name to its value.
4347 """
4448 self ._pipeline_root = pipeline_root
49+ self ._schema_version = schema_version
4550 self ._parameter_types = parameter_types
4651 self ._parameter_values = copy .deepcopy (parameter_values or {})
4752
@@ -64,6 +69,8 @@ def from_job_spec_json(
6469 .get ("inputDefinitions" , {})
6570 .get ("parameters" , {})
6671 )
72+ schema_version = job_spec ["pipelineSpec" ]["schemaVersion" ]
73+
6774 # 'type' is deprecated in IR and change to 'parameterType'.
6875 parameter_types = {
6976 k : v .get ("parameterType" ) or v .get ("type" )
@@ -72,7 +79,7 @@ def from_job_spec_json(
7279
7380 pipeline_root = runtime_config_spec .get ("gcsOutputDirectory" )
7481 parameter_values = _parse_runtime_parameters (runtime_config_spec )
75- return cls (pipeline_root , parameter_types , parameter_values )
82+ return cls (pipeline_root , schema_version , parameter_types , parameter_values )
7683
7784 def update_pipeline_root (self , pipeline_root : Optional [str ]) -> None :
7885 """Updates pipeline_root value.
@@ -95,9 +102,12 @@ def update_runtime_parameters(
95102 """
96103 if parameter_values :
97104 parameters = dict (parameter_values )
98- for k , v in parameter_values .items ():
99- if isinstance (v , (dict , list , bool )):
100- parameters [k ] = json .dumps (v )
105+ if packaging .version .parse (self ._schema_version ) <= packaging .version .parse (
106+ "2.0.0"
107+ ):
108+ for k , v in parameter_values .items ():
109+ if isinstance (v , (dict , list , bool )):
110+ parameters [k ] = json .dumps (v )
101111 self ._parameter_values .update (parameters )
102112
103113 def build (self ) -> Dict [str , Any ]:
@@ -111,9 +121,15 @@ def build(self) -> Dict[str, Any]:
111121 "Pipeline root must be specified, either during "
112122 "compile time, or when calling the service."
113123 )
124+ if packaging .version .parse (self ._schema_version ) > packaging .version .parse (
125+ "2.0.0"
126+ ):
127+ parameter_values_key = "parameterValues"
128+ else :
129+ parameter_values_key = "parameters"
114130 return {
115131 "gcsOutputDirectory" : self ._pipeline_root ,
116- "parameters" : {
132+ parameter_values_key : {
117133 k : self ._get_vertex_value (k , v )
118134 for k , v in self ._parameter_values .items ()
119135 if v is not None
@@ -122,7 +138,7 @@ def build(self) -> Dict[str, Any]:
122138
123139 def _get_vertex_value (
124140 self , name : str , value : Union [int , float , str , bool , list , dict ]
125- ) -> Dict [ str , Any ]:
141+ ) -> Union [ int , float , str , bool , list , dict ]:
126142 """Converts primitive values into Vertex pipeline Value proto message.
127143
128144 Args:
@@ -147,27 +163,21 @@ def _get_vertex_value(
147163 "pipeline job input definitions." .format (name )
148164 )
149165
150- result = {}
151- if self ._parameter_types [name ] == "INT" :
152- result ["intValue" ] = value
153- elif self ._parameter_types [name ] == "DOUBLE" :
154- result ["doubleValue" ] = value
155- elif self ._parameter_types [name ] == "STRING" :
156- result ["stringValue" ] = value
157- elif self ._parameter_types [name ] == "BOOLEAN" :
158- result ["boolValue" ] = value
159- elif self ._parameter_types [name ] == "NUMBER_DOUBLE" :
160- result ["numberValue" ] = value
161- elif self ._parameter_types [name ] == "NUMBER_INTEGER" :
162- result ["numberValue" ] = value
163- elif self ._parameter_types [name ] == "LIST" :
164- result ["listValue" ] = value
165- elif self ._parameter_types [name ] == "STRUCT" :
166- result ["structValue" ] = value
166+ if packaging .version .parse (self ._schema_version ) <= packaging .version .parse (
167+ "2.0.0"
168+ ):
169+ result = {}
170+ if self ._parameter_types [name ] == "INT" :
171+ result ["intValue" ] = value
172+ elif self ._parameter_types [name ] == "DOUBLE" :
173+ result ["doubleValue" ] = value
174+ elif self ._parameter_types [name ] == "STRING" :
175+ result ["stringValue" ] = value
176+ else :
177+ raise TypeError ("Got unknown type of value: {}" .format (value ))
178+ return result
167179 else :
168- raise TypeError ("Got unknown type of value: {}" .format (value ))
169-
170- return result
180+ return value
171181
172182
173183def _parse_runtime_parameters (
0 commit comments