Skip to content

Commit 4fce8c4

Browse files
ivanmkcsasha-gitg
andauthored
feat: Added default AutoMLTabularTrainingJob column transformations (#357)
* Added default column_transformation code * Added docstrings * Added tests and moved code to tabular_dataset * Switched to using BigQuery.Table instead of custom SQL query * Fixed bigquery unit test * Added GCS test * Fixed issues with incorrect input config parameter * Added test for AutoMLTabularTrainingJob for no transformations * Added comment * Fixed test * Ran linter * Switched from classmethod to staticmethod where applicable and logged column names * Added extra dataset tests * Added logging suppression * Fixed lint errors * Switched logging filter method Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>
1 parent beb4032 commit 4fce8c4

File tree

8 files changed

+413
-23
lines changed

8 files changed

+413
-23
lines changed

google/cloud/aiplatform/datasets/_datasources.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def __init__(
8686
raise ValueError("One of gcs_source or bq_source must be set.")
8787

8888
if gcs_source:
89-
dataset_metadata = {"input_config": {"gcs_source": {"uri": gcs_source}}}
89+
dataset_metadata = {"inputConfig": {"gcsSource": {"uri": gcs_source}}}
9090
elif bq_source:
91-
dataset_metadata = {"input_config": {"bigquery_source": {"uri": bq_source}}}
91+
dataset_metadata = {"inputConfig": {"bigquerySource": {"uri": bq_source}}}
9292

9393
self._dataset_metadata = dataset_metadata
9494

google/cloud/aiplatform/datasets/tabular_dataset.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Optional, Sequence, Tuple, Union
18+
import csv
19+
import logging
20+
21+
from typing import List, Optional, Sequence, Tuple, Union
1922

2023
from google.auth import credentials as auth_credentials
2124

25+
from google.cloud import bigquery
26+
from google.cloud import storage
27+
2228
from google.cloud.aiplatform import datasets
2329
from google.cloud.aiplatform.datasets import _datasources
2430
from google.cloud.aiplatform import initializer
@@ -33,6 +39,157 @@ class TabularDataset(datasets._Dataset):
3339
schema.dataset.metadata.tabular,
3440
)
3541

42+
@property
43+
def column_names(self) -> List[str]:
44+
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
45+
Google BigQuery source.
46+
47+
Returns:
48+
List[str]
49+
A list of columns names
50+
51+
Raises:
52+
RuntimeError: When no valid source is found.
53+
"""
54+
55+
metadata = self._gca_resource.metadata
56+
57+
if metadata is None:
58+
raise RuntimeError("No metadata found for dataset")
59+
60+
input_config = metadata.get("inputConfig")
61+
62+
if input_config is None:
63+
raise RuntimeError("No inputConfig found for dataset")
64+
65+
gcs_source = input_config.get("gcsSource")
66+
bq_source = input_config.get("bigquerySource")
67+
68+
if gcs_source:
69+
gcs_source_uris = gcs_source.get("uri")
70+
71+
if gcs_source_uris and len(gcs_source_uris) > 0:
72+
# Lexicographically sort the files
73+
gcs_source_uris.sort()
74+
75+
# Get the first file in sorted list
76+
return TabularDataset._retrieve_gcs_source_columns(
77+
self.project, gcs_source_uris[0]
78+
)
79+
elif bq_source:
80+
bq_table_uri = bq_source.get("uri")
81+
if bq_table_uri:
82+
return TabularDataset._retrieve_bq_source_columns(
83+
self.project, bq_table_uri
84+
)
85+
86+
raise RuntimeError("No valid CSV or BigQuery datasource found.")
87+
88+
@staticmethod
89+
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
90+
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
91+
92+
Example Usage:
93+
94+
column_names = _retrieve_gcs_source_columns(
95+
"project_id",
96+
"gs://example-bucket/path/to/csv_file"
97+
)
98+
99+
# column_names = ["column_1", "column_2"]
100+
101+
Args:
102+
project (str):
103+
Required. Project to initiate the Google Cloud Storage client with.
104+
gcs_csv_file_path (str):
105+
Required. A full path to a CSV files stored on Google Cloud Storage.
106+
Must include "gs://" prefix.
107+
108+
Returns:
109+
List[str]
110+
A list of columns names in the CSV file.
111+
112+
Raises:
113+
RuntimeError: When the retrieved CSV file is invalid.
114+
"""
115+
116+
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
117+
gcs_csv_file_path
118+
)
119+
client = storage.Client(project=project)
120+
bucket = client.bucket(gcs_bucket)
121+
blob = bucket.blob(gcs_blob)
122+
123+
# Incrementally download the CSV file until the header is retrieved
124+
first_new_line_index = -1
125+
start_index = 0
126+
increment = 1000
127+
line = ""
128+
129+
try:
130+
logger = logging.getLogger("google.resumable_media._helpers")
131+
logging_warning_filter = utils.LoggingFilter(logging.INFO)
132+
logger.addFilter(logging_warning_filter)
133+
134+
while first_new_line_index == -1:
135+
line += blob.download_as_bytes(
136+
start=start_index, end=start_index + increment
137+
).decode("utf-8")
138+
first_new_line_index = line.find("\n")
139+
start_index += increment
140+
141+
header_line = line[:first_new_line_index]
142+
143+
# Split to make it an iterable
144+
header_line = header_line.split("\n")[:1]
145+
146+
csv_reader = csv.reader(header_line, delimiter=",")
147+
except (ValueError, RuntimeError) as err:
148+
raise RuntimeError(
149+
"There was a problem extracting the headers from the CSV file at '{}': {}".format(
150+
gcs_csv_file_path, err
151+
)
152+
)
153+
finally:
154+
logger.removeFilter(logging_warning_filter)
155+
156+
return next(csv_reader)
157+
158+
@staticmethod
159+
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
160+
"""Retrieve the columns from a table on Google BigQuery
161+
162+
Example Usage:
163+
164+
column_names = _retrieve_bq_source_columns(
165+
"project_id",
166+
"bq://project_id.dataset.table"
167+
)
168+
169+
# column_names = ["column_1", "column_2"]
170+
171+
Args:
172+
project (str):
173+
Required. Project to initiate the BigQuery client with.
174+
bq_table_uri (str):
175+
Required. A URI to a BigQuery table.
176+
Can include "bq://" prefix but not required.
177+
178+
Returns:
179+
List[str]
180+
A list of columns names in the BigQuery table.
181+
"""
182+
183+
# Remove bq:// prefix
184+
prefix = "bq://"
185+
if bq_table_uri.startswith(prefix):
186+
bq_table_uri = bq_table_uri[len(prefix) :]
187+
188+
client = bigquery.Client(project=project)
189+
table = client.get_table(bq_table_uri)
190+
schema = table.schema
191+
return [schema.name for schema in schema]
192+
36193
@classmethod
37194
def create(
38195
cls,

google/cloud/aiplatform/initializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def credentials(self) -> Optional[auth_credentials.Credentials]:
170170
if self._credentials:
171171
return self._credentials
172172
logger = logging.getLogger("google.auth._default")
173-
logging_warning_filter = utils.LoggingWarningFilter()
173+
logging_warning_filter = utils.LoggingFilter(logging.WARNING)
174174
logger.addFilter(logging_warning_filter)
175175
credentials, _ = google.auth.default()
176176
logger.removeFilter(logging_warning_filter)

google/cloud/aiplatform/training_jobs.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def __init__(
130130

131131
super().__init__(project=project, location=location, credentials=credentials)
132132
self._display_name = display_name
133-
self._project = project
134133
self._training_encryption_spec = initializer.global_config.get_encryption_spec(
135134
encryption_spec_key_name=training_encryption_spec_key_name
136135
)
@@ -2955,10 +2954,31 @@ def _run(
29552954

29562955
training_task_definition = schema.training_job.definition.automl_tabular
29572956

2957+
if self._column_transformations is None:
2958+
_LOGGER.info(
2959+
"No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
2960+
)
2961+
2962+
column_names = [
2963+
column_name
2964+
for column_name in dataset.column_names
2965+
if column_name != target_column
2966+
]
2967+
column_transformations = [
2968+
{"auto": {"column_name": column_name}} for column_name in column_names
2969+
]
2970+
2971+
_LOGGER.info(
2972+
"The column transformation of type 'auto' was set for the following columns: %s."
2973+
% column_names
2974+
)
2975+
else:
2976+
column_transformations = self._column_transformations
2977+
29582978
training_task_inputs_dict = {
29592979
# required inputs
29602980
"targetColumn": target_column,
2961-
"transformations": self._column_transformations,
2981+
"transformations": column_transformations,
29622982
"trainBudgetMilliNodeHours": budget_milli_node_hours,
29632983
# optional inputs
29642984
"weightColumnName": weight_column,

google/cloud/aiplatform/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,9 @@ class PredictionClientWithOverride(ClientWithOverride):
468468
)
469469

470470

471-
class LoggingWarningFilter(logging.Filter):
471+
class LoggingFilter(logging.Filter):
472+
def __init__(self, warning_level: int):
473+
self._warning_level = warning_level
474+
472475
def filter(self, record):
473-
return record.levelname == logging.WARNING
476+
return record.levelname == self._warning_level

tests/system/aiplatform/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from google.api_core import exceptions
2626
from google.api_core import client_options
2727

28-
from google.cloud import storage
2928
from google.cloud import aiplatform
29+
from google.cloud import storage
3030
from google.cloud.aiplatform import utils
3131
from google.cloud.aiplatform import initializer
3232
from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset

tests/unit/aiplatform/test_automl_tabular_training_jobs.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,16 @@
3434
_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name"
3535
_TEST_DATASET_NAME = "test-dataset-name"
3636
_TEST_DISPLAY_NAME = "test-display-name"
37-
_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image"
3837
_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular
3938
_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image
4039

40+
_TEST_TRAINING_COLUMN_NAMES = [
41+
"sepal_width",
42+
"sepal_length",
43+
"petal_length",
44+
"petal_width",
45+
]
46+
4147
_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [
4248
{"auto": {"column_name": "sepal_width"}},
4349
{"auto": {"column_name": "sepal_length"}},
@@ -169,7 +175,9 @@ def mock_dataset_tabular():
169175
name=_TEST_DATASET_NAME,
170176
metadata={},
171177
)
172-
return ds
178+
ds.column_names = _TEST_TRAINING_COLUMN_NAMES
179+
180+
yield ds
173181

174182

175183
@pytest.fixture
@@ -347,6 +355,81 @@ def test_run_call_pipeline_if_no_model_display_name(
347355
training_pipeline=true_training_pipeline,
348356
)
349357

358+
@pytest.mark.parametrize("sync", [True, False])
359+
# This test checks that default transformations are used if no columns transformations are provided
360+
def test_run_call_pipeline_service_create_if_no_column_transformations(
361+
self,
362+
mock_pipeline_service_create,
363+
mock_pipeline_service_get,
364+
mock_dataset_tabular,
365+
mock_model_service_get,
366+
sync,
367+
):
368+
aiplatform.init(
369+
project=_TEST_PROJECT,
370+
staging_bucket=_TEST_BUCKET_NAME,
371+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
372+
)
373+
374+
job = training_jobs.AutoMLTabularTrainingJob(
375+
display_name=_TEST_DISPLAY_NAME,
376+
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
377+
optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE,
378+
column_transformations=None,
379+
optimization_objective_recall_value=None,
380+
optimization_objective_precision_value=None,
381+
)
382+
383+
model_from_job = job.run(
384+
dataset=mock_dataset_tabular,
385+
target_column=_TEST_TRAINING_TARGET_COLUMN,
386+
model_display_name=_TEST_MODEL_DISPLAY_NAME,
387+
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
388+
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
389+
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
390+
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
391+
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
392+
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
393+
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
394+
sync=sync,
395+
)
396+
397+
if not sync:
398+
model_from_job.wait()
399+
400+
true_fraction_split = gca_training_pipeline.FractionSplit(
401+
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
402+
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
403+
test_fraction=_TEST_TEST_FRACTION_SPLIT,
404+
)
405+
406+
true_managed_model = gca_model.Model(
407+
display_name=_TEST_MODEL_DISPLAY_NAME,
408+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
409+
)
410+
411+
true_input_data_config = gca_training_pipeline.InputDataConfig(
412+
fraction_split=true_fraction_split,
413+
predefined_split=gca_training_pipeline.PredefinedSplit(
414+
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
415+
),
416+
dataset_id=mock_dataset_tabular.name,
417+
)
418+
419+
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
420+
display_name=_TEST_DISPLAY_NAME,
421+
training_task_definition=schema.training_job.definition.automl_tabular,
422+
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
423+
model_to_upload=true_managed_model,
424+
input_data_config=true_input_data_config,
425+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
426+
)
427+
428+
mock_pipeline_service_create.assert_called_once_with(
429+
parent=initializer.global_config.common_location_path(),
430+
training_pipeline=true_training_pipeline,
431+
)
432+
350433
@pytest.mark.usefixtures(
351434
"mock_pipeline_service_create",
352435
"mock_pipeline_service_get",

0 commit comments

Comments
 (0)