Skip to content

Commit 2508fe9

Browse files
authored
fix: log pipeline completion and raise pipeline failures (#523)
1 parent f6f9a97 commit 2508fe9

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
]
4545
)
4646

47+
_PIPELINE_ERROR_STATES = set(
48+
[gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED]
49+
)
50+
4751
# Vertex AI Pipelines service API job name relative name prefix pattern.
4852
_JOB_NAME_PATTERN = "{parent}/pipelineJobs/{job_id}"
4953

@@ -311,6 +315,13 @@ def _block_until_complete(self):
311315
previous_time = current_time
312316
time.sleep(wait)
313317

318+
# Error is only populated when the job state is
319+
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
320+
if self._gca_resource.state in _PIPELINE_ERROR_STATES:
321+
raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
322+
else:
323+
_LOGGER.log_action_completed_against_resource("run", "completed", self)
324+
314325
def cancel(self) -> None:
315326
"""Starts asynchronous cancellation on the PipelineJob. The server
316327
makes a best effort to cancel the job, but success is not guaranteed.

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,26 @@ def mock_pipeline_service_get():
133133
yield mock_get_pipeline_job
134134

135135

136+
@pytest.fixture
137+
def mock_pipeline_service_get_with_fail():
138+
with mock.patch.object(
139+
pipeline_service_client_v1beta1.PipelineServiceClient, "get_pipeline_job"
140+
) as mock_get_pipeline_job:
141+
mock_get_pipeline_job.side_effect = [
142+
make_pipeline_job(
143+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING
144+
),
145+
make_pipeline_job(
146+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING
147+
),
148+
make_pipeline_job(
149+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED
150+
),
151+
]
152+
153+
yield mock_get_pipeline_job
154+
155+
136156
@pytest.fixture
137157
def mock_pipeline_service_cancel():
138158
with mock.patch.object(
@@ -269,3 +289,33 @@ def test_cancel_pipeline_job_without_running(
269289
job.cancel()
270290

271291
assert e.match(regexp=r"PipelineJob has not been launched")
292+
293+
@pytest.mark.usefixtures(
294+
"mock_pipeline_service_create",
295+
"mock_pipeline_service_get_with_fail",
296+
"mock_load_json",
297+
)
298+
@pytest.mark.parametrize("sync", [True, False])
299+
def test_pipeline_failure_raises(self, sync):
300+
aiplatform.init(
301+
project=_TEST_PROJECT,
302+
staging_bucket=_TEST_GCS_BUCKET_NAME,
303+
location=_TEST_LOCATION,
304+
credentials=_TEST_CREDENTIALS,
305+
)
306+
307+
job = pipeline_jobs.PipelineJob(
308+
display_name=_TEST_PIPELINE_JOB_ID,
309+
template_path=_TEST_TEMPLATE_PATH,
310+
job_id=_TEST_PIPELINE_JOB_ID,
311+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
312+
enable_caching=True,
313+
)
314+
315+
with pytest.raises(RuntimeError):
316+
job.run(
317+
service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync,
318+
)
319+
320+
if not sync:
321+
job.wait()

0 commit comments

Comments
 (0)