File tree

5 files changed

+574
-69
lines changed

5 files changed

+574
-69
lines changed
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,9 @@ def from_local_script(
10611061
accelerator_count: int = 0,
10621062
boot_disk_type: str = "pd-ssd",
10631063
boot_disk_size_gb: int = 100,
1064+
reduction_server_replica_count: int = 0,
1065+
reduction_server_machine_type: Optional[str] = None,
1066+
reduction_server_container_uri: Optional[str] = None,
10641067
base_output_dir: Optional[str] = None,
10651068
project: Optional[str] = None,
10661069
location: Optional[str] = None,
@@ -1127,6 +1130,13 @@ def from_local_script(
11271130
boot_disk_size_gb (int):
11281131
Optional. Size in GB of the boot disk, default is 100GB.
11291132
boot disk size must be within the range of [100, 64000].
1133+
reduction_server_replica_count (int):
1134+
The number of reduction server replicas, default is 0.
1135+
reduction_server_machine_type (str):
1136+
Optional. The type of machine to use for reduction server.
1137+
reduction_server_container_uri (str):
1138+
Optional. The Uri of the reduction server container image.
1139+
See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server
11301140
base_output_dir (str):
11311141
Optional. GCS output directory of job. If not provided a
11321142
timestamped directory in the staging directory will be used.
@@ -1181,6 +1191,8 @@ def from_local_script(
11811191
accelerator_type=accelerator_type,
11821192
boot_disk_type=boot_disk_type,
11831193
boot_disk_size_gb=boot_disk_size_gb,
1194+
reduction_server_replica_count=reduction_server_replica_count,
1195+
reduction_server_machine_type=reduction_server_machine_type,
11841196
).pool_specs
11851197

11861198
python_packager = source_utils._TrainingScriptPythonPackager(
@@ -1191,21 +1203,33 @@ def from_local_script(
11911203
gcs_staging_dir=staging_bucket, project=project, credentials=credentials,
11921204
)
11931205

1194-
for spec in worker_pool_specs:
1195-
spec["python_package_spec"] = {
1196-
"executor_image_uri": container_uri,
1197-
"python_module": python_packager.module_name,
1198-
"package_uris": [package_gcs_uri],
1199-
}
1200-
1201-
if args:
1202-
spec["python_package_spec"]["args"] = args
1203-
1204-
if environment_variables:
1205-
spec["python_package_spec"]["env"] = [
1206-
{"name": key, "value": value}
1207-
for key, value in environment_variables.items()
1208-
]
1206+
for spec_order, spec in enumerate(worker_pool_specs):
1207+
1208+
if not spec:
1209+
continue
1210+
1211+
if (
1212+
spec_order == worker_spec_utils._SPEC_ORDERS["server_spec"]
1213+
and reduction_server_replica_count > 0
1214+
):
1215+
spec["container_spec"] = {
1216+
"image_uri": reduction_server_container_uri,
1217+
}
1218+
else:
1219+
spec["python_package_spec"] = {
1220+
"executor_image_uri": container_uri,
1221+
"python_module": python_packager.module_name,
1222+
"package_uris": [package_gcs_uri],
1223+
}
1224+
1225+
if args:
1226+
spec["python_package_spec"]["args"] = args
1227+
1228+
if environment_variables:
1229+
spec["python_package_spec"]["env"] = [
1230+
{"name": key, "value": value}
1231+
for key, value in environment_variables.items()
1232+
]
12091233

12101234
return cls(
12111235
display_name=display_name,

0 commit comments

Comments
 (0)