Skip to content

Commit c24251f

Browse files
authored
fix: change default replica count to 1 for custom training job classes (#579)
1 parent 6a99b12 commit c24251f

File tree

2 files changed

+4
-17
lines changed

2 files changed

+4
-17
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ def network(self) -> Optional[str]:
11071107
def _prepare_and_validate_run(
11081108
self,
11091109
model_display_name: Optional[str] = None,
1110-
replica_count: int = 0,
1110+
replica_count: int = 1,
11111111
machine_type: str = "n1-standard-4",
11121112
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
11131113
accelerator_count: int = 0,
@@ -1521,7 +1521,7 @@ def run(
15211521
bigquery_destination: Optional[str] = None,
15221522
args: Optional[List[Union[str, float, int]]] = None,
15231523
environment_variables: Optional[Dict[str, str]] = None,
1524-
replica_count: int = 0,
1524+
replica_count: int = 1,
15251525
machine_type: str = "n1-standard-4",
15261526
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
15271527
accelerator_count: int = 0,
@@ -2143,7 +2143,7 @@ def run(
21432143
bigquery_destination: Optional[str] = None,
21442144
args: Optional[List[Union[str, float, int]]] = None,
21452145
environment_variables: Optional[Dict[str, str]] = None,
2146-
replica_count: int = 0,
2146+
replica_count: int = 1,
21472147
machine_type: str = "n1-standard-4",
21482148
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
21492149
accelerator_count: int = 0,
@@ -4095,7 +4095,7 @@ def run(
40954095
bigquery_destination: Optional[str] = None,
40964096
args: Optional[List[Union[str, float, int]]] = None,
40974097
environment_variables: Optional[Dict[str, str]] = None,
4098-
replica_count: int = 0,
4098+
replica_count: int = 1,
40994099
machine_type: str = "n1-standard-4",
41004100
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
41014101
accelerator_count: int = 0,

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
652652
network=_TEST_NETWORK,
653653
args=_TEST_RUN_ARGS,
654654
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
655-
replica_count=1,
656655
machine_type=_TEST_MACHINE_TYPE,
657656
accelerator_type=_TEST_ACCELERATOR_TYPE,
658657
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -825,7 +824,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
825824
bigquery_destination=_TEST_BIGQUERY_DESTINATION,
826825
args=_TEST_RUN_ARGS,
827826
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
828-
replica_count=1,
829827
machine_type=_TEST_MACHINE_TYPE,
830828
accelerator_type=_TEST_ACCELERATOR_TYPE,
831829
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -1099,7 +1097,6 @@ def test_run_call_pipeline_service_create_with_no_dataset(
10991097
base_output_dir=_TEST_BASE_OUTPUT_DIR,
11001098
args=_TEST_RUN_ARGS,
11011099
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
1102-
replica_count=1,
11031100
machine_type=_TEST_MACHINE_TYPE,
11041101
accelerator_type=_TEST_ACCELERATOR_TYPE,
11051102
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -1628,7 +1625,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
16281625
annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI,
16291626
base_output_dir=_TEST_BASE_OUTPUT_DIR,
16301627
args=_TEST_RUN_ARGS,
1631-
replica_count=1,
16321628
machine_type=_TEST_MACHINE_TYPE,
16331629
accelerator_type=_TEST_ACCELERATOR_TYPE,
16341630
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -1870,7 +1866,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
18701866
base_output_dir=_TEST_BASE_OUTPUT_DIR,
18711867
args=_TEST_RUN_ARGS,
18721868
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
1873-
replica_count=1,
18741869
machine_type=_TEST_MACHINE_TYPE,
18751870
accelerator_type=_TEST_ACCELERATOR_TYPE,
18761871
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -2032,7 +2027,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
20322027
base_output_dir=_TEST_BASE_OUTPUT_DIR,
20332028
bigquery_destination=_TEST_BIGQUERY_DESTINATION,
20342029
args=_TEST_RUN_ARGS,
2035-
replica_count=1,
20362030
machine_type=_TEST_MACHINE_TYPE,
20372031
accelerator_type=_TEST_ACCELERATOR_TYPE,
20382032
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -2294,7 +2288,6 @@ def test_run_call_pipeline_service_create_with_no_dataset(
22942288
model_from_job = job.run(
22952289
base_output_dir=_TEST_BASE_OUTPUT_DIR,
22962290
args=_TEST_RUN_ARGS,
2297-
replica_count=1,
22982291
machine_type=_TEST_MACHINE_TYPE,
22992292
accelerator_type=_TEST_ACCELERATOR_TYPE,
23002293
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -2674,7 +2667,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
26742667
service_account=_TEST_SERVICE_ACCOUNT,
26752668
network=_TEST_NETWORK,
26762669
args=_TEST_RUN_ARGS,
2677-
replica_count=1,
26782670
machine_type=_TEST_MACHINE_TYPE,
26792671
accelerator_type=_TEST_ACCELERATOR_TYPE,
26802672
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -3112,7 +3104,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
31123104
network=_TEST_NETWORK,
31133105
args=_TEST_RUN_ARGS,
31143106
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
3115-
replica_count=1,
31163107
machine_type=_TEST_MACHINE_TYPE,
31173108
accelerator_type=_TEST_ACCELERATOR_TYPE,
31183109
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -3273,7 +3264,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_dis
32733264
# model_display_name=_TEST_MODEL_DISPLAY_NAME,
32743265
base_output_dir=_TEST_BASE_OUTPUT_DIR,
32753266
args=_TEST_RUN_ARGS,
3276-
replica_count=1,
32773267
machine_type=_TEST_MACHINE_TYPE,
32783268
accelerator_type=_TEST_ACCELERATOR_TYPE,
32793269
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -3426,7 +3416,6 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
34263416
base_output_dir=_TEST_BASE_OUTPUT_DIR,
34273417
bigquery_destination=_TEST_BIGQUERY_DESTINATION,
34283418
args=_TEST_RUN_ARGS,
3429-
replica_count=1,
34303419
machine_type=_TEST_MACHINE_TYPE,
34313420
accelerator_type=_TEST_ACCELERATOR_TYPE,
34323421
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -3693,7 +3682,6 @@ def test_run_call_pipeline_service_create_with_no_dataset(
36933682
model_display_name=_TEST_MODEL_DISPLAY_NAME,
36943683
base_output_dir=_TEST_BASE_OUTPUT_DIR,
36953684
args=_TEST_RUN_ARGS,
3696-
replica_count=1,
36973685
machine_type=_TEST_MACHINE_TYPE,
36983686
accelerator_type=_TEST_ACCELERATOR_TYPE,
36993687
accelerator_count=_TEST_ACCELERATOR_COUNT,
@@ -4080,7 +4068,6 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
40804068
annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI,
40814069
base_output_dir=_TEST_BASE_OUTPUT_DIR,
40824070
args=_TEST_RUN_ARGS,
4083-
replica_count=1,
40844071
machine_type=_TEST_MACHINE_TYPE,
40854072
accelerator_type=_TEST_ACCELERATOR_TYPE,
40864073
accelerator_count=_TEST_ACCELERATOR_COUNT,

0 commit comments

Comments
 (0)