|
74 | 74 | )
|
75 | 75 |
|
76 | 76 | from google.cloud import bigquery
|
77 |
| - |
78 |
| -try: |
79 |
| - from google.cloud import bigquery_storage |
80 |
| - from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream |
81 |
| - |
82 |
| - _USE_BQ_STORAGE = True |
83 |
| -except ImportError: |
84 |
| - _USE_BQ_STORAGE = False |
| 77 | +from google.cloud import bigquery_storage |
| 78 | +from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream |
85 | 79 |
|
86 | 80 | from google.cloud import resourcemanager
|
87 | 81 |
|
|
283 | 277 | _TEST_GCS_SOURCE_TYPE_AVRO = "avro"
|
284 | 278 | _TEST_GCS_SOURCE_TYPE_INVALID = "json"
|
285 | 279 |
|
| 280 | +_TEST_BATCH_SERVE_START_TIME = datetime.datetime.now() |
286 | 281 | _TEST_BQ_DESTINATION_URI = "bq://project.dataset.table_name"
|
287 | 282 | _TEST_GCS_OUTPUT_URI_PREFIX = "gs://my_bucket/path/to_prefix"
|
288 | 283 |
|
@@ -1613,6 +1608,57 @@ def test_batch_serve_to_bq_with_timeout_not_explicitly_set(
|
1613 | 1608 | timeout=None,
|
1614 | 1609 | )
|
1615 | 1610 |
|
| 1611 | + @pytest.mark.parametrize("sync", [True, False]) |
| 1612 | + @pytest.mark.usefixtures("get_featurestore_mock") |
| 1613 | + def test_batch_serve_to_bq_with_start_time( |
| 1614 | + self, batch_read_feature_values_mock, sync |
| 1615 | + ): |
| 1616 | + aiplatform.init(project=_TEST_PROJECT) |
| 1617 | + my_featurestore = aiplatform.Featurestore( |
| 1618 | + featurestore_name=_TEST_FEATURESTORE_NAME |
| 1619 | + ) |
| 1620 | + |
| 1621 | + expected_entity_type_specs = [ |
| 1622 | + _get_entity_type_spec_proto_with_feature_ids( |
| 1623 | + entity_type_id="my_entity_type_id_1", |
| 1624 | + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], |
| 1625 | + ), |
| 1626 | + _get_entity_type_spec_proto_with_feature_ids( |
| 1627 | + entity_type_id="my_entity_type_id_2", |
| 1628 | + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], |
| 1629 | + ), |
| 1630 | + ] |
| 1631 | + |
| 1632 | + expected_batch_read_feature_values_request = ( |
| 1633 | + gca_featurestore_service.BatchReadFeatureValuesRequest( |
| 1634 | + featurestore=my_featurestore.resource_name, |
| 1635 | + destination=gca_featurestore_service.FeatureValueDestination( |
| 1636 | + bigquery_destination=_TEST_BQ_DESTINATION, |
| 1637 | + ), |
| 1638 | + entity_type_specs=expected_entity_type_specs, |
| 1639 | + bigquery_read_instances=_TEST_BQ_SOURCE, |
| 1640 | + start_time=_TEST_BATCH_SERVE_START_TIME, |
| 1641 | + ) |
| 1642 | + ) |
| 1643 | + |
| 1644 | + my_featurestore.batch_serve_to_bq( |
| 1645 | + bq_destination_output_uri=_TEST_BQ_DESTINATION_URI, |
| 1646 | + serving_feature_ids=_TEST_SERVING_FEATURE_IDS, |
| 1647 | + read_instances_uri=_TEST_BQ_SOURCE_URI, |
| 1648 | + sync=sync, |
| 1649 | + serve_request_timeout=None, |
| 1650 | + start_time=_TEST_BATCH_SERVE_START_TIME, |
| 1651 | + ) |
| 1652 | + |
| 1653 | + if not sync: |
| 1654 | + my_featurestore.wait() |
| 1655 | + |
| 1656 | + batch_read_feature_values_mock.assert_called_once_with( |
| 1657 | + request=expected_batch_read_feature_values_request, |
| 1658 | + metadata=_TEST_REQUEST_METADATA, |
| 1659 | + timeout=None, |
| 1660 | + ) |
| 1661 | + |
1616 | 1662 | @pytest.mark.parametrize("sync", [True, False])
|
1617 | 1663 | @pytest.mark.usefixtures("get_featurestore_mock")
|
1618 | 1664 | def test_batch_serve_to_gcs(self, batch_read_feature_values_mock, sync):
|
@@ -1677,9 +1723,58 @@ def test_batch_serve_to_gcs_with_invalid_gcs_destination_type(self):
|
1677 | 1723 | read_instances_uri=_TEST_GCS_CSV_SOURCE_URI,
|
1678 | 1724 | )
|
1679 | 1725 |
|
1680 |
| - @pytest.mark.skipif( |
1681 |
| - _USE_BQ_STORAGE is False, reason="batch_serve_to_df requires bigquery_storage" |
1682 |
| - ) |
| 1726 | + @pytest.mark.parametrize("sync", [True, False]) |
| 1727 | + @pytest.mark.usefixtures("get_featurestore_mock") |
| 1728 | + def test_batch_serve_to_gcs_with_start_time( |
| 1729 | + self, batch_read_feature_values_mock, sync |
| 1730 | + ): |
| 1731 | + aiplatform.init(project=_TEST_PROJECT) |
| 1732 | + my_featurestore = aiplatform.Featurestore( |
| 1733 | + featurestore_name=_TEST_FEATURESTORE_NAME |
| 1734 | + ) |
| 1735 | + |
| 1736 | + expected_entity_type_specs = [ |
| 1737 | + _get_entity_type_spec_proto_with_feature_ids( |
| 1738 | + entity_type_id="my_entity_type_id_1", |
| 1739 | + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], |
| 1740 | + ), |
| 1741 | + _get_entity_type_spec_proto_with_feature_ids( |
| 1742 | + entity_type_id="my_entity_type_id_2", |
| 1743 | + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], |
| 1744 | + ), |
| 1745 | + ] |
| 1746 | + |
| 1747 | + expected_batch_read_feature_values_request = ( |
| 1748 | + gca_featurestore_service.BatchReadFeatureValuesRequest( |
| 1749 | + featurestore=my_featurestore.resource_name, |
| 1750 | + destination=gca_featurestore_service.FeatureValueDestination( |
| 1751 | + tfrecord_destination=_TEST_TFRECORD_DESTINATION, |
| 1752 | + ), |
| 1753 | + entity_type_specs=expected_entity_type_specs, |
| 1754 | + csv_read_instances=_TEST_CSV_SOURCE, |
| 1755 | + start_time=_TEST_BATCH_SERVE_START_TIME, |
| 1756 | + ) |
| 1757 | + ) |
| 1758 | + |
| 1759 | + my_featurestore.batch_serve_to_gcs( |
| 1760 | + gcs_destination_output_uri_prefix=_TEST_GCS_OUTPUT_URI_PREFIX, |
| 1761 | + gcs_destination_type=_TEST_GCS_DESTINATION_TYPE_TFRECORD, |
| 1762 | + serving_feature_ids=_TEST_SERVING_FEATURE_IDS, |
| 1763 | + read_instances_uri=_TEST_GCS_CSV_SOURCE_URI, |
| 1764 | + sync=sync, |
| 1765 | + serve_request_timeout=None, |
| 1766 | + start_time=_TEST_BATCH_SERVE_START_TIME, |
| 1767 | + ) |
| 1768 | + |
| 1769 | + if not sync: |
| 1770 | + my_featurestore.wait() |
| 1771 | + |
| 1772 | + batch_read_feature_values_mock.assert_called_once_with( |
| 1773 | + request=expected_batch_read_feature_values_request, |
| 1774 | + metadata=_TEST_REQUEST_METADATA, |
| 1775 | + timeout=None, |
| 1776 | + ) |
| 1777 | + |
1683 | 1778 | @pytest.mark.usefixtures(
|
1684 | 1779 | "get_featurestore_mock",
|
1685 | 1780 | "bq_init_client_mock",
|
@@ -1753,9 +1848,6 @@ def test_batch_serve_to_df(self, batch_read_feature_values_mock):
|
1753 | 1848 | timeout=None,
|
1754 | 1849 | )
|
1755 | 1850 |
|
1756 |
| - @pytest.mark.skipif( |
1757 |
| - _USE_BQ_STORAGE is False, reason="batch_serve_to_df requires bigquery_storage" |
1758 |
| - ) |
1759 | 1851 | @pytest.mark.usefixtures(
|
1760 | 1852 | "get_featurestore_mock",
|
1761 | 1853 | "bq_init_client_mock",
|
@@ -1850,6 +1942,81 @@ def test_batch_serve_to_df_user_specified_bq_dataset(
|
1850 | 1942 | bq_create_dataset_mock.assert_not_called()
|
1851 | 1943 | bq_delete_dataset_mock.assert_not_called()
|
1852 | 1944 |
|
| 1945 | + @pytest.mark.usefixtures( |
| 1946 | + "get_featurestore_mock", |
| 1947 | + "bq_init_client_mock", |
| 1948 | + "bq_init_dataset_mock", |
| 1949 | + "bq_create_dataset_mock", |
| 1950 | + "bq_load_table_from_dataframe_mock", |
| 1951 | + "bq_delete_dataset_mock", |
| 1952 | + "bqs_init_client_mock", |
| 1953 | + "bqs_create_read_session", |
| 1954 | + "get_project_mock", |
| 1955 | + ) |
| 1956 | + @patch("uuid.uuid4", uuid_mock) |
| 1957 | + def test_batch_serve_to_df_with_start_time(self, batch_read_feature_values_mock): |
| 1958 | + |
| 1959 | + aiplatform.init(project=_TEST_PROJECT_DIFF) |
| 1960 | + |
| 1961 | + my_featurestore = aiplatform.Featurestore( |
| 1962 | + featurestore_name=_TEST_FEATURESTORE_NAME |
| 1963 | + ) |
| 1964 | + |
| 1965 | + read_instances_df = pd.DataFrame() |
| 1966 | + |
| 1967 | + expected_temp_bq_dataset_name = ( |
| 1968 | + f"temp_{_TEST_FEATURESTORE_ID}_{uuid.uuid4()}".replace("-", "_") |
| 1969 | + ) |
| 1970 | + expecte_temp_bq_dataset_id = f"{_TEST_PROJECT}.{expected_temp_bq_dataset_name}"[ |
| 1971 | + :1024 |
| 1972 | + ] |
| 1973 | + expected_temp_bq_read_instances_table_id = ( |
| 1974 | + f"{expecte_temp_bq_dataset_id}.read_instances" |
| 1975 | + ) |
| 1976 | + expected_temp_bq_batch_serve_table_id = ( |
| 1977 | + f"{expecte_temp_bq_dataset_id}.batch_serve" |
| 1978 | + ) |
| 1979 | + |
| 1980 | + expected_entity_type_specs = [ |
| 1981 | + _get_entity_type_spec_proto_with_feature_ids( |
| 1982 | + entity_type_id="my_entity_type_id_1", |
| 1983 | + feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"], |
| 1984 | + ), |
| 1985 | + _get_entity_type_spec_proto_with_feature_ids( |
| 1986 | + entity_type_id="my_entity_type_id_2", |
| 1987 | + feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"], |
| 1988 | + ), |
| 1989 | + ] |
| 1990 | + |
| 1991 | + expected_batch_read_feature_values_request = ( |
| 1992 | + gca_featurestore_service.BatchReadFeatureValuesRequest( |
| 1993 | + featurestore=my_featurestore.resource_name, |
| 1994 | + destination=gca_featurestore_service.FeatureValueDestination( |
| 1995 | + bigquery_destination=gca_io.BigQueryDestination( |
| 1996 | + output_uri=f"bq://{expected_temp_bq_batch_serve_table_id}" |
| 1997 | + ), |
| 1998 | + ), |
| 1999 | + entity_type_specs=expected_entity_type_specs, |
| 2000 | + bigquery_read_instances=gca_io.BigQuerySource( |
| 2001 | + input_uri=f"bq://{expected_temp_bq_read_instances_table_id}" |
| 2002 | + ), |
| 2003 | + start_time=_TEST_BATCH_SERVE_START_TIME, |
| 2004 | + ) |
| 2005 | + ) |
| 2006 | + |
| 2007 | + my_featurestore.batch_serve_to_df( |
| 2008 | + serving_feature_ids=_TEST_SERVING_FEATURE_IDS, |
| 2009 | + read_instances_df=read_instances_df, |
| 2010 | + serve_request_timeout=None, |
| 2011 | + start_time=_TEST_BATCH_SERVE_START_TIME, |
| 2012 | + ) |
| 2013 | + |
| 2014 | + batch_read_feature_values_mock.assert_called_once_with( |
| 2015 | + request=expected_batch_read_feature_values_request, |
| 2016 | + metadata=_TEST_REQUEST_METADATA, |
| 2017 | + timeout=None, |
| 2018 | + ) |
| 2019 | + |
1853 | 2020 |
|
1854 | 2021 | @pytest.mark.usefixtures("google_auth_mock")
|
1855 | 2022 | class TestEntityType:
|
|
0 commit comments