Skip to content

Commit 91d8459

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add start_time support for BatchReadFeatureValues wrapper methods.
Always run BatchRead Dataflow tests. PiperOrigin-RevId: 520782661
1 parent f66beaa commit 91d8459

File tree

2 files changed

+216
-15
lines changed

2 files changed

+216
-15
lines changed

google/cloud/aiplatform/featurestore/featurestore.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from google.auth import credentials as auth_credentials
2222
from google.protobuf import field_mask_pb2
23+
from google.protobuf import timestamp_pb2
2324

2425
from google.cloud.aiplatform import base
2526
from google.cloud.aiplatform.compat.types import (
@@ -31,7 +32,10 @@
3132
from google.cloud.aiplatform import featurestore
3233
from google.cloud.aiplatform import initializer
3334
from google.cloud.aiplatform import utils
34-
from google.cloud.aiplatform.utils import featurestore_utils, resource_manager_utils
35+
from google.cloud.aiplatform.utils import (
36+
featurestore_utils,
37+
resource_manager_utils,
38+
)
3539

3640
from google.cloud import bigquery
3741

@@ -695,6 +699,7 @@ def _validate_and_get_batch_read_feature_values_request(
695699
read_instances: Union[gca_io.BigQuerySource, gca_io.CsvSource],
696700
pass_through_fields: Optional[List[str]] = None,
697701
feature_destination_fields: Optional[Dict[str, str]] = None,
702+
start_time: [timestamp_pb2.Timestamp] = None,
698703
) -> gca_featurestore_service.BatchReadFeatureValuesRequest:
699704
"""Validates and gets batch_read_feature_values_request
700705
@@ -736,6 +741,10 @@ def _validate_and_get_batch_read_feature_values_request(
736741
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
737742
}
738743
744+
start_time (timestamp_pb2.Timestamp):
745+
Optional. Excludes Feature values with feature generation timestamp before this timestamp. If not set, retrieve
746+
oldest values kept in Feature Store. Timestamp, if present, must not have higher than millisecond precision.
747+
739748
Returns:
740749
gca_featurestore_service.BatchReadFeatureValuesRequest: batch read feature values request
741750
"""
@@ -819,6 +828,9 @@ def _validate_and_get_batch_read_feature_values_request(
819828
for pass_through_field in pass_through_fields
820829
]
821830

831+
if start_time is not None:
832+
batch_read_feature_values_request.start_time = start_time
833+
822834
return batch_read_feature_values_request
823835

824836
@base.optional_sync(return_input_arg="self")
@@ -829,6 +841,7 @@ def batch_serve_to_bq(
829841
read_instances_uri: str,
830842
pass_through_fields: Optional[List[str]] = None,
831843
feature_destination_fields: Optional[Dict[str, str]] = None,
844+
start_time: Optional[timestamp_pb2.Timestamp] = None,
832845
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
833846
serve_request_timeout: Optional[float] = None,
834847
sync: bool = True,
@@ -903,8 +916,14 @@ def batch_serve_to_bq(
903916
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo',
904917
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
905918
}
919+
920+
start_time (timestamp_pb2.Timestamp):
921+
Optional. Excludes Feature values with feature generation timestamp before this timestamp. If not set, retrieve
922+
oldest values kept in Feature Store. Timestamp, if present, must not have higher than millisecond precision.
923+
906924
serve_request_timeout (float):
907925
Optional. The timeout for the serve request in seconds.
926+
908927
Returns:
909928
Featurestore: The featurestore resource object batch read feature values from.
910929
@@ -924,6 +943,7 @@ def batch_serve_to_bq(
924943
feature_destination_fields=feature_destination_fields,
925944
read_instances=read_instances,
926945
pass_through_fields=pass_through_fields,
946+
start_time=start_time,
927947
)
928948
)
929949

@@ -942,6 +962,7 @@ def batch_serve_to_gcs(
942962
read_instances_uri: str,
943963
pass_through_fields: Optional[List[str]] = None,
944964
feature_destination_fields: Optional[Dict[str, str]] = None,
965+
start_time: Optional[timestamp_pb2.Timestamp] = None,
945966
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
946967
sync: bool = True,
947968
serve_request_timeout: Optional[float] = None,
@@ -1037,6 +1058,11 @@ def batch_serve_to_gcs(
10371058
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo',
10381059
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
10391060
}
1061+
1062+
start_time (timestamp_pb2.Timestamp):
1063+
Optional. Excludes Feature values with feature generation timestamp before this timestamp. If not set, retrieve
1064+
oldest values kept in Feature Store. Timestamp, if present, must not have higher than millisecond precision.
1065+
10401066
serve_request_timeout (float):
10411067
Optional. The timeout for the serve request in seconds.
10421068
@@ -1075,6 +1101,7 @@ def batch_serve_to_gcs(
10751101
feature_destination_fields=feature_destination_fields,
10761102
read_instances=read_instances,
10771103
pass_through_fields=pass_through_fields,
1104+
start_time=start_time,
10781105
)
10791106
)
10801107

@@ -1090,6 +1117,7 @@ def batch_serve_to_df(
10901117
read_instances_df: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
10911118
pass_through_fields: Optional[List[str]] = None,
10921119
feature_destination_fields: Optional[Dict[str, str]] = None,
1120+
start_time: Optional[timestamp_pb2.Timestamp] = None,
10931121
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
10941122
serve_request_timeout: Optional[float] = None,
10951123
bq_dataset_id: Optional[str] = None,
@@ -1182,6 +1210,11 @@ def batch_serve_to_df(
11821210
for temporarily staging data. If specified, caller must have
11831211
`bigquery.tables.create` permissions for Dataset.
11841212
1213+
1214+
start_time (timestamp_pb2.Timestamp):
1215+
Optional. Excludes Feature values with feature generation timestamp before this timestamp. If not set, retrieve
1216+
oldest values kept in Feature Store. Timestamp, if present, must not have higher than millisecond precision.
1217+
11851218
Returns:
11861219
pd.DataFrame: The pandas DataFrame containing feature values from batch serving.
11871220
@@ -1264,6 +1297,7 @@ def batch_serve_to_df(
12641297
feature_destination_fields=feature_destination_fields,
12651298
request_metadata=request_metadata,
12661299
serve_request_timeout=serve_request_timeout,
1300+
start_time=start_time,
12671301
)
12681302

12691303
bigquery_storage_read_client = bigquery_storage.BigQueryReadClient(

tests/unit/aiplatform/test_featurestores.py

Lines changed: 181 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,8 @@
7474
)
7575

7676
from google.cloud import bigquery
77-
78-
try:
79-
from google.cloud import bigquery_storage
80-
from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream
81-
82-
_USE_BQ_STORAGE = True
83-
except ImportError:
84-
_USE_BQ_STORAGE = False
77+
from google.cloud import bigquery_storage
78+
from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream
8579

8680
from google.cloud import resourcemanager
8781

@@ -283,6 +277,7 @@
283277
_TEST_GCS_SOURCE_TYPE_AVRO = "avro"
284278
_TEST_GCS_SOURCE_TYPE_INVALID = "json"
285279

280+
_TEST_BATCH_SERVE_START_TIME = datetime.datetime.now()
286281
_TEST_BQ_DESTINATION_URI = "bq://project.dataset.table_name"
287282
_TEST_GCS_OUTPUT_URI_PREFIX = "gs://my_bucket/path/to_prefix"
288283

@@ -1613,6 +1608,57 @@ def test_batch_serve_to_bq_with_timeout_not_explicitly_set(
16131608
timeout=None,
16141609
)
16151610

1611+
@pytest.mark.parametrize("sync", [True, False])
1612+
@pytest.mark.usefixtures("get_featurestore_mock")
1613+
def test_batch_serve_to_bq_with_start_time(
1614+
self, batch_read_feature_values_mock, sync
1615+
):
1616+
aiplatform.init(project=_TEST_PROJECT)
1617+
my_featurestore = aiplatform.Featurestore(
1618+
featurestore_name=_TEST_FEATURESTORE_NAME
1619+
)
1620+
1621+
expected_entity_type_specs = [
1622+
_get_entity_type_spec_proto_with_feature_ids(
1623+
entity_type_id="my_entity_type_id_1",
1624+
feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"],
1625+
),
1626+
_get_entity_type_spec_proto_with_feature_ids(
1627+
entity_type_id="my_entity_type_id_2",
1628+
feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"],
1629+
),
1630+
]
1631+
1632+
expected_batch_read_feature_values_request = (
1633+
gca_featurestore_service.BatchReadFeatureValuesRequest(
1634+
featurestore=my_featurestore.resource_name,
1635+
destination=gca_featurestore_service.FeatureValueDestination(
1636+
bigquery_destination=_TEST_BQ_DESTINATION,
1637+
),
1638+
entity_type_specs=expected_entity_type_specs,
1639+
bigquery_read_instances=_TEST_BQ_SOURCE,
1640+
start_time=_TEST_BATCH_SERVE_START_TIME,
1641+
)
1642+
)
1643+
1644+
my_featurestore.batch_serve_to_bq(
1645+
bq_destination_output_uri=_TEST_BQ_DESTINATION_URI,
1646+
serving_feature_ids=_TEST_SERVING_FEATURE_IDS,
1647+
read_instances_uri=_TEST_BQ_SOURCE_URI,
1648+
sync=sync,
1649+
serve_request_timeout=None,
1650+
start_time=_TEST_BATCH_SERVE_START_TIME,
1651+
)
1652+
1653+
if not sync:
1654+
my_featurestore.wait()
1655+
1656+
batch_read_feature_values_mock.assert_called_once_with(
1657+
request=expected_batch_read_feature_values_request,
1658+
metadata=_TEST_REQUEST_METADATA,
1659+
timeout=None,
1660+
)
1661+
16161662
@pytest.mark.parametrize("sync", [True, False])
16171663
@pytest.mark.usefixtures("get_featurestore_mock")
16181664
def test_batch_serve_to_gcs(self, batch_read_feature_values_mock, sync):
@@ -1677,9 +1723,58 @@ def test_batch_serve_to_gcs_with_invalid_gcs_destination_type(self):
16771723
read_instances_uri=_TEST_GCS_CSV_SOURCE_URI,
16781724
)
16791725

1680-
@pytest.mark.skipif(
1681-
_USE_BQ_STORAGE is False, reason="batch_serve_to_df requires bigquery_storage"
1682-
)
1726+
@pytest.mark.parametrize("sync", [True, False])
1727+
@pytest.mark.usefixtures("get_featurestore_mock")
1728+
def test_batch_serve_to_gcs_with_start_time(
1729+
self, batch_read_feature_values_mock, sync
1730+
):
1731+
aiplatform.init(project=_TEST_PROJECT)
1732+
my_featurestore = aiplatform.Featurestore(
1733+
featurestore_name=_TEST_FEATURESTORE_NAME
1734+
)
1735+
1736+
expected_entity_type_specs = [
1737+
_get_entity_type_spec_proto_with_feature_ids(
1738+
entity_type_id="my_entity_type_id_1",
1739+
feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"],
1740+
),
1741+
_get_entity_type_spec_proto_with_feature_ids(
1742+
entity_type_id="my_entity_type_id_2",
1743+
feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"],
1744+
),
1745+
]
1746+
1747+
expected_batch_read_feature_values_request = (
1748+
gca_featurestore_service.BatchReadFeatureValuesRequest(
1749+
featurestore=my_featurestore.resource_name,
1750+
destination=gca_featurestore_service.FeatureValueDestination(
1751+
tfrecord_destination=_TEST_TFRECORD_DESTINATION,
1752+
),
1753+
entity_type_specs=expected_entity_type_specs,
1754+
csv_read_instances=_TEST_CSV_SOURCE,
1755+
start_time=_TEST_BATCH_SERVE_START_TIME,
1756+
)
1757+
)
1758+
1759+
my_featurestore.batch_serve_to_gcs(
1760+
gcs_destination_output_uri_prefix=_TEST_GCS_OUTPUT_URI_PREFIX,
1761+
gcs_destination_type=_TEST_GCS_DESTINATION_TYPE_TFRECORD,
1762+
serving_feature_ids=_TEST_SERVING_FEATURE_IDS,
1763+
read_instances_uri=_TEST_GCS_CSV_SOURCE_URI,
1764+
sync=sync,
1765+
serve_request_timeout=None,
1766+
start_time=_TEST_BATCH_SERVE_START_TIME,
1767+
)
1768+
1769+
if not sync:
1770+
my_featurestore.wait()
1771+
1772+
batch_read_feature_values_mock.assert_called_once_with(
1773+
request=expected_batch_read_feature_values_request,
1774+
metadata=_TEST_REQUEST_METADATA,
1775+
timeout=None,
1776+
)
1777+
16831778
@pytest.mark.usefixtures(
16841779
"get_featurestore_mock",
16851780
"bq_init_client_mock",
@@ -1753,9 +1848,6 @@ def test_batch_serve_to_df(self, batch_read_feature_values_mock):
17531848
timeout=None,
17541849
)
17551850

1756-
@pytest.mark.skipif(
1757-
_USE_BQ_STORAGE is False, reason="batch_serve_to_df requires bigquery_storage"
1758-
)
17591851
@pytest.mark.usefixtures(
17601852
"get_featurestore_mock",
17611853
"bq_init_client_mock",
@@ -1850,6 +1942,81 @@ def test_batch_serve_to_df_user_specified_bq_dataset(
18501942
bq_create_dataset_mock.assert_not_called()
18511943
bq_delete_dataset_mock.assert_not_called()
18521944

1945+
@pytest.mark.usefixtures(
1946+
"get_featurestore_mock",
1947+
"bq_init_client_mock",
1948+
"bq_init_dataset_mock",
1949+
"bq_create_dataset_mock",
1950+
"bq_load_table_from_dataframe_mock",
1951+
"bq_delete_dataset_mock",
1952+
"bqs_init_client_mock",
1953+
"bqs_create_read_session",
1954+
"get_project_mock",
1955+
)
1956+
@patch("uuid.uuid4", uuid_mock)
1957+
def test_batch_serve_to_df_with_start_time(self, batch_read_feature_values_mock):
1958+
1959+
aiplatform.init(project=_TEST_PROJECT_DIFF)
1960+
1961+
my_featurestore = aiplatform.Featurestore(
1962+
featurestore_name=_TEST_FEATURESTORE_NAME
1963+
)
1964+
1965+
read_instances_df = pd.DataFrame()
1966+
1967+
expected_temp_bq_dataset_name = (
1968+
f"temp_{_TEST_FEATURESTORE_ID}_{uuid.uuid4()}".replace("-", "_")
1969+
)
1970+
expecte_temp_bq_dataset_id = f"{_TEST_PROJECT}.{expected_temp_bq_dataset_name}"[
1971+
:1024
1972+
]
1973+
expected_temp_bq_read_instances_table_id = (
1974+
f"{expecte_temp_bq_dataset_id}.read_instances"
1975+
)
1976+
expected_temp_bq_batch_serve_table_id = (
1977+
f"{expecte_temp_bq_dataset_id}.batch_serve"
1978+
)
1979+
1980+
expected_entity_type_specs = [
1981+
_get_entity_type_spec_proto_with_feature_ids(
1982+
entity_type_id="my_entity_type_id_1",
1983+
feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"],
1984+
),
1985+
_get_entity_type_spec_proto_with_feature_ids(
1986+
entity_type_id="my_entity_type_id_2",
1987+
feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"],
1988+
),
1989+
]
1990+
1991+
expected_batch_read_feature_values_request = (
1992+
gca_featurestore_service.BatchReadFeatureValuesRequest(
1993+
featurestore=my_featurestore.resource_name,
1994+
destination=gca_featurestore_service.FeatureValueDestination(
1995+
bigquery_destination=gca_io.BigQueryDestination(
1996+
output_uri=f"bq://{expected_temp_bq_batch_serve_table_id}"
1997+
),
1998+
),
1999+
entity_type_specs=expected_entity_type_specs,
2000+
bigquery_read_instances=gca_io.BigQuerySource(
2001+
input_uri=f"bq://{expected_temp_bq_read_instances_table_id}"
2002+
),
2003+
start_time=_TEST_BATCH_SERVE_START_TIME,
2004+
)
2005+
)
2006+
2007+
my_featurestore.batch_serve_to_df(
2008+
serving_feature_ids=_TEST_SERVING_FEATURE_IDS,
2009+
read_instances_df=read_instances_df,
2010+
serve_request_timeout=None,
2011+
start_time=_TEST_BATCH_SERVE_START_TIME,
2012+
)
2013+
2014+
batch_read_feature_values_mock.assert_called_once_with(
2015+
request=expected_batch_read_feature_values_request,
2016+
metadata=_TEST_REQUEST_METADATA,
2017+
timeout=None,
2018+
)
2019+
18532020

18542021
@pytest.mark.usefixtures("google_auth_mock")
18552022
class TestEntityType:

0 commit comments

Comments
 (0)