Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ message PpConfig {
optional NCCLConfig coll_nccl_config = 15;
optional NCCLConfig p2p_nccl_config = 16;
optional NCCLConfig shared_nccl_config = 17;
optional bool sync_param = 18 [ default = true ];
optional bool sync_moment = 19 [ default = false ];
optional string sync_mode = 20 [ default = 'broadcast' ];
}

message DygraphShardingConfig {
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7905,6 +7905,18 @@ def _copy_to(self, device, blocking):
core.eager.tensor_copy(self, new_param, device, blocking)
return new_param

def __setattr__(self, name, value):
if (
name == 'color'
and hasattr(self, 'color')
and self.color is not None
):
raise AttributeError(
f"Parameter '{self.name}' already has a 'color' attribute (used for distributed sharding parallel grouping) "
f"and cannot be reassigned."
)
super().__setattr__(name, value)

__repr__ = __str__


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

__all__ = []

SHARED_WEIGHT_SYNC_PREFIX = "@SHARED_WEIGHT"


class HybridParallelClipGrad:
def __init__(self, clip, hcg, split_norm_comm=False, timers=None):
Expand Down Expand Up @@ -403,7 +405,15 @@ def _insert_sync(self, sync_var, src, mp_group, sync_mode):
)
)

def _filter_fn(self, param, strategy):
def _pp_filter_fn(self, param):
color = getattr(param, 'color', -1)
if isinstance(color, dict):
color_color = color.get('color', -1)
if SHARED_WEIGHT_SYNC_PREFIX in str(color_color):
return True
return False

def _mp_filter_fn(self, param, strategy):
p_name = param.name
tar_param = strategy.sync_param_name
if param.is_distributed is False:
Expand All @@ -412,12 +422,135 @@ def _filter_fn(self, param, strategy):
return True
return False

def _step(self, parameters_list):
def syc_grad(self, param, src_rank, group, sync_mode):
if hasattr(param, "main_grad") and param.main_grad is not None:
assert param.grad is None
self._insert_sync(param.main_grad, src_rank, group, sync_mode)
elif param.grad is not None:
self._insert_sync(param.grad, src_rank, group, sync_mode)

def syc_param(self, param, src_rank, group, sync_mode):
# Param sync after opt
self._insert_sync(param, src_rank, group, sync_mode)

def syc_master_weight(self, param, src_rank, group, sync_mode):
# Master param sync after opt
if (
hasattr(self._inner_opt, "_multi_precision")
and self._inner_opt._multi_precision
and param.name in self._inner_opt._master_weights
):
self._insert_sync(
self._inner_opt._master_weights[param.name],
src_rank,
group,
sync_mode,
)

def syc_moment(self, param, src_rank, group, sync_mode):
_OPTIMIZER_TYPES = (paddle.optimizer.Adam, paddle.optimizer.AdamW)

def recursive_isinstance(opt):
return isinstance(opt, _OPTIMIZER_TYPES) or (
hasattr(opt, "_inner_opt")
and recursive_isinstance(opt._inner_opt)
)

if recursive_isinstance(self._inner_opt):
if (
param.name
in self._inner_opt._accumulators[
self._inner_opt._moment1_acc_str
]
):
moment1 = self._inner_opt._get_accumulator(
self._inner_opt._moment1_acc_str, param
)
self._insert_sync(moment1, src_rank, group, sync_mode)

if (
param.name
in self._inner_opt._accumulators[
self._inner_opt._moment2_acc_str
]
):
moment2 = self._inner_opt._get_accumulator(
self._inner_opt._moment2_acc_str, param
)
self._insert_sync(moment2, src_rank, group, sync_mode)

def _sync_mp_grads(self, params, mp_configs):
mp_group = self._hcg.get_model_parallel_group()
src_rank = self._hcg.get_model_parallel_group_src_rank()

if self.processed_steps < g_profile_optimizer_details_steps:
get_sync_logger().info("Starting hybridoptimizer step")
get_sync_logger().info("Starting mp grad sync")

# Grad sync before opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad:
for p in params:
self.syc_grad(p, src_rank, mp_group, mp_configs.sync_mode)

if self.processed_steps < g_profile_optimizer_details_steps:
get_sync_logger().info("Finished mp grad sync")

def _sync_mp_params_and_moments(self, params, mp_configs):
mp_group = self._hcg.get_model_parallel_group()
src_rank = self._hcg.get_model_parallel_group_src_rank()

# syc param and master weight after opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param:
for p in params:
self.syc_param(p, src_rank, mp_group, mp_configs.sync_mode)
self.syc_master_weight(
p, src_rank, mp_group, mp_configs.sync_mode
)

# Moment sync after opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment:
for p in params:
self.syc_moment(p, src_rank, mp_group, mp_configs.sync_mode)

def _get_pp_sync_params(self, parameters_list):
pp_group = self._hcg.get_pipe_parallel_group()
params = None
pp_configs = None

if pp_group.nranks > 1:
pp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[
"pp_configs"
]

if pp_configs and (pp_configs.sync_param or pp_configs.sync_moment):
params = sorted(
[p for p in parameters_list if self._pp_filter_fn(p)],
key=lambda p: p.name,
)
return params, pp_configs

def _sync_pp_params_and_moments(self, params, pp_configs):
pp_group = self._hcg.get_pipe_parallel_group()

# syc param and master weight after opt
if pp_group.nranks > 1 and pp_configs and pp_configs.sync_param:
for p in params:
assert hasattr(p, 'color'), f"{p.name} has no color"
color_group = p.color["group"]
src_rank = min(color_group.ranks)
self.syc_param(p, src_rank, color_group, pp_configs.sync_mode)
self.syc_master_weight(
p, src_rank, color_group, pp_configs.sync_mode
)

