Skip to content

Commit 3ec9386

Browse files
authored
feat: Add list() method to all resource nouns (#294)
### [Colab for manual testing](https://colab.research.google.com/drive/14iLNaJEyZaPGebCgJZxgUJJeU9r7z_Fu) ### Summary of Changes - Added a `_list_method` property to `AiPlatformResourceNoun` to store GAPIC method name for each noun - Added a `create_time` and `update_time` property to `AiPlatformResourceNoun` - Added a single `list()` method that takes four optional fields and returns a list of SDK types - All of the fields except `order_by` are available in every GAPIC list methods - Added local sorting for GAPIC list methods that do not take `order_by` - Added 3 unit tests to check correct GAPIC calls and local sorting - Added `aiplatform.init()` to test class setup and dropped it from some unit tests Fixes [b/183498826](http://b/183498826) 🦕
1 parent 674227d commit 3ec9386

File tree

10 files changed

+630
-25
lines changed

10 files changed

+630
-25
lines changed

google/cloud/aiplatform/base.py

Lines changed: 233 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
import abc
1919
from concurrent import futures
20+
import datetime
2021
import functools
2122
import inspect
22-
import threading
23-
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
2423

2524
import proto
25+
import threading
26+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
2627

2728
from google.auth import credentials as auth_credentials
2829
from google.cloud.aiplatform import initializer
@@ -249,6 +250,12 @@ def _getter_method(cls) -> str:
249250
"""Name of getter method of client class for retrieving the resource."""
250251
pass
251252

253+
@property
254+
@abc.abstractmethod
255+
def _list_method(cls) -> str:
256+
"""Name of list method of client class for listing resources."""
257+
pass
258+
252259
@property
253260
@abc.abstractmethod
254261
def _delete_method(cls) -> str:
@@ -385,6 +392,17 @@ def display_name(self) -> str:
385392
"""Display name of this resource."""
386393
return self._gca_resource.display_name
387394

395+
@property
396+
def create_time(self) -> datetime.datetime:
397+
"""Time this resource was created."""
398+
return self._gca_resource.create_time
399+
400+
@property
401+
def update_time(self) -> datetime.datetime:
402+
"""Time this resource was last updated."""
403+
self._sync_gca_resource()
404+
return self._gca_resource.update_time
405+
388406
def __repr__(self) -> str:
389407
return f"{object.__repr__(self)} \nresource name: {self.resource_name}"
390408

@@ -617,6 +635,219 @@ def _sync_object_with_future_result(
617635
if value:
618636
setattr(self, attribute, value)
619637

638+
def _construct_sdk_resource_from_gapic(
639+
self,
640+
gapic_resource: proto.Message,
641+
project: Optional[str] = None,
642+
location: Optional[str] = None,
643+
credentials: Optional[auth_credentials.Credentials] = None,
644+
) -> AiPlatformResourceNoun:
645+
"""Given a GAPIC resource object, return the SDK representation.
646+
647+
Args:
648+
gapic_resource (proto.Message):
649+
A GAPIC representation of an AI Platform resource, usually
650+
retrieved by a get_* or in a list_* API call.
651+
project (str):
652+
Optional. Project to construct SDK object from. If not set,
653+
project set in aiplatform.init will be used.
654+
location (str):
655+
Optional. Location to construct SDK object from. If not set,
656+
location set in aiplatform.init will be used.
657+
credentials (auth_credentials.Credentials):
658+
Optional. Custom credentials to use to construct SDK object.
659+
Overrides credentials set in aiplatform.init.
660+
661+
Returns:
662+
AiPlatformResourceNoun:
663+
An initialized SDK object that represents GAPIC type.
664+
"""
665+
sdk_resource = self._empty_constructor(
666+
project=project, location=location, credentials=credentials
667+
)
668+
sdk_resource._gca_resource = gapic_resource
669+
return sdk_resource
670+
671+
# TODO(b/144545165): Improve documentation for list filtering once available
672+
# TODO(b/184910159): Expose `page_size` field in list method
673+
@classmethod
674+
def _list(
675+
cls,
676+
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
677+
filter: Optional[str] = None,
678+
order_by: Optional[str] = None,
679+
project: Optional[str] = None,
680+
location: Optional[str] = None,
681+
credentials: Optional[auth_credentials.Credentials] = None,
682+
) -> List[AiPlatformResourceNoun]:
683+
"""Private method to list all instances of this AI Platform Resource,
684+
takes a `cls_filter` arg to filter to a particular SDK resource subclass.
685+
686+
Args:
687+
cls_filter (Callable[[proto.Message], bool]):
688+
A function that takes one argument, a GAPIC resource, and returns
689+
a bool. If the function returns False, that resource will be
690+
excluded from the returned list. Example usage:
691+
cls_filter = lambda obj: obj.metadata in cls.valid_metadatas
692+
filter (str):
693+
Optional. An expression for filtering the results of the request.
694+
For field names both snake_case and camelCase are supported.
695+
order_by (str):
696+
Optional. A comma-separated list of fields to order by, sorted in
697+
ascending order. Use "desc" after a field name for descending.
698+
Supported fields: `display_name`, `create_time`, `update_time`
699+
project (str):
700+
Optional. Project to retrieve list from. If not set, project
701+
set in aiplatform.init will be used.
702+
location (str):
703+
Optional. Location to retrieve list from. If not set, location
704+
set in aiplatform.init will be used.
705+
credentials (auth_credentials.Credentials):
706+
Optional. Custom credentials to use to retrieve list. Overrides
707+
credentials set in aiplatform.init.
708+
709+
Returns:
710+
List[AiPlatformResourceNoun] - A list of SDK resource objects
711+
"""
712+
self = cls._empty_constructor(
713+
project=project, location=location, credentials=credentials
714+
)
715+
716+
# Fetch credentials once and re-use for all `_empty_constructor()` calls
717+
creds = initializer.global_config.credentials
718+
719+
resource_list_method = getattr(self.api_client, self._list_method)
720+
721+
list_request = {
722+
"parent": initializer.global_config.common_location_path(
723+
project=project, location=location
724+
),
725+
"filter": filter,
726+
}
727+
728+
if order_by:
729+
list_request["order_by"] = order_by
730+
731+
resource_list = resource_list_method(request=list_request) or []
732+
733+
return [
734+
self._construct_sdk_resource_from_gapic(
735+
gapic_resource, project=project, location=location, credentials=creds
736+
)
737+
for gapic_resource in resource_list
738+
if cls_filter(gapic_resource)
739+
]
740+
741+
@classmethod
742+
def _list_with_local_order(
743+
cls,
744+
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
745+
filter: Optional[str] = None,
746+
order_by: Optional[str] = None,
747+
project: Optional[str] = None,
748+
location: Optional[str] = None,
749+
credentials: Optional[auth_credentials.Credentials] = None,
750+
) -> List[AiPlatformResourceNoun]:
751+
"""Private method to list all instances of this AI Platform Resource,
752+
takes a `cls_filter` arg to filter to a particular SDK resource subclass.
753+
Provides client-side sorting when a list API doesn't support `order_by`.
754+
755+
Args:
756+
cls_filter (Callable[[proto.Message], bool]):
757+
A function that takes one argument, a GAPIC resource, and returns
758+
a bool. If the function returns False, that resource will be
759+
excluded from the returned list. Example usage:
760+
cls_filter = lambda obj: obj.metadata in cls.valid_metadatas
761+
filter (str):
762+
Optional. An expression for filtering the results of the request.
763+
For field names both snake_case and camelCase are supported.
764+
order_by (str):
765+
Optional. A comma-separated list of fields to order by, sorted in
766+
ascending order. Use "desc" after a field name for descending.
767+
Supported fields: `display_name`, `create_time`, `update_time`
768+
project (str):
769+
Optional. Project to retrieve list from. If not set, project
770+
set in aiplatform.init will be used.
771+
location (str):
772+
Optional. Location to retrieve list from. If not set, location
773+
set in aiplatform.init will be used.
774+
credentials (auth_credentials.Credentials):
775+
Optional. Custom credentials to use to retrieve list. Overrides
776+
credentials set in aiplatform.init.
777+
778+
Returns:
779+
List[AiPlatformResourceNoun] - A list of SDK resource objects
780+
"""
781+
782+
li = cls._list(
783+
cls_filter=cls_filter,
784+
filter=filter,
785+
order_by=None, # This method will handle the ordering locally
786+
project=project,
787+
location=location,
788+
credentials=credentials,
789+
)
790+
791+
desc = "desc" in order_by
792+
order_by = order_by.replace("desc", "")
793+
order_by = order_by.split(",")
794+
795+
li.sort(
796+
key=lambda x: tuple(getattr(x, field.strip()) for field in order_by),
797+
reverse=desc,
798+
)
799+
800+
return li
801+
802+
@classmethod
803+
def list(
804+
cls,
805+
filter: Optional[str] = None,
806+
order_by: Optional[str] = None,
807+
project: Optional[str] = None,
808+
location: Optional[str] = None,
809+
credentials: Optional[auth_credentials.Credentials] = None,
810+
) -> List[AiPlatformResourceNoun]:
811+
"""List all instances of this AI Platform Resource.
812+
813+
Example Usage:
814+
815+
aiplatform.BatchPredictionJobs.list(
816+
filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"',
817+
)
818+
819+
aiplatform.Model.list(order_by="create_time desc, display_name")
820+
821+
Args:
822+
filter (str):
823+
Optional. An expression for filtering the results of the request.
824+
For field names both snake_case and camelCase are supported.
825+
order_by (str):
826+
Optional. A comma-separated list of fields to order by, sorted in
827+
ascending order. Use "desc" after a field name for descending.
828+
Supported fields: `display_name`, `create_time`, `update_time`
829+
project (str):
830+
Optional. Project to retrieve list from. If not set, project
831+
set in aiplatform.init will be used.
832+
location (str):
833+
Optional. Location to retrieve list from. If not set, location
834+
set in aiplatform.init will be used.
835+
credentials (auth_credentials.Credentials):
836+
Optional. Custom credentials to use to retrieve list. Overrides
837+
credentials set in aiplatform.init.
838+
839+
Returns:
840+
List[AiPlatformResourceNoun] - A list of SDK resource objects
841+
"""
842+
843+
return cls._list(
844+
filter=filter,
845+
order_by=order_by,
846+
project=project,
847+
location=location,
848+
credentials=credentials,
849+
)
850+
620851
@optional_sync()
621852
def delete(self, sync: bool = True) -> None:
622853
"""Deletes this AI Platform resource. WARNING: This deletion is permament.

google/cloud/aiplatform/datasets/dataset.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Optional, Sequence, Dict, Tuple, Union
18+
from typing import Optional, Sequence, Dict, Tuple, Union, List
1919

2020
from google.api_core import operation
2121
from google.auth import credentials as auth_credentials
@@ -40,9 +40,10 @@ class Dataset(base.AiPlatformResourceNounWithFutureManager):
4040
_is_client_prediction_client = False
4141
_resource_noun = "datasets"
4242
_getter_method = "get_dataset"
43+
_list_method = "list_datasets"
4344
_delete_method = "delete_dataset"
4445

45-
_supported_metadata_schema_uris: Optional[Tuple[str]] = None
46+
_supported_metadata_schema_uris: Tuple[str] = ()
4647

4748
def __init__(
4849
self,
@@ -494,3 +495,57 @@ def export_data(self, output_dir: str) -> Sequence[str]:
494495

495496
def update(self):
496497
raise NotImplementedError("Update dataset has not been implemented yet")
498+
499+
@classmethod
500+
def list(
501+
cls,
502+
filter: Optional[str] = None,
503+
order_by: Optional[str] = None,
504+
project: Optional[str] = None,
505+
location: Optional[str] = None,
506+
credentials: Optional[auth_credentials.Credentials] = None,
507+
) -> List[base.AiPlatformResourceNoun]:
508+
"""List all instances of this Dataset resource.
509+
510+
Example Usage:
511+
512+
aiplatform.TabularDataset.list(
513+
filter='labels.my_key="my_value"',
514+
order_by='display_name'
515+
)
516+
517+
Args:
518+
filter (str):
519+
Optional. An expression for filtering the results of the request.
520+
For field names both snake_case and camelCase are supported.
521+
order_by (str):
522+
Optional. A comma-separated list of fields to order by, sorted in
523+
ascending order. Use "desc" after a field name for descending.
524+
Supported fields: `display_name`, `create_time`, `update_time`
525+
project (str):
526+
Optional. Project to retrieve list from. If not set, project
527+
set in aiplatform.init will be used.
528+
location (str):
529+
Optional. Location to retrieve list from. If not set, location
530+
set in aiplatform.init will be used.
531+
credentials (auth_credentials.Credentials):
532+
Optional. Custom credentials to use to retrieve list. Overrides
533+
credentials set in aiplatform.init.
534+
535+
Returns:
536+
List[base.AiPlatformResourceNoun] - A list of Dataset resource objects
537+
"""
538+
539+
dataset_subclass_filter = (
540+
lambda gapic_obj: gapic_obj.metadata_schema_uri
541+
in cls._supported_metadata_schema_uris
542+
)
543+
544+
return cls._list_with_local_order(
545+
cls_filter=dataset_subclass_filter,
546+
filter=filter,
547+
order_by=order_by,
548+
project=project,
549+
location=location,
550+
credentials=credentials,
551+
)

0 commit comments

Comments
 (0)