Skip to content

Commit 7cb6976

Browse files
authored
feat: Added column_specs, training_encryption_spec_key_name, model_encryption_spec_key_name to AutoMLForecastingTrainingJob.init and various split methods to AutoMLForecastingTrainingJob.run (#647)
* Extracted column_names code from AutoMLTabularTrainingJob for reuse * Added missing parameters to AutoMLForecast * Fixed tests and added encryption spec * Added missing docstrings * Made _ColumnNamesDataset subclass _Dataset * Fixed docstrings * Moved transformations code out of column_names_dataset * Minor fixes * Cleanup * Ran linter * Fix lint issue * Removed timestamp_split_column_name from AutoMLForecasting * Cleaned up typehints * Fixed test * Ran linter * Ran lint * Added more docstrings for raising exceptions
1 parent 77e58c5 commit 7cb6976

File tree

7 files changed

+777
-287
lines changed

7 files changed

+777
-287
lines changed

google/cloud/aiplatform/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
from google.cloud.aiplatform.datasets.dataset import _Dataset
19+
from google.cloud.aiplatform.datasets.column_names_dataset import _ColumnNamesDataset
1920
from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset
2021
from google.cloud.aiplatform.datasets.time_series_dataset import TimeSeriesDataset
2122
from google.cloud.aiplatform.datasets.image_dataset import ImageDataset
@@ -25,6 +26,7 @@
2526

2627
__all__ = (
2728
"_Dataset",
29+
"_ColumnNamesDataset",
2830
"TabularDataset",
2931
"TimeSeriesDataset",
3032
"ImageDataset",
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
import csv
20+
import logging
21+
from typing import List, Optional, Set
22+
from google.auth import credentials as auth_credentials
23+
24+
from google.cloud import bigquery
25+
from google.cloud import storage
26+
27+
from google.cloud.aiplatform import utils
28+
from google.cloud.aiplatform import datasets
29+
30+
31+
class _ColumnNamesDataset(datasets._Dataset):
32+
@property
33+
def column_names(self) -> List[str]:
34+
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
35+
Google BigQuery source.
36+
37+
Returns:
38+
List[str]
39+
A list of columns names
40+
41+
Raises:
42+
RuntimeError: When no valid source is found.
43+
"""
44+
45+
self._assert_gca_resource_is_available()
46+
47+
metadata = self._gca_resource.metadata
48+
49+
if metadata is None:
50+
raise RuntimeError("No metadata found for dataset")
51+
52+
input_config = metadata.get("inputConfig")
53+
54+
if input_config is None:
55+
raise RuntimeError("No inputConfig found for dataset")
56+
57+
gcs_source = input_config.get("gcsSource")
58+
bq_source = input_config.get("bigquerySource")
59+
60+
if gcs_source:
61+
gcs_source_uris = gcs_source.get("uri")
62+
63+
if gcs_source_uris and len(gcs_source_uris) > 0:
64+
# Lexicographically sort the files
65+
gcs_source_uris.sort()
66+
67+
# Get the first file in sorted list
68+
# TODO(b/193044977): Return as Set instead of List
69+
return list(
70+
self._retrieve_gcs_source_columns(
71+
project=self.project,
72+
gcs_csv_file_path=gcs_source_uris[0],
73+
credentials=self.credentials,
74+
)
75+
)
76+
elif bq_source:
77+
bq_table_uri = bq_source.get("uri")
78+
if bq_table_uri:
79+
# TODO(b/193044977): Return as Set instead of List
80+
return list(
81+
self._retrieve_bq_source_columns(
82+
project=self.project,
83+
bq_table_uri=bq_table_uri,
84+
credentials=self.credentials,
85+
)
86+
)
87+
88+
raise RuntimeError("No valid CSV or BigQuery datasource found.")
89+
90+
@staticmethod
91+
def _retrieve_gcs_source_columns(
92+
project: str,
93+
gcs_csv_file_path: str,
94+
credentials: Optional[auth_credentials.Credentials] = None,
95+
) -> Set[str]:
96+
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
97+
98+
Example Usage:
99+
100+
column_names = _retrieve_gcs_source_columns(
101+
"project_id",
102+
"gs://example-bucket/path/to/csv_file"
103+
)
104+
105+
# column_names = {"column_1", "column_2"}
106+
107+
Args:
108+
project (str):
109+
Required. Project to initiate the Google Cloud Storage client with.
110+
gcs_csv_file_path (str):
111+
Required. A full path to a CSV files stored on Google Cloud Storage.
112+
Must include "gs://" prefix.
113+
credentials (auth_credentials.Credentials):
114+
Credentials to use to with GCS Client.
115+
Returns:
116+
Set[str]
117+
A set of columns names in the CSV file.
118+
119+
Raises:
120+
RuntimeError: When the retrieved CSV file is invalid.
121+
"""
122+
123+
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
124+
gcs_csv_file_path
125+
)
126+
client = storage.Client(project=project, credentials=credentials)
127+
bucket = client.bucket(gcs_bucket)
128+
blob = bucket.blob(gcs_blob)
129+
130+
# Incrementally download the CSV file until the header is retrieved
131+
first_new_line_index = -1
132+
start_index = 0
133+
increment = 1000
134+
line = ""
135+
136+
try:
137+
logger = logging.getLogger("google.resumable_media._helpers")
138+
logging_warning_filter = utils.LoggingFilter(logging.INFO)
139+
logger.addFilter(logging_warning_filter)
140+
141+
while first_new_line_index == -1:
142+
line += blob.download_as_bytes(
143+
start=start_index, end=start_index + increment - 1
144+
).decode("utf-8")
145+
146+
first_new_line_index = line.find("\n")
147+
start_index += increment
148+
149+
header_line = line[:first_new_line_index]
150+
151+
# Split to make it an iterable
152+
header_line = header_line.split("\n")[:1]
153+
154+
csv_reader = csv.reader(header_line, delimiter=",")
155+
except (ValueError, RuntimeError) as err:
156+
raise RuntimeError(
157+
"There was a problem extracting the headers from the CSV file at '{}': {}".format(
158+
gcs_csv_file_path, err
159+
)
160+
)
161+
finally:
162+
logger.removeFilter(logging_warning_filter)
163+
164+
return set(next(csv_reader))
165+
166+
@staticmethod
167+
def _get_bq_schema_field_names_recursively(
168+
schema_field: bigquery.SchemaField,
169+
) -> Set[str]:
170+
"""Retrieve the name for a schema field along with ancestor fields.
171+
Nested schema fields are flattened and concatenated with a ".".
172+
Schema fields with child fields are not included, but the children are.
173+
174+
Args:
175+
project (str):
176+
Required. Project to initiate the BigQuery client with.
177+
bq_table_uri (str):
178+
Required. A URI to a BigQuery table.
179+
Can include "bq://" prefix but not required.
180+
credentials (auth_credentials.Credentials):
181+
Credentials to use with BQ Client.
182+
183+
Returns:
184+
Set[str]
185+
A set of columns names in the BigQuery table.
186+
"""
187+
188+
ancestor_names = {
189+
nested_field_name
190+
for field in schema_field.fields
191+
for nested_field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
192+
field
193+
)
194+
}
195+
196+
# Only return "leaf nodes", basically any field that doesn't have children
197+
if len(ancestor_names) == 0:
198+
return {schema_field.name}
199+
else:
200+
return {f"{schema_field.name}.{name}" for name in ancestor_names}
201+
202+
@staticmethod
203+
def _retrieve_bq_source_columns(
204+
project: str,
205+
bq_table_uri: str,
206+
credentials: Optional[auth_credentials.Credentials] = None,
207+
) -> Set[str]:
208+
"""Retrieve the column names from a table on Google BigQuery
209+
Nested schema fields are flattened and concatenated with a ".".
210+
Schema fields with child fields are not included, but the children are.
211+
212+
Example Usage:
213+
214+
column_names = _retrieve_bq_source_columns(
215+
"project_id",
216+
"bq://project_id.dataset.table"
217+
)
218+
219+
# column_names = {"column_1", "column_2", "column_3.nested_field"}
220+
221+
Args:
222+
project (str):
223+
Required. Project to initiate the BigQuery client with.
224+
bq_table_uri (str):
225+
Required. A URI to a BigQuery table.
226+
Can include "bq://" prefix but not required.
227+
credentials (auth_credentials.Credentials):
228+
Credentials to use with BQ Client.
229+
230+
Returns:
231+
Set[str]
232+
A set of column names in the BigQuery table.
233+
"""
234+
235+
# Remove bq:// prefix
236+
prefix = "bq://"
237+
if bq_table_uri.startswith(prefix):
238+
bq_table_uri = bq_table_uri[len(prefix) :]
239+
240+
client = bigquery.Client(project=project, credentials=credentials)
241+
table = client.get_table(bq_table_uri)
242+
schema = table.schema
243+
244+
return {
245+
field_name
246+
for field in schema
247+
for field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
248+
field
249+
)
250+
}

0 commit comments

Comments
 (0)