Skip to content

Commit 2cf47c9

Browse files
committed
Add seq2seq job to init file.
1 parent 950b827 commit 2cf47c9

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

google/cloud/aiplatform/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
CustomPythonPackageTrainingJob,
6464
AutoMLTabularTrainingJob,
6565
AutoMLForecastingTrainingJob,
66+
SequenceToSequencePlusForecastingTrainingJob,
6667
AutoMLImageTrainingJob,
6768
AutoMLTextTrainingJob,
6869
AutoMLVideoTrainingJob,
@@ -116,6 +117,7 @@
116117
"Model",
117118
"ModelEvaluation",
118119
"PipelineJob",
120+
"SequenceToSequencePlusForecastingTrainingJob",
119121
"TabularDataset",
120122
"Tensorboard",
121123
"TensorboardExperiment",

google/cloud/aiplatform/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class definition:
2323
custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"
2424
automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml"
2525
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
26-
seq2seq_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
26+
seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
2727
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
2828
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
2929
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"

google/cloud/aiplatform/training_jobs.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ def _training_task_definition(cls) -> str:
16931693
"""A GCS path to the YAML file that defines the training task.
16941694
16951695
The definition files that can be used here are found in
1696-
gs://google-cloud- aiplatform/schema/trainingjob/definition/.
1696+
gs://google-cloud-aiplatform/schema/trainingjob/definition/.
16971697
"""
16981698
pass
16991699

@@ -1907,13 +1907,13 @@ def run(
19071907

19081908
if self._is_waiting_to_run():
19091909
raise RuntimeError(
1910-
f"{self.__class__._model_type} Forecasting Training is already "
1911-
"scheduled to run."
1910+
f"{self._model_type} Forecasting Training is already scheduled "
1911+
"to run."
19121912
)
19131913

19141914
if self._has_run:
19151915
raise RuntimeError(
1916-
f"{self.__class__._model_type} Forecasting Training has already run."
1916+
f"{self._model_type} Forecasting Training has already run."
19171917
)
19181918

19191919
if additional_experiments:
@@ -2218,7 +2218,7 @@ def _run(
22182218
)
22192219

22202220
new_model = self._run_job(
2221-
training_task_definition=self.__class__._training_task_definition,
2221+
training_task_definition=self._training_task_definition,
22222222
training_task_inputs=training_task_inputs_dict,
22232223
dataset=dataset,
22242224
training_fraction_split=training_fraction_split,
@@ -4961,8 +4961,10 @@ def evaluated_data_items_bigquery_uri(self) -> Optional[str]:
49614961

49624962
class SequenceToSequencePlusForecastingTrainingJob(_ForecastingTrainingJob):
49634963
_model_type = "Seq2Seq"
4964-
_training_task_definition = schema.training_job.definition.seq2seq_forecasting
4965-
_supported_training_schemas = (schema.training_job.definition.seq2seq_forecasting,)
4964+
_training_task_definition = schema.training_job.definition.seq2seq_plus_forecasting
4965+
_supported_training_schemas = (
4966+
schema.training_job.definition.seq2seq_plus_forecasting,
4967+
)
49664968

49674969
def __init__(
49684970
self,

0 commit comments

Comments
 (0)