@@ -133,6 +133,26 @@ def mock_pipeline_service_get():
133133 yield mock_get_pipeline_job
134134
135135
136+ @pytest .fixture
137+ def mock_pipeline_service_get_with_fail ():
138+ with mock .patch .object (
139+ pipeline_service_client_v1beta1 .PipelineServiceClient , "get_pipeline_job"
140+ ) as mock_get_pipeline_job :
141+ mock_get_pipeline_job .side_effect = [
142+ make_pipeline_job (
143+ gca_pipeline_state_v1beta1 .PipelineState .PIPELINE_STATE_RUNNING
144+ ),
145+ make_pipeline_job (
146+ gca_pipeline_state_v1beta1 .PipelineState .PIPELINE_STATE_RUNNING
147+ ),
148+ make_pipeline_job (
149+ gca_pipeline_state_v1beta1 .PipelineState .PIPELINE_STATE_FAILED
150+ ),
151+ ]
152+
153+ yield mock_get_pipeline_job
154+
155+
136156@pytest .fixture
137157def mock_pipeline_service_cancel ():
138158 with mock .patch .object (
@@ -269,3 +289,33 @@ def test_cancel_pipeline_job_without_running(
269289 job .cancel ()
270290
271291 assert e .match (regexp = r"PipelineJob has not been launched" )
292+
293+ @pytest .mark .usefixtures (
294+ "mock_pipeline_service_create" ,
295+ "mock_pipeline_service_get_with_fail" ,
296+ "mock_load_json" ,
297+ )
298+ @pytest .mark .parametrize ("sync" , [True , False ])
299+ def test_pipeline_failure_raises (self , sync ):
300+ aiplatform .init (
301+ project = _TEST_PROJECT ,
302+ staging_bucket = _TEST_GCS_BUCKET_NAME ,
303+ location = _TEST_LOCATION ,
304+ credentials = _TEST_CREDENTIALS ,
305+ )
306+
307+ job = pipeline_jobs .PipelineJob (
308+ display_name = _TEST_PIPELINE_JOB_ID ,
309+ template_path = _TEST_TEMPLATE_PATH ,
310+ job_id = _TEST_PIPELINE_JOB_ID ,
311+ parameter_values = _TEST_PIPELINE_PARAMETER_VALUES ,
312+ enable_caching = True ,
313+ )
314+
315+ with pytest .raises (RuntimeError ):
316+ job .run (
317+ service_account = _TEST_SERVICE_ACCOUNT , network = _TEST_NETWORK , sync = sync ,
318+ )
319+
320+ if not sync :
321+ job .wait ()
0 commit comments