Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -534,20 +534,29 @@ def _sync_pp_params_and_moments(self, params, pp_configs):
# syc param and master weight after opt
if pp_group.nranks > 1 and pp_configs and pp_configs.sync_param:
for p in params:
assert hasattr(p, 'color'), f"{p.name} has no color"
color_group = p.color["group"]
src_rank = min(color_group.ranks)
self.syc_param(p, src_rank, color_group, pp_configs.sync_mode)
assert (
hasattr(p, 'color') and 'broadcast_group' in p.color
), f"{p.name} has no color"
broadcast_group = p.color["broadcast_group"]
src_rank = min(broadcast_group.ranks)
self.syc_param(
p, src_rank, broadcast_group, pp_configs.sync_mode
)
self.syc_master_weight(
p, src_rank, color_group, pp_configs.sync_mode
p, src_rank, broadcast_group, pp_configs.sync_mode
)

# Moment sync after opt
if pp_group.nranks > 1 and pp_configs and pp_configs.sync_moment:
for p in params:
color_group = p.color["group"]
src_rank = min(color_group.ranks)
self.syc_moment(p, src_rank, color_group, pp_configs.sync_mode)
assert (
hasattr(p, 'color') and 'broadcast_group' in p.color
), f"{p.name} has no color"
broadcast_group = p.color["broadcast_group"]
src_rank = min(broadcast_group.ranks)
self.syc_moment(
p, src_rank, broadcast_group, pp_configs.sync_mode
)

def _get_mp_sync_params(self, parameters_list):
mp_group = self._hcg.get_model_parallel_group()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,16 +700,22 @@ def _construct_shared_comm(self):
"layer": self.shared_layers[layer_name],
}

# Set color for shared parameters to facilitate synchronization of parameters
# and optimizer states after each step
for weight_attr in shared_attrs:
shared_param = getattr(
self.shared_layers[layer_name], weight_attr
)
shared_param.color = {
"color": f"{SHARED_WEIGHT_SYNC_PREFIX}_{comm_key}",
"group": group,
}
if (
hybrid_configs["pp_configs"].sync_moment
or hybrid_configs["pp_configs"].sync_param
):
# Set color for shared parameters to facilitate synchronization of parameters
# and optimizer states after each step
for weight_attr in shared_attrs:
shared_param = getattr(
self.shared_layers[layer_name], weight_attr
)
hcg = fleet.get_hybrid_communicate_group()
shared_param.color = {
"color": f"{SHARED_WEIGHT_SYNC_PREFIX}_{comm_key}",
"group": hcg.get_sharding_parallel_group(),
"broadcast_group": group,
}
return shared_comm

def _synchronize_shared_weights(self):
Expand Down
Loading