17
17
import copy
18
18
import json
19
19
from typing import Any , Dict , Mapping , Optional , Union
20
+ from google .cloud .aiplatform .compat .types import pipeline_failure_policy
20
21
import packaging .version
21
22
22
23
@@ -32,6 +33,7 @@ def __init__(
32
33
schema_version : str ,
33
34
parameter_types : Mapping [str , str ],
34
35
parameter_values : Optional [Dict [str , Any ]] = None ,
36
+ failure_policy : Optional [pipeline_failure_policy .PipelineFailurePolicy ] = None ,
35
37
):
36
38
"""Creates a PipelineRuntimeConfigBuilder object.
37
39
@@ -44,11 +46,20 @@ def __init__(
44
46
Required. The mapping from pipeline parameter name to its type.
45
47
parameter_values (Dict[str, Any]):
46
48
Optional. The mapping from runtime parameter name to its value.
49
+ failure_policy (pipeline_failure_policy.PipelineFailurePolicy):
50
+ Optional. Represents the failure policy of a pipeline. Currently, the
51
+ default of a pipeline is that the pipeline will continue to
52
+ run until no more tasks can be executed, also known as
53
+ PIPELINE_FAILURE_POLICY_FAIL_SLOW. However, if a pipeline is
54
+ set to PIPELINE_FAILURE_POLICY_FAIL_FAST, it will stop
55
+ scheduling any new tasks when a task has failed. Any
56
+ scheduled tasks will continue to completion.
47
57
"""
48
58
self ._pipeline_root = pipeline_root
49
59
self ._schema_version = schema_version
50
60
self ._parameter_types = parameter_types
51
61
self ._parameter_values = copy .deepcopy (parameter_values or {})
62
+ self ._failure_policy = failure_policy
52
63
53
64
@classmethod
54
65
def from_job_spec_json (
@@ -80,7 +91,14 @@ def from_job_spec_json(
80
91
81
92
pipeline_root = runtime_config_spec .get ("gcsOutputDirectory" )
82
93
parameter_values = _parse_runtime_parameters (runtime_config_spec )
83
- return cls (pipeline_root , schema_version , parameter_types , parameter_values )
94
+ failure_policy = runtime_config_spec .get ("failurePolicy" )
95
+ return cls (
96
+ pipeline_root ,
97
+ schema_version ,
98
+ parameter_types ,
99
+ parameter_values ,
100
+ failure_policy ,
101
+ )
84
102
85
103
def update_pipeline_root (self , pipeline_root : Optional [str ]) -> None :
86
104
"""Updates pipeline_root value.
@@ -111,6 +129,24 @@ def update_runtime_parameters(
111
129
parameters [k ] = json .dumps (v )
112
130
self ._parameter_values .update (parameters )
113
131
132
+ def update_failure_policy (self , failure_policy : Optional [str ] = None ) -> None :
133
+ """Merges runtime failure policy.
134
+
135
+ Args:
136
+ failure_policy (str):
137
+ Optional. The failure policy - "slow" or "fast".
138
+
139
+ Raises:
140
+ ValueError: if failure_policy is not valid.
141
+ """
142
+ if failure_policy :
143
+ if failure_policy in _FAILURE_POLICY_TO_ENUM_VALUE :
144
+ self ._failure_policy = _FAILURE_POLICY_TO_ENUM_VALUE [failure_policy ]
145
+ else :
146
+ raise ValueError (
147
+ f'failure_policy should be either "slow" or "fast", but got: "{ failure_policy } ".'
148
+ )
149
+
114
150
def build (self ) -> Dict [str , Any ]:
115
151
"""Build a RuntimeConfig proto.
116
152
@@ -128,7 +164,8 @@ def build(self) -> Dict[str, Any]:
128
164
parameter_values_key = "parameterValues"
129
165
else :
130
166
parameter_values_key = "parameters"
131
- return {
167
+
168
+ runtime_config = {
132
169
"gcsOutputDirectory" : self ._pipeline_root ,
133
170
parameter_values_key : {
134
171
k : self ._get_vertex_value (k , v )
@@ -137,6 +174,11 @@ def build(self) -> Dict[str, Any]:
137
174
},
138
175
}
139
176
177
+ if self ._failure_policy :
178
+ runtime_config ["failurePolicy" ] = self ._failure_policy
179
+
180
+ return runtime_config
181
+
140
182
def _get_vertex_value (
141
183
self , name : str , value : Union [int , float , str , bool , list , dict ]
142
184
) -> Union [int , float , str , bool , list , dict ]:
@@ -205,3 +247,10 @@ def _parse_runtime_parameters(
205
247
else :
206
248
raise TypeError ("Got unknown type of value: {}" .format (value ))
207
249
return result
250
+
251
+
252
+ _FAILURE_POLICY_TO_ENUM_VALUE = {
253
+ "slow" : pipeline_failure_policy .PipelineFailurePolicy .PIPELINE_FAILURE_POLICY_FAIL_SLOW ,
254
+ "fast" : pipeline_failure_policy .PipelineFailurePolicy .PIPELINE_FAILURE_POLICY_FAIL_FAST ,
255
+ None : pipeline_failure_policy .PipelineFailurePolicy .PIPELINE_FAILURE_POLICY_UNSPECIFIED ,
256
+ }
0 commit comments