@@ -29,26 +29,26 @@ class PipelineRuntimeConfigBuilder(object):
2929 def __init__ (
3030 self ,
3131 pipeline_root : str ,
32+ schema_version : str ,
3233 parameter_types : Mapping [str , str ],
3334 parameter_values : Optional [Dict [str , Any ]] = None ,
34- schema_version : Optional [str ] = None ,
3535 ):
3636 """Creates a PipelineRuntimeConfigBuilder object.
3737
3838 Args:
3939 pipeline_root (str):
4040 Required. The root of the pipeline outputs.
41+ schema_version (str):
42+ Required. Schema version of the IR. This field determines the fields supported in current version of IR.
4143 parameter_types (Mapping[str, str]):
4244 Required. The mapping from pipeline parameter name to its type.
4345 parameter_values (Dict[str, Any]):
4446 Optional. The mapping from runtime parameter name to its value.
45- schema_version (str):
46- Optional. Schema version of the IR. This field determines the fields supported in current version of IR.
4747 """
4848 self ._pipeline_root = pipeline_root
49+ self ._schema_version = schema_version
4950 self ._parameter_types = parameter_types
5051 self ._parameter_values = copy .deepcopy (parameter_values or {})
51- self ._schema_version = schema_version or "2.0.0"
5252
5353 @classmethod
5454 def from_job_spec_json (
@@ -64,15 +64,12 @@ def from_job_spec_json(
6464 A PipelineRuntimeConfigBuilder object.
6565 """
6666 runtime_config_spec = job_spec ["runtimeConfig" ]
67- input_definitions = (
68- job_spec ["pipelineSpec" ]["root" ].get ("inputDefinitions" ) or {}
69- )
7067 parameter_input_definitions = (
71- input_definitions . get ( "parameterValues" )
72- or input_definitions .get ("parameters" )
73- or {}
68+ job_spec [ "pipelineSpec" ][ "root" ]
69+ .get ("inputDefinitions" , {} )
70+ . get ( "parameters" , {})
7471 )
75- schema_version = job_spec ["pipelineSpec" ]. get ( "schemaVersion" )
72+ schema_version = job_spec ["pipelineSpec" ][ "schemaVersion" ]
7673
7774 # 'type' is deprecated in IR and change to 'parameterType'.
7875 parameter_types = {
@@ -82,7 +79,7 @@ def from_job_spec_json(
8279
8380 pipeline_root = runtime_config_spec .get ("gcsOutputDirectory" )
8481 parameter_values = _parse_runtime_parameters (runtime_config_spec )
85- return cls (pipeline_root , parameter_types , parameter_values , schema_version )
82+ return cls (pipeline_root , schema_version , parameter_types , parameter_values )
8683
8784 def update_pipeline_root (self , pipeline_root : Optional [str ]) -> None :
8885 """Updates pipeline_root value.
@@ -141,7 +138,7 @@ def build(self) -> Dict[str, Any]:
141138
142139 def _get_vertex_value (
143140 self , name : str , value : Union [int , float , str , bool , list , dict ]
144- ) -> Dict [str , Any ]:
141+ ) -> Union [ Dict [str , Any ], int , float , str , bool , list , dict ]:
145142 """Converts primitive values into Vertex pipeline Value proto message.
146143
147144 Args:
@@ -166,26 +163,21 @@ def _get_vertex_value(
166163 "pipeline job input definitions." .format (name )
167164 )
168165
169- result = {}
170- if self ._parameter_types [name ] == "INT" :
171- result ["intValue" ] = value
172- elif self ._parameter_types [name ] == "DOUBLE" :
173- result ["doubleValue" ] = value
174- elif self ._parameter_types [name ] == "STRING" :
175- result ["stringValue" ] = value
176- elif self ._parameter_types [name ] == "BOOLEAN" :
177- result ["boolValue" ] = value
178- elif self ._parameter_types [name ] == "NUMBER_DOUBLE" :
179- result ["numberValue" ] = value
180- elif self ._parameter_types [name ] == "NUMBER_INTEGER" :
181- result ["numberValue" ] = value
182- elif self ._parameter_types [name ] == "LIST" :
183- result ["listValue" ] = value
184- elif self ._parameter_types [name ] == "STRUCT" :
185- result ["structValue" ] = value
166+ if packaging .version .parse (self ._schema_version ) <= packaging .version .parse (
167+ "2.0.0"
168+ ):
169+ result = {}
170+ if self ._parameter_types [name ] == "INT" :
171+ result ["intValue" ] = value
172+ elif self ._parameter_types [name ] == "DOUBLE" :
173+ result ["doubleValue" ] = value
174+ elif self ._parameter_types [name ] == "STRING" :
175+ result ["stringValue" ] = value
176+ else :
177+ raise TypeError ("Got unknown type of value: {}" .format (value ))
178+ return result
186179 else :
187- raise TypeError ("Got unknown type of value: {}" .format (value ))
188- return result
180+ return value
189181
190182
191183def _parse_runtime_parameters (
0 commit comments