@@ -1061,6 +1061,9 @@ def from_local_script(
10611061 accelerator_count : int = 0 ,
10621062 boot_disk_type : str = "pd-ssd" ,
10631063 boot_disk_size_gb : int = 100 ,
1064+ reduction_server_replica_count : Optional [int ] = 0 ,
1065+ reduction_server_machine_type : Optional [str ] = None ,
1066+ reduction_server_container_uri : Optional [str ] = None ,
10641067 base_output_dir : Optional [str ] = None ,
10651068 project : Optional [str ] = None ,
10661069 location : Optional [str ] = None ,
@@ -1127,6 +1130,13 @@ def from_local_script(
11271130 boot_disk_size_gb (int):
11281131 Optional. Size in GB of the boot disk, default is 100GB.
11291132 boot disk size must be within the range of [100, 64000].
1133+ reduction_server_replica_count (int):
1134+ Optional. The number of reduction server replicas.
1135+ reduction_server_machine_type (str):
1136+ Optional. The type of machine to use for reduction server.
1137+ reduction_server_container_uri (str):
1138+ Optional. The Uri of the reduction server container image.
1139+ See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server
11301140 base_output_dir (str):
11311141 Optional. GCS output directory of job. If not provided a
11321142 timestamped directory in the staging directory will be used.
@@ -1181,6 +1191,8 @@ def from_local_script(
11811191 accelerator_type = accelerator_type ,
11821192 boot_disk_type = boot_disk_type ,
11831193 boot_disk_size_gb = boot_disk_size_gb ,
1194+ reduction_server_replica_count = reduction_server_replica_count ,
1195+ reduction_server_machine_type = reduction_server_machine_type ,
11841196 ).pool_specs
11851197
11861198 python_packager = source_utils ._TrainingScriptPythonPackager (
@@ -1191,21 +1203,33 @@ def from_local_script(
11911203 gcs_staging_dir = staging_bucket , project = project , credentials = credentials ,
11921204 )
11931205
1194- for spec in worker_pool_specs :
1195- spec ["python_package_spec" ] = {
1196- "executor_image_uri" : container_uri ,
1197- "python_module" : python_packager .module_name ,
1198- "package_uris" : [package_gcs_uri ],
1199- }
1200-
1201- if args :
1202- spec ["python_package_spec" ]["args" ] = args
1203-
1204- if environment_variables :
1205- spec ["python_package_spec" ]["env" ] = [
1206- {"name" : key , "value" : value }
1207- for key , value in environment_variables .items ()
1208- ]
1206+ for spec_order , spec in enumerate (worker_pool_specs ):
1207+
1208+ if not spec :
1209+ continue
1210+
1211+ if (
1212+ spec_order == worker_spec_utils .SPEC_ORDERS ["server_spec" ]
1213+ and reduction_server_replica_count > 0
1214+ ):
1215+ spec ["container_spec" ] = {
1216+ "image_uri" : reduction_server_container_uri ,
1217+ }
1218+ else :
1219+ spec ["python_package_spec" ] = {
1220+ "executor_image_uri" : container_uri ,
1221+ "python_module" : python_packager .module_name ,
1222+ "package_uris" : [package_gcs_uri ],
1223+ }
1224+
1225+ if args :
1226+ spec ["python_package_spec" ]["args" ] = args
1227+
1228+ if environment_variables :
1229+ spec ["python_package_spec" ]["env" ] = [
1230+ {"name" : key , "value" : value }
1231+ for key , value in environment_variables .items ()
1232+ ]
12091233
12101234 return cls (
12111235 display_name = display_name ,
0 commit comments