@@ -256,9 +256,8 @@ def setup_method(self):
256256 def teardown_method (self ):
257257 initializer .global_pool .shutdown (wait = True )
258258
259+ @pytest .mark .parametrize ("sync" , [True , False ])
259260 @pytest .mark .parametrize (
260- "sync" ,
261- [True , False ],
262261 "training_job" ,
263262 [
264263 training_jobs .AutoMLForecastingTrainingJob ,
@@ -327,7 +326,7 @@ def test_run_call_pipeline_service_create(
327326 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
328327 display_name = _TEST_DISPLAY_NAME ,
329328 labels = _TEST_LABELS ,
330- training_task_definition = schema . training_job .definition . automl_forecasting ,
329+ training_task_definition = training_job ._training_task_definition ,
331330 training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS ,
332331 model_to_upload = true_managed_model ,
333332 input_data_config = true_input_data_config ,
@@ -353,9 +352,8 @@ def test_run_call_pipeline_service_create(
353352
354353 assert job .state == gca_pipeline_state .PipelineState .PIPELINE_STATE_SUCCEEDED
355354
355+ @pytest .mark .parametrize ("sync" , [True , False ])
356356 @pytest .mark .parametrize (
357- "sync" ,
358- [True , False ],
359357 "training_job" ,
360358 [
361359 training_jobs .AutoMLForecastingTrainingJob ,
@@ -424,7 +422,7 @@ def test_run_call_pipeline_service_create_with_timeout(
424422 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
425423 display_name = _TEST_DISPLAY_NAME ,
426424 labels = _TEST_LABELS ,
427- training_task_definition = schema . training_job .definition . automl_forecasting ,
425+ training_task_definition = training_job ._training_task_definition ,
428426 training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS ,
429427 model_to_upload = true_managed_model ,
430428 input_data_config = true_input_data_config ,
@@ -437,9 +435,8 @@ def test_run_call_pipeline_service_create_with_timeout(
437435 )
438436
439437 @pytest .mark .usefixtures ("mock_pipeline_service_get" )
438+ @pytest .mark .parametrize ("sync" , [True , False ])
440439 @pytest .mark .parametrize (
441- "sync" ,
442- [True , False ],
443440 "training_job" ,
444441 [
445442 training_jobs .AutoMLForecastingTrainingJob ,
@@ -502,7 +499,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
502499 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
503500 display_name = _TEST_DISPLAY_NAME ,
504501 labels = _TEST_LABELS ,
505- training_task_definition = schema . training_job .definition . automl_forecasting ,
502+ training_task_definition = training_job ._training_task_definition ,
506503 training_task_inputs = _TEST_TRAINING_TASK_INPUTS ,
507504 model_to_upload = true_managed_model ,
508505 input_data_config = true_input_data_config ,
@@ -515,9 +512,8 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
515512 )
516513
517514 @pytest .mark .usefixtures ("mock_pipeline_service_get" )
515+ @pytest .mark .parametrize ("sync" , [True , False ])
518516 @pytest .mark .parametrize (
519- "sync" ,
520- [True , False ],
521517 "training_job" ,
522518 [
523519 training_jobs .AutoMLForecastingTrainingJob ,
@@ -577,7 +573,7 @@ def test_run_call_pipeline_if_set_additional_experiments(
577573
578574 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
579575 display_name = _TEST_DISPLAY_NAME ,
580- training_task_definition = schema . training_job .definition . automl_forecasting ,
576+ training_task_definition = training_job ._training_task_definition ,
581577 training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS ,
582578 model_to_upload = true_managed_model ,
583579 input_data_config = true_input_data_config ,
@@ -594,9 +590,8 @@ def test_run_call_pipeline_if_set_additional_experiments(
594590 "mock_pipeline_service_get" ,
595591 "mock_model_service_get" ,
596592 )
593+ @pytest .mark .parametrize ("sync" , [True , False ])
597594 @pytest .mark .parametrize (
598- "sync" ,
599- [True , False ],
600595 "training_job" ,
601596 [
602597 training_jobs .AutoMLForecastingTrainingJob ,
@@ -664,9 +659,8 @@ def test_run_called_twice_raises(
664659 sync = sync ,
665660 )
666661
662+ @pytest .mark .parametrize ("sync" , [True , False ])
667663 @pytest .mark .parametrize (
668- "sync" ,
669- [True , False ],
670664 "training_job" ,
671665 [
672666 training_jobs .AutoMLForecastingTrainingJob ,
@@ -748,9 +742,8 @@ def test_raises_before_run_is_called(
748742 with pytest .raises (RuntimeError ):
749743 job .state
750744
745+ @pytest .mark .parametrize ("sync" , [True , False ])
751746 @pytest .mark .parametrize (
752- "sync" ,
753- [True , False ],
754747 "training_job" ,
755748 [
756749 training_jobs .AutoMLForecastingTrainingJob ,
@@ -830,7 +823,7 @@ def test_splits_fraction(
830823
831824 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
832825 display_name = _TEST_DISPLAY_NAME ,
833- training_task_definition = schema . training_job .definition . automl_forecasting ,
826+ training_task_definition = training_job ._training_task_definition ,
834827 training_task_inputs = _TEST_TRAINING_TASK_INPUTS ,
835828 model_to_upload = true_managed_model ,
836829 input_data_config = true_input_data_config ,
@@ -843,9 +836,8 @@ def test_splits_fraction(
843836 timeout = None ,
844837 )
845838
839+ @pytest .mark .parametrize ("sync" , [True , False ])
846840 @pytest .mark .parametrize (
847- "sync" ,
848- [True , False ],
849841 "training_job" ,
850842 [
851843 training_jobs .AutoMLForecastingTrainingJob ,
@@ -927,9 +919,7 @@ def test_splits_timestamp(
927919
928920 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
929921 display_name = _TEST_DISPLAY_NAME ,
930- training_task_definition = (
931- schema .training_job .definition .automl_forecasting
932- ),
922+ training_task_definition = training_job ._training_task_definition ,
933923 training_task_inputs = _TEST_TRAINING_TASK_INPUTS ,
934924 model_to_upload = true_managed_model ,
935925 input_data_config = true_input_data_config ,
@@ -942,9 +932,8 @@ def test_splits_timestamp(
942932 timeout = None ,
943933 )
944934
935+ @pytest .mark .parametrize ("sync" , [True , False ])
945936 @pytest .mark .parametrize (
946- "sync" ,
947- [True , False ],
948937 "training_job" ,
949938 [
950939 training_jobs .AutoMLForecastingTrainingJob ,
@@ -1020,7 +1009,7 @@ def test_splits_predefined(
10201009
10211010 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
10221011 display_name = _TEST_DISPLAY_NAME ,
1023- training_task_definition = schema . training_job .definition . automl_forecasting ,
1012+ training_task_definition = training_job ._training_task_definition ,
10241013 training_task_inputs = _TEST_TRAINING_TASK_INPUTS ,
10251014 model_to_upload = true_managed_model ,
10261015 input_data_config = true_input_data_config ,
@@ -1033,9 +1022,8 @@ def test_splits_predefined(
10331022 timeout = None ,
10341023 )
10351024
1025+ @pytest .mark .parametrize ("sync" , [True , False ])
10361026 @pytest .mark .parametrize (
1037- "sync" ,
1038- [True , False ],
10391027 "training_job" ,
10401028 [
10411029 training_jobs .AutoMLForecastingTrainingJob ,
@@ -1105,7 +1093,7 @@ def test_splits_default(
11051093
11061094 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
11071095 display_name = _TEST_DISPLAY_NAME ,
1108- training_task_definition = schema . training_job .definition . automl_forecasting ,
1096+ training_task_definition = training_job ._training_task_definition ,
11091097 training_task_inputs = _TEST_TRAINING_TASK_INPUTS ,
11101098 model_to_upload = true_managed_model ,
11111099 input_data_config = true_input_data_config ,
0 commit comments