Skip to content

Commit ee49e00

Browse files
authored
feat: Allow users to specify timestamp split for vertex forecasting (#1187)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) Fixes b/230009255 🦕
1 parent e03f373 commit ee49e00

File tree

2 files changed

+123
-7
lines changed

2 files changed

+123
-7
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ def _create_input_data_config(
408408
that piece is ignored by the pipeline.
409409
410410
Supported only for tabular and time series Datasets.
411-
This parameter must be used with training_fraction_split, validation_fraction_split and test_fraction_split.
411+
This parameter must be used with training_fraction_split,
412+
validation_fraction_split, and test_fraction_split.
412413
gcs_destination_uri_prefix (str):
413414
Optional. The Google Cloud Storage location.
414415
@@ -669,7 +670,8 @@ def _run_job(
669670
that piece is ignored by the pipeline.
670671
671672
Supported only for tabular and time series Datasets.
672-
This parameter must be used with training_fraction_split, validation_fraction_split and test_fraction_split.
673+
This parameter must be used with training_fraction_split,
674+
validation_fraction_split, and test_fraction_split.
673675
model (~.model.Model):
674676
Optional. Describes the Model that may be uploaded (via
675677
[ModelService.UploadMode][]) by this TrainingPipeline. The
@@ -3487,9 +3489,9 @@ def run(
34873489
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
34883490
piece of data the key is not present or has an invalid value,
34893491
that piece is ignored by the pipeline.
3490-
34913492
Supported only for tabular and time series Datasets.
3492-
This parameter must be used with training_fraction_split, validation_fraction_split and test_fraction_split.
3493+
This parameter must be used with training_fraction_split,
3494+
validation_fraction_split, and test_fraction_split.
34933495
weight_column (str):
34943496
Optional. Name of the column that should be used as the weight column.
34953497
Higher values in this column give more importance to the row
@@ -3681,9 +3683,9 @@ def _run(
36813683
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
36823684
piece of data the key is not present or has an invalid value,
36833685
that piece is ignored by the pipeline.
3684-
36853686
Supported only for tabular and time series Datasets.
3686-
This parameter must be used with training_fraction_split, validation_fraction_split and test_fraction_split.
3687+
This parameter must be used with training_fraction_split,
3688+
validation_fraction_split, and test_fraction_split.
36873689
weight_column (str):
36883690
Optional. Name of the column that should be used as the weight column.
36893691
Higher values in this column give more importance to the row
@@ -4022,6 +4024,7 @@ def run(
40224024
validation_fraction_split: Optional[float] = None,
40234025
test_fraction_split: Optional[float] = None,
40244026
predefined_split_column_name: Optional[str] = None,
4027+
timestamp_split_column_name: Optional[str] = None,
40254028
weight_column: Optional[str] = None,
40264029
time_series_attribute_columns: Optional[List[str]] = None,
40274030
context_window: Optional[int] = None,
@@ -4106,6 +4109,16 @@ def run(
41064109
ignored by the pipeline.
41074110
41084111
Supported only for tabular and time series Datasets.
4112+
timestamp_split_column_name (str):
4113+
Optional. The key is a name of one of the Dataset's data
4114+
columns. The value of the key values of the key (the values in
4115+
the column) must be in RFC 3339 `date-time` format, where
4116+
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
4117+
piece of data the key is not present or has an invalid value,
4118+
that piece is ignored by the pipeline.
4119+
Supported only for tabular and time series Datasets.
4120+
This parameter must be used with training_fraction_split,
4121+
validation_fraction_split, and test_fraction_split.
41094122
weight_column (str):
41104123
Optional. Name of the column that should be used as the weight column.
41114124
Higher values in this column give more importance to the row
@@ -4229,6 +4242,7 @@ def run(
42294242
validation_fraction_split=validation_fraction_split,
42304243
test_fraction_split=test_fraction_split,
42314244
predefined_split_column_name=predefined_split_column_name,
4245+
timestamp_split_column_name=timestamp_split_column_name,
42324246
weight_column=weight_column,
42334247
time_series_attribute_columns=time_series_attribute_columns,
42344248
context_window=context_window,
@@ -4260,6 +4274,7 @@ def _run(
42604274
validation_fraction_split: Optional[float] = None,
42614275
test_fraction_split: Optional[float] = None,
42624276
predefined_split_column_name: Optional[str] = None,
4277+
timestamp_split_column_name: Optional[str] = None,
42634278
weight_column: Optional[str] = None,
42644279
time_series_attribute_columns: Optional[List[str]] = None,
42654280
context_window: Optional[int] = None,
@@ -4352,6 +4367,16 @@ def _run(
43524367
ignored by the pipeline.
43534368
43544369
Supported only for tabular and time series Datasets.
4370+
timestamp_split_column_name (str):
4371+
Optional. The key is a name of one of the Dataset's data
4372+
columns. The value of the key values of the key (the values in
4373+
the column) must be in RFC 3339 `date-time` format, where
4374+
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
4375+
piece of data the key is not present or has an invalid value,
4376+
that piece is ignored by the pipeline.
4377+
Supported only for tabular and time series Datasets.
4378+
This parameter must be used with training_fraction_split,
4379+
validation_fraction_split, and test_fraction_split.
43554380
weight_column (str):
43564381
Optional. Name of the column that should be used as the weight column.
43574382
Higher values in this column give more importance to the row
@@ -4511,7 +4536,7 @@ def _run(
45114536
validation_fraction_split=validation_fraction_split,
45124537
test_fraction_split=test_fraction_split,
45134538
predefined_split_column_name=predefined_split_column_name,
4514-
timestamp_split_column_name=None, # Not supported by AutoMLForecasting
4539+
timestamp_split_column_name=timestamp_split_column_name,
45154540
model=model,
45164541
create_request_timeout=create_request_timeout,
45174542
)

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
_TEST_FRACTION_SPLIT_TEST = 0.2
149149

150150
_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split"
151+
_TEST_SPLIT_TIMESTAMP_COLUMN_NAME = "timestamp"
151152

152153

153154
@pytest.fixture
@@ -768,6 +769,96 @@ def test_splits_fraction(
768769
timeout=None,
769770
)
770771

772+
@pytest.mark.parametrize("sync", [True, False])
773+
def test_splits_timestamp(
774+
self,
775+
mock_pipeline_service_create,
776+
mock_pipeline_service_get,
777+
mock_dataset_time_series,
778+
mock_model_service_get,
779+
sync,
780+
):
781+
"""Initiate aiplatform with encryption key name.
782+
783+
Create and run an AutoML Forecasting training job, verify calls and
784+
return value
785+
"""
786+
787+
aiplatform.init(
788+
project=_TEST_PROJECT,
789+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
790+
)
791+
792+
job = AutoMLForecastingTrainingJob(
793+
display_name=_TEST_DISPLAY_NAME,
794+
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
795+
column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS,
796+
)
797+
798+
model_from_job = job.run(
799+
dataset=mock_dataset_time_series,
800+
training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING,
801+
validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION,
802+
test_fraction_split=_TEST_FRACTION_SPLIT_TEST,
803+
timestamp_split_column_name=_TEST_SPLIT_TIMESTAMP_COLUMN_NAME,
804+
target_column=_TEST_TRAINING_TARGET_COLUMN,
805+
time_column=_TEST_TRAINING_TIME_COLUMN,
806+
time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
807+
unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
808+
available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
809+
forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON,
810+
data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT,
811+
data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT,
812+
model_display_name=_TEST_MODEL_DISPLAY_NAME,
813+
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
814+
time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS,
815+
context_window=_TEST_TRAINING_CONTEXT_WINDOW,
816+
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
817+
export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS,
818+
export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
819+
export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
820+
quantiles=_TEST_TRAINING_QUANTILES,
821+
validation_options=_TEST_TRAINING_VALIDATION_OPTIONS,
822+
sync=sync,
823+
create_request_timeout=None,
824+
)
825+
826+
if not sync:
827+
model_from_job.wait()
828+
829+
true_split = gca_training_pipeline.TimestampSplit(
830+
training_fraction=_TEST_FRACTION_SPLIT_TRAINING,
831+
validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION,
832+
test_fraction=_TEST_FRACTION_SPLIT_TEST,
833+
key=_TEST_SPLIT_TIMESTAMP_COLUMN_NAME,
834+
)
835+
836+
true_managed_model = gca_model.Model(
837+
display_name=_TEST_MODEL_DISPLAY_NAME,
838+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
839+
)
840+
841+
true_input_data_config = gca_training_pipeline.InputDataConfig(
842+
timestamp_split=true_split, dataset_id=mock_dataset_time_series.name
843+
)
844+
845+
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
846+
display_name=_TEST_DISPLAY_NAME,
847+
training_task_definition=(
848+
schema.training_job.definition.automl_forecasting
849+
),
850+
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
851+
model_to_upload=true_managed_model,
852+
input_data_config=true_input_data_config,
853+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
854+
)
855+
856+
mock_pipeline_service_create.assert_called_once_with(
857+
parent=initializer.global_config.common_location_path(),
858+
training_pipeline=true_training_pipeline,
859+
timeout=None,
860+
)
861+
771862
@pytest.mark.parametrize("sync", [True, False])
772863
def test_splits_predefined(
773864
self,

0 commit comments

Comments
 (0)