@@ -1061,6 +1061,9 @@ def from_local_script(
|
1061 | 1061 | accelerator_count: int = 0,
|
1062 | 1062 | boot_disk_type: str = "pd-ssd",
|
1063 | 1063 | boot_disk_size_gb: int = 100,
|
| 1064 | +reduction_server_replica_count: int = 0, |
| 1065 | +reduction_server_machine_type: Optional[str] = None, |
| 1066 | +reduction_server_container_uri: Optional[str] = None, |
1064 | 1067 | base_output_dir: Optional[str] = None,
|
1065 | 1068 | project: Optional[str] = None,
|
1066 | 1069 | location: Optional[str] = None,
|
@@ -1127,6 +1130,13 @@ def from_local_script(
|
1127 | 1130 | boot_disk_size_gb (int):
|
1128 | 1131 | Optional. Size in GB of the boot disk, default is 100GB.
|
1129 | 1132 | boot disk size must be within the range of [100, 64000].
|
| 1133 | +reduction_server_replica_count (int): |
| 1134 | +The number of reduction server replicas, default is 0. |
| 1135 | +reduction_server_machine_type (str): |
| 1136 | +Optional. The type of machine to use for reduction server. |
| 1137 | +reduction_server_container_uri (str): |
| 1138 | +Optional. The Uri of the reduction server container image. |
| 1139 | +See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server |
1130 | 1140 | base_output_dir (str):
|
1131 | 1141 | Optional. GCS output directory of job. If not provided a
|
1132 | 1142 | timestamped directory in the staging directory will be used.
|
@@ -1181,6 +1191,8 @@ def from_local_script(
|
1181 | 1191 | accelerator_type=accelerator_type,
|
1182 | 1192 | boot_disk_type=boot_disk_type,
|
1183 | 1193 | boot_disk_size_gb=boot_disk_size_gb,
|
| 1194 | +reduction_server_replica_count=reduction_server_replica_count, |
| 1195 | +reduction_server_machine_type=reduction_server_machine_type, |
1184 | 1196 | ).pool_specs
|
1185 | 1197 |
|
1186 | 1198 | python_packager = source_utils._TrainingScriptPythonPackager(
|
@@ -1191,21 +1203,33 @@ def from_local_script(
|
1191 | 1203 | gcs_staging_dir=staging_bucket, project=project, credentials=credentials,
|
1192 | 1204 | )
|
1193 | 1205 |
|
1194 |
| -for spec in worker_pool_specs: |
1195 |
| -spec["python_package_spec"] = { |
1196 |
| -"executor_image_uri": container_uri, |
1197 |
| -"python_module": python_packager.module_name, |
1198 |
| -"package_uris": [package_gcs_uri], |
1199 |
| -} |
1200 |
| - |
1201 |
| -if args: |
1202 |
| -spec["python_package_spec"]["args"] = args |
1203 |
| - |
1204 |
| -if environment_variables: |
1205 |
| -spec["python_package_spec"]["env"] = [ |
1206 |
| -{"name": key, "value": value} |
1207 |
| -for key, value in environment_variables.items() |
1208 |
| -] |
| 1206 | +for spec_order, spec in enumerate(worker_pool_specs): |
| 1207 | + |
| 1208 | +if not spec: |
| 1209 | +continue |
| 1210 | + |
| 1211 | +if ( |
| 1212 | +spec_order == worker_spec_utils._SPEC_ORDERS["server_spec"] |
| 1213 | +and reduction_server_replica_count > 0 |
| 1214 | +): |
| 1215 | +spec["container_spec"] = { |
| 1216 | +"image_uri": reduction_server_container_uri, |
| 1217 | +} |
| 1218 | +else: |
| 1219 | +spec["python_package_spec"] = { |
| 1220 | +"executor_image_uri": container_uri, |
| 1221 | +"python_module": python_packager.module_name, |
| 1222 | +"package_uris": [package_gcs_uri], |
| 1223 | +} |
| 1224 | + |
| 1225 | +if args: |
| 1226 | +spec["python_package_spec"]["args"] = args |
| 1227 | + |
| 1228 | +if environment_variables: |
| 1229 | +spec["python_package_spec"]["env"] = [ |
| 1230 | +{"name": key, "value": value} |
| 1231 | +for key, value in environment_variables.items() |
| 1232 | +] |
1209 | 1233 |
|
1210 | 1234 | return cls(
|
1211 | 1235 | display_name=display_name,
|
|
0 commit comments