Skip to content

Commit efd07c0

Browse files
authored
[Distributed] support fuse optimizer (#9519) (#9777)
1 parent 7b53fec commit efd07c0

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class TrainingArguments:
293293
enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap.
294294
disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification.
295295
enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
296+
enable_fuse_optimizer_states, fuse optimizer states to a single storage.
296297
recompute (`bool`, *optional*, defaults to `False`):
297298
Recompute the forward pass to calculate gradients. Used for saving memory.
298299
Only support for networks with transformer blocks.
@@ -1412,10 +1413,11 @@ def is_segment_parallel_supported():
14121413
"enable_stage1_broadcast_overlap",
14131414
"enable_stage1_allgather_overlap",
14141415
"enable_release_grads",
1416+
"enable_fuse_optimizer_states",
14151417
]:
14161418
raise ValueError(
1417-
f"Found unknown pipeline mode config {x}, "
1418-
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap, enable_stage2_overlap, split_param, disable_stage1_reduce_avg, enable_stage1_broadcast_overlap, enable_stage1_allgather_overlap."
1419+
f"Found unknown sharding mode config {x}, "
1420+
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap, enable_stage2_overlap, split_param, disable_stage1_reduce_avg, enable_stage1_broadcast_overlap, enable_stage1_allgather_overlap, enable_release_grads, enable_fuse_optimizer_states."
14191421
)
14201422
if "disable_stage1_reduce_avg" in sharding_parallel_config:
14211423
assert self.sharding == [
@@ -1441,6 +1443,9 @@ def is_segment_parallel_supported():
14411443
if "enable_release_grads" in sharding_parallel_config:
14421444
strategy.hybrid_configs["sharding_configs"].release_gradients = True
14431445

1446+
if "enable_fuse_optimizer_states" in sharding_parallel_config:
1447+
strategy.hybrid_configs["sharding_configs"].enable_fuse_optimizer_states = True
1448+
14441449
if self.pipeline_parallel_degree == 1:
14451450
strategy.hybrid_configs["sharding_configs"].tensor_fusion = (
14461451
True if "enable_stage1_tensor_fusion" in sharding_parallel_config else False

0 commit comments

Comments
 (0)