Skip to content

Commit 7f16202

Browse files
committed
fix fixture issue
1 parent 3ae5d39 commit 7f16202

File tree

2 files changed

+48
-56
lines changed

2 files changed

+48
-56
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,12 +1562,10 @@ def _model_upload_fail_string(self) -> str:
15621562

15631563

15641564
class _ForecastingTrainingJob(_TrainingJob):
1565-
"""ABC for Forecasting Training Pipelines.."""
1565+
"""ABC for Forecasting Training Pipelines."""
15661566

15671567
def __init__(
15681568
self,
1569-
model_type: str,
1570-
training_task_definition: str,
15711569
display_name: Optional[str] = None,
15721570
optimization_objective: Optional[str] = None,
15731571
column_specs: Optional[Dict[str, str]] = None,
@@ -1582,15 +1580,6 @@ def __init__(
15821580
"""Constructs a Forecasting Training Job.
15831581
15841582
Args:
1585-
model_type (str): The type of forecasting model.
1586-
training_task_definition (str):
1587-
Required. A Google Cloud Storage path to the
1588-
YAML file that defines the training task which
1589-
is responsible for producing the model artifact,
1590-
and may also include additional auxiliary work.
1591-
The definition files that can be used here are
1592-
found in gs://google-cloud-
1593-
aiplatform/schema/trainingjob/definition/.
15941583
display_name (str):
15951584
Optional. The user-defined name of this TrainingPipeline.
15961585
optimization_objective (str):
@@ -1687,8 +1676,24 @@ def __init__(
16871676

16881677
self._optimization_objective = optimization_objective
16891678
self._additional_experiments = []
1690-
self._model_type = model_type
1691-
self._training_task_definition = training_task_definition
1679+
1680+
@property
1681+
@classmethod
1682+
@abc.abstractmethod
1683+
def _model_type(cls) -> str:
1684+
"""The type of forecasting model."""
1685+
pass
1686+
1687+
@property
1688+
@classmethod
1689+
@abc.abstractmethod
1690+
def _training_task_definition(cls) -> str:
1691+
"""A GCS path to the YAML file that defines the training task.
1692+
1693+
The definition files that can be used here are found in
1694+
gs://google-cloud- aiplatform/schema/trainingjob/definition/.
1695+
"""
1696+
pass
16921697

16931698
def run(
16941699
self,
@@ -1900,13 +1905,14 @@ def run(
19001905

19011906
if self._is_waiting_to_run():
19021907
raise RuntimeError(
1903-
f"{self._model_type} Forecasting Training is already scheduled "
1904-
"to run."
1908+
f"{self.__class__._model_type} Forecasting Training is already "
1909+
"scheduled to run."
19051910
)
19061911

19071912
if self._has_run:
19081913
raise RuntimeError(
1909-
f"{self._model_type} Forecasting Training has " "already run."
1914+
f"{self.__class__._model_type} Forecasting Training has "
1915+
"already run."
19101916
)
19111917

19121918
if additional_experiments:
@@ -2210,7 +2216,7 @@ def _run(
22102216
)
22112217

22122218
new_model = self._run_job(
2213-
training_task_definition=self._training_task_definition,
2219+
training_task_definition=self.__class__._training_task_definition,
22142220
training_task_inputs=training_task_inputs_dict,
22152221
dataset=dataset,
22162222
training_fraction_split=training_fraction_split,
@@ -4590,6 +4596,9 @@ class column_data_types:
45904596

45914597

45924598
class AutoMLForecastingTrainingJob(_ForecastingTrainingJob):
4599+
_model_type = "AutoML"
4600+
_training_task_definition = (
4601+
schema.training_job.definition.automl_forecasting)
45934602
_supported_training_schemas = (schema.training_job.definition.automl_forecasting,)
45944603

45954604
def __init__(
@@ -4692,10 +4701,6 @@ def __init__(
46924701
ValueError: If both column_transformations and column_specs were provided.
46934702
"""
46944703
super().__init__(
4695-
model_type="AutoML",
4696-
training_task_definition=(
4697-
schema.training_job.definition.automl_forecasting
4698-
),
46994704
display_name=display_name,
47004705
optimization_objective=optimization_objective,
47014706
column_specs=column_specs,
@@ -4954,6 +4959,9 @@ def evaluated_data_items_bigquery_uri(self) -> Optional[str]:
49544959

49554960

49564961
class SequenceToSequencePlusForecastingTrainingJob(_ForecastingTrainingJob):
4962+
_model_type = "Seq2Seq"
4963+
_training_task_definition = (
4964+
schema.training_job.definition.seq2seq_forecasting)
49574965
_supported_training_schemas = (schema.training_job.definition.seq2seq_forecasting,)
49584966

49594967
def __init__(
@@ -5056,10 +5064,6 @@ def __init__(
50565064
ValueError: If both column_transformations and column_specs were provided.
50575065
"""
50585066
super().__init__(
5059-
model_type="Seq2Seq",
5060-
training_task_definition=(
5061-
schema.training_job.definition.seq2seq_forecasting
5062-
),
50635067
display_name=display_name,
50645068
optimization_objective=optimization_objective,
50655069
column_specs=column_specs,

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,8 @@ def setup_method(self):
256256
def teardown_method(self):
257257
initializer.global_pool.shutdown(wait=True)
258258

259+
@pytest.mark.parametrize("sync", [True, False])
259260
@pytest.mark.parametrize(
260-
"sync",
261-
[True, False],
262261
"training_job",
263262
[
264263
training_jobs.AutoMLForecastingTrainingJob,
@@ -327,7 +326,7 @@ def test_run_call_pipeline_service_create(
327326
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
328327
display_name=_TEST_DISPLAY_NAME,
329328
labels=_TEST_LABELS,
330-
training_task_definition=schema.training_job.definition.automl_forecasting,
329+
training_task_definition=training_job._training_task_definition,
331330
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS,
332331
model_to_upload=true_managed_model,
333332
input_data_config=true_input_data_config,
@@ -353,9 +352,8 @@ def test_run_call_pipeline_service_create(
353352

354353
assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
355354

355+
@pytest.mark.parametrize("sync", [True, False])
356356
@pytest.mark.parametrize(
357-
"sync",
358-
[True, False],
359357
"training_job",
360358
[
361359
training_jobs.AutoMLForecastingTrainingJob,
@@ -424,7 +422,7 @@ def test_run_call_pipeline_service_create_with_timeout(
424422
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
425423
display_name=_TEST_DISPLAY_NAME,
426424
labels=_TEST_LABELS,
427-
training_task_definition=schema.training_job.definition.automl_forecasting,
425+
training_task_definition=training_job._training_task_definition,
428426
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS,
429427
model_to_upload=true_managed_model,
430428
input_data_config=true_input_data_config,
@@ -437,9 +435,8 @@ def test_run_call_pipeline_service_create_with_timeout(
437435
)
438436

439437
@pytest.mark.usefixtures("mock_pipeline_service_get")
438+
@pytest.mark.parametrize("sync", [True, False])
440439
@pytest.mark.parametrize(
441-
"sync",
442-
[True, False],
443440
"training_job",
444441
[
445442
training_jobs.AutoMLForecastingTrainingJob,
@@ -502,7 +499,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
502499
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
503500
display_name=_TEST_DISPLAY_NAME,
504501
labels=_TEST_LABELS,
505-
training_task_definition=schema.training_job.definition.automl_forecasting,
502+
training_task_definition=training_job._training_task_definition,
506503
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
507504
model_to_upload=true_managed_model,
508505
input_data_config=true_input_data_config,
@@ -515,9 +512,8 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
515512
)
516513

517514
@pytest.mark.usefixtures("mock_pipeline_service_get")
515+
@pytest.mark.parametrize("sync", [True, False])
518516
@pytest.mark.parametrize(
519-
"sync",
520-
[True, False],
521517
"training_job",
522518
[
523519
training_jobs.AutoMLForecastingTrainingJob,
@@ -577,7 +573,7 @@ def test_run_call_pipeline_if_set_additional_experiments(
577573

578574
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
579575
display_name=_TEST_DISPLAY_NAME,
580-
training_task_definition=schema.training_job.definition.automl_forecasting,
576+
training_task_definition=training_job._training_task_definition,
581577
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS,
582578
model_to_upload=true_managed_model,
583579
input_data_config=true_input_data_config,
@@ -594,9 +590,8 @@ def test_run_call_pipeline_if_set_additional_experiments(
594590
"mock_pipeline_service_get",
595591
"mock_model_service_get",
596592
)
593+
@pytest.mark.parametrize("sync", [True, False])
597594
@pytest.mark.parametrize(
598-
"sync",
599-
[True, False],
600595
"training_job",
601596
[
602597
training_jobs.AutoMLForecastingTrainingJob,
@@ -664,9 +659,8 @@ def test_run_called_twice_raises(
664659
sync=sync,
665660
)
666661

662+
@pytest.mark.parametrize("sync", [True, False])
667663
@pytest.mark.parametrize(
668-
"sync",
669-
[True, False],
670664
"training_job",
671665
[
672666
training_jobs.AutoMLForecastingTrainingJob,
@@ -748,9 +742,8 @@ def test_raises_before_run_is_called(
748742
with pytest.raises(RuntimeError):
749743
job.state
750744

745+
@pytest.mark.parametrize("sync", [True, False])
751746
@pytest.mark.parametrize(
752-
"sync",
753-
[True, False],
754747
"training_job",
755748
[
756749
training_jobs.AutoMLForecastingTrainingJob,
@@ -830,7 +823,7 @@ def test_splits_fraction(
830823

831824
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
832825
display_name=_TEST_DISPLAY_NAME,
833-
training_task_definition=schema.training_job.definition.automl_forecasting,
826+
training_task_definition=training_job._training_task_definition,
834827
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
835828
model_to_upload=true_managed_model,
836829
input_data_config=true_input_data_config,
@@ -843,9 +836,8 @@ def test_splits_fraction(
843836
timeout=None,
844837
)
845838

839+
@pytest.mark.parametrize("sync", [True, False])
846840
@pytest.mark.parametrize(
847-
"sync",
848-
[True, False],
849841
"training_job",
850842
[
851843
training_jobs.AutoMLForecastingTrainingJob,
@@ -927,9 +919,7 @@ def test_splits_timestamp(
927919

928920
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
929921
display_name=_TEST_DISPLAY_NAME,
930-
training_task_definition=(
931-
schema.training_job.definition.automl_forecasting
932-
),
922+
training_task_definition=training_job._training_task_definition,
933923
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
934924
model_to_upload=true_managed_model,
935925
input_data_config=true_input_data_config,
@@ -942,9 +932,8 @@ def test_splits_timestamp(
942932
timeout=None,
943933
)
944934

935+
@pytest.mark.parametrize("sync", [True, False])
945936
@pytest.mark.parametrize(
946-
"sync",
947-
[True, False],
948937
"training_job",
949938
[
950939
training_jobs.AutoMLForecastingTrainingJob,
@@ -1020,7 +1009,7 @@ def test_splits_predefined(
10201009

10211010
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
10221011
display_name=_TEST_DISPLAY_NAME,
1023-
training_task_definition=schema.training_job.definition.automl_forecasting,
1012+
training_task_definition=training_job._training_task_definition,
10241013
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
10251014
model_to_upload=true_managed_model,
10261015
input_data_config=true_input_data_config,
@@ -1033,9 +1022,8 @@ def test_splits_predefined(
10331022
timeout=None,
10341023
)
10351024

1025+
@pytest.mark.parametrize("sync", [True, False])
10361026
@pytest.mark.parametrize(
1037-
"sync",
1038-
[True, False],
10391027
"training_job",
10401028
[
10411029
training_jobs.AutoMLForecastingTrainingJob,
@@ -1105,7 +1093,7 @@ def test_splits_default(
11051093

11061094
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
11071095
display_name=_TEST_DISPLAY_NAME,
1108-
training_task_definition=schema.training_job.definition.automl_forecasting,
1096+
training_task_definition=training_job._training_task_definition,
11091097
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
11101098
model_to_upload=true_managed_model,
11111099
input_data_config=true_input_data_config,

0 commit comments

Comments
 (0)