Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec
from google.cloud.aiplatform.constants import base as base_constants
from google.protobuf import json_format
from google.protobuf import field_mask_pb2 as field_mask

# This is the default retry callback to be used with get methods.
_DEFAULT_RETRY = retry.Retry()
Expand Down Expand Up @@ -1030,6 +1031,7 @@ def _list(
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
filter: Optional[str] = None,
order_by: Optional[str] = None,
read_mask: Optional[field_mask.FieldMask] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add doc string for this and the one in the next method?

project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand All @@ -1052,6 +1054,14 @@ def _list(
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
read_mask (field_mask.FieldMask):
Optional. A FieldMask with a list of strings passed via `paths`
indicating which fields to return for each resource in the response.
For example, passing
field_mask.FieldMask(paths=["create_time", "update_time"])
as `read_mask` would result in each returned VertexAiResourceNoun
in the result list only having the "create_time" and
"update_time" attributes.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
Expand All @@ -1067,6 +1077,7 @@ def _list(
Returns:
List[VertexAiResourceNoun] - A list of SDK resource objects
"""

resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
Expand All @@ -1083,6 +1094,10 @@ def _list(
),
}

# `read_mask` is only passed from PipelineJob.list() for now
if read_mask is not None:
list_request["read_mask"] = read_mask

if filter:
list_request["filter"] = filter

Expand All @@ -1105,6 +1120,7 @@ def _list_with_local_order(
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
filter: Optional[str] = None,
order_by: Optional[str] = None,
read_mask: Optional[field_mask.FieldMask] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand All @@ -1127,6 +1143,14 @@ def _list_with_local_order(
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
read_mask (field_mask.FieldMask):
Optional. A FieldMask with a list of strings passed via `paths`
indicating which fields to return for each resource in the response.
For example, passing
field_mask.FieldMask(paths=["create_time", "update_time"])
as `read_mask` would result in each returned VertexAiResourceNoun
in the result list only having the "create_time" and
"update_time" attributes.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
Expand All @@ -1145,6 +1169,7 @@ def _list_with_local_order(
cls_filter=cls_filter,
filter=filter,
order_by=None, # This method will handle the ordering locally
read_mask=read_mask,
project=project,
location=location,
credentials=credentials,
Expand Down
17 changes: 17 additions & 0 deletions google/cloud/aiplatform/constants/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,20 @@

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")

# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list()
_READ_MASK_FIELDS = [
"name",
"state",
"display_name",
"pipeline_spec.pipeline_info",
"create_time",
"start_time",
"end_time",
"update_time",
"labels",
"template_uri",
"template_metadata.version",
"job_detail.pipeline_run_context",
"job_detail.pipeline_context",
]
24 changes: 24 additions & 0 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from google.cloud.aiplatform.utils import yaml_utils
from google.cloud.aiplatform.utils import pipeline_utils
from google.protobuf import json_format
from google.protobuf import field_mask_pb2 as field_mask

from google.cloud.aiplatform.compat.types import (
pipeline_job as gca_pipeline_job,
Expand All @@ -56,6 +57,8 @@
# Pattern for an Artifact Registry URL.
_VALID_AR_URL = pipeline_constants._VALID_AR_URL

_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS


def _get_current_time() -> datetime.datetime:
"""Gets the current timestamp."""
Expand Down Expand Up @@ -509,6 +512,7 @@ def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
enable_simple_view: Optional[bool] = False,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand All @@ -530,6 +534,17 @@ def list(
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
enable_simple_view (bool):
Optional. Whether to pass the `read_mask` parameter to the list call.
This will improve the performance of calling list(). However, the
returned PipelineJob list will not include all fields for each PipelineJob.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we can specify what fields are missing.

Setting this to True will exclude the following fields in your response:
`runtime_config`, `service_account`, `network`, and some subfields of
`pipeline_spec` and `job_detail`. The following fields will be included in
each PipelineJob resource in your response: `state`, `display_name`,
`pipeline_spec.pipeline_info`, `create_time`, `start_time`, `end_time`,
`update_time`, `labels`, `template_uri`, `template_metadata.version`,
`job_detail.pipeline_run_context`, `job_detail.pipeline_context`.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
Expand All @@ -544,9 +559,18 @@ def list(
List[PipelineJob] - A list of PipelineJob resource objects
"""

read_mask_fields = None

if enable_simple_view:
read_mask_fields = field_mask.FieldMask(paths=_READ_MASK_FIELDS)
_LOGGER.warn(
"By enabling simple view, the PipelineJob resources returned from this method will not contain all fields."
)

return cls._list_with_local_order(
filter=filter,
order_by=order_by,
read_mask=read_mask_fields,
project=project,
location=location,
credentials=credentials,
Expand Down
13 changes: 13 additions & 0 deletions tests/system/aiplatform/test_pipeline_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from google.cloud import aiplatform
from tests.system.aiplatform import e2e_base

from google.protobuf.json_format import MessageToDict


@pytest.mark.usefixtures("tear_down_resources")
class TestPipelineJob(e2e_base.TestEndToEnd):
Expand Down Expand Up @@ -59,3 +61,14 @@ def training_pipeline(number_of_epochs: int = 10):
shared_state.setdefault("resources", []).append(job)

job.wait()

list_with_read_mask = aiplatform.PipelineJob.list(enable_simple_view=True)
list_without_read_mask = aiplatform.PipelineJob.list()

# enable_simple_view=True should apply the `read_mask` filter to limit PipelineJob fields returned
assert "serviceAccount" in MessageToDict(
list_without_read_mask[0].gca_resource._pb
)
assert "serviceAccount" not in MessageToDict(
list_with_read_mask[0].gca_resource._pb
)
57 changes: 57 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
from google.cloud.aiplatform_v1 import Context as GapicContext
from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore
from google.cloud.aiplatform.metadata import constants
Expand All @@ -37,6 +38,7 @@
from google.cloud.aiplatform.utils import gcs_utils
from google.cloud import storage
from google.protobuf import json_format
from google.protobuf import field_mask_pb2 as field_mask

from google.cloud.aiplatform.compat.services import (
pipeline_service_client,
Expand All @@ -62,6 +64,9 @@
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"

_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"
_TEST_PIPELINE_JOB_LIST_READ_MASK = field_mask.FieldMask(
paths=pipeline_constants._READ_MASK_FIELDS
)

_TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param": "hello"}
_TEST_PIPELINE_PARAMETER_VALUES = {
Expand Down Expand Up @@ -332,6 +337,17 @@ def mock_pipeline_service_list():
with mock.patch.object(
pipeline_service_client.PipelineServiceClient, "list_pipeline_jobs"
) as mock_list_pipeline_jobs:
mock_list_pipeline_jobs.return_value = [
make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
),
]
yield mock_list_pipeline_jobs


Expand Down Expand Up @@ -1354,6 +1370,47 @@ def test_list_pipeline_job(
request={"parent": _TEST_PARENT}
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_pipeline_bucket_exists",
)
@pytest.mark.parametrize(
"job_spec",
[
_TEST_PIPELINE_SPEC_JSON,
_TEST_PIPELINE_SPEC_YAML,
_TEST_PIPELINE_JOB,
_TEST_PIPELINE_SPEC_LEGACY_JSON,
_TEST_PIPELINE_SPEC_LEGACY_YAML,
_TEST_PIPELINE_JOB_LEGACY,
],
)
def test_list_pipeline_job_with_read_mask(
self, mock_pipeline_service_list, mock_load_yaml_and_json
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)

job.run()
job.list(enable_simple_view=True)

mock_pipeline_service_list.assert_called_once_with(
request={
"parent": _TEST_PARENT,
"read_mask": _TEST_PIPELINE_JOB_LIST_READ_MASK,
},
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
Expand Down