# Moment sync after opt
if pp_group.nranks > 1 and pp_configs and pp_configs.sync_moment:
for p in params:
color_group = p.color["group"]
src_rank = min(color_group.ranks)
self.syc_moment(p, src_rank, color_group, pp_configs.sync_mode)

def _get_mp_sync_params(self, parameters_list):
mp_group = self._hcg.get_model_parallel_group()
params = None
mp_configs = None

Expand All @@ -435,94 +568,29 @@ def _step(self, parameters_list):
[
p
for p in parameters_list
if self._filter_fn(p, fleet.fleet._user_defined_strategy)
if self._mp_filter_fn(p, fleet.fleet._user_defined_strategy)
],
key=lambda p: p.name,
)
return params, mp_configs

def syc_grad(p):
if hasattr(p, "main_grad") and p.main_grad is not None:
assert p.grad is None
self._insert_sync(
p.main_grad, src_rank, mp_group, mp_configs.sync_mode
)
elif p.grad is not None:
self._insert_sync(
p.grad, src_rank, mp_group, mp_configs.sync_mode
)

def _step(self, parameters_list):
if self.processed_steps < g_profile_optimizer_details_steps:
get_sync_logger().info("Starting mp grad sync")
get_sync_logger().info("Starting hybridoptimizer step")

# Grad sync before opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad:
for p in params:
syc_grad(p)
# Sync non-model-parallel parameters' grads/weights/moments for MP group consistency.
mp_params, mp_configs = self._get_mp_sync_params(parameters_list)
# Sync PP shared params' weights and moments to ensure consistency within the PP group.
# Note: Grads are synced in the pipeline parallel for compatibility.
pp_params, pp_configs = self._get_pp_sync_params(parameters_list)

if self.processed_steps < g_profile_optimizer_details_steps:
get_sync_logger().info("Finished mp grad sync")
self._sync_mp_grads(mp_params, mp_configs)

self._inner_opt.step()

def syc_param(p):
# Param sync after opt
self._insert_sync(p, src_rank, mp_group, mp_configs.sync_mode)
self._sync_mp_params_and_moments(mp_params, mp_configs)
self._sync_pp_params_and_moments(pp_params, pp_configs)

def syc_master_weight(p):
# Master param sync after opt
if (
hasattr(self._inner_opt, "_multi_precision")
and self._inner_opt._multi_precision
and p.name in self._inner_opt._master_weights
):
self._insert_sync(
self._inner_opt._master_weights[p.name],
src_rank,
mp_group,
mp_configs.sync_mode,
)

# syc param and master weight after opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param:
for p in params:
syc_param(p)
syc_master_weight(p)

def syc_moment(p):
if isinstance(
self._inner_opt,
(paddle.optimizer.Adam, paddle.optimizer.AdamW),
):
if (
p.name
in self._inner_opt._accumulators[
self._inner_opt._moment1_acc_str
]
):
moment1 = self._inner_opt._get_accumulator(
self._inner_opt._moment1_acc_str, p
)
self._insert_sync(
moment1, src_rank, mp_group, mp_configs.sync_mode
)

if (
p.name
in self._inner_opt._accumulators[
self._inner_opt._moment2_acc_str
]
):
moment2 = self._inner_opt._get_accumulator(
self._inner_opt._moment2_acc_str, p
)
self._insert_sync(
moment2, src_rank, mp_group, mp_configs.sync_mode
)

# Moment sync after opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment:
for p in params:
syc_moment(p)
if self.processed_steps < g_profile_optimizer_details_steps:
get_sync_logger().info("Finishing hybridoptimizer step")
self.processed_steps += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
import paddle.distributed as dist
from paddle import framework, nn
from paddle.device.cuda.cuda_graphed_layer import CUDAGraphedLayer
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
SHARED_WEIGHT_SYNC_PREFIX,
)
from paddle.distributed.fleet.utils.log_util import layer_to_str, logger
from paddle.framework import core
from paddle.incubate.distributed.fleet import recompute_hybrid
Expand Down Expand Up @@ -693,9 +696,20 @@ def _construct_shared_comm(self):
shared_comm[comm_key] = {
"ranks": shared_ranks,
"group": group,
"weight_attr": comm_key_to_shared_attrs[comm_key],
"weight_attr": shared_attrs,
"layer": self.shared_layers[layer_name],
}

# Set color for shared parameters to facilitate synchronization of parameters
# and optimizer states after each step
for weight_attr in shared_attrs:
shared_param = getattr(
self.shared_layers[layer_name], weight_attr
)
shared_param.color = {
"color": f"{SHARED_WEIGHT_SYNC_PREFIX}_{comm_key}",
"group": group,
}
return shared_comm

def _synchronize_shared_weights(self):
Expand Down
25 changes: 25 additions & 0 deletions test/collective/fleet/hybrid_parallel_shared_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,30 @@ def test_pp_model(self):
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())


class TestDistEmbeddingTrainingWithSync(TestDistEmbeddingTraining):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size,
}
strategy.hybrid_configs["pp_configs"].clear_every_step_cache = True
strategy.hybrid_configs["pp_configs"].sync_moment = True
strategy.hybrid_configs["pp_configs"].sync_param = True

fleet.init(is_collective=True, strategy=strategy)

def test_pp_model(self):
super().test_pp_model()


if __name__ == "__main__":
unittest.main()