Skip to content

Commit 161fb67

Browse files
authored
Fix moe optimizer broadcast (#8813)
1 parent 157f7d3 commit 161fb67

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2512,7 +2512,9 @@ def _load_optimizer_and_scheduler(self, checkpoint):
25122512
dist.barrier()
25132513
if self.args.use_expert_parallel:
25142514
opt_state_dict = broadcast_moe_optimizer(
2515-
opt_state_dict, broadcast_dp=not self.args.should_load_sharding_stage1_model
2515+
opt_state_dict,
2516+
model_state_dict=self.model.state_dict(),
2517+
broadcast_dp=not self.args.should_load_sharding_stage1_model,
25162518
)
25172519
else:
25182520
if not self.args.should_load_sharding_stage1_model:

paddlenlp/trainer/utils/helper.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def broadcast_dp_optimizer(state_dict):
229229
return state_dict
230230

231231

232-
def broadcast_moe_optimizer(state_dict, broadcast_dp=True):
232+
def broadcast_moe_optimizer(state_dict, model_state_dict=None, broadcast_dp=True):
233233

234234
try:
235235
hcg = fleet.get_hybrid_communicate_group()
@@ -242,7 +242,29 @@ def broadcast_moe_optimizer(state_dict, broadcast_dp=True):
242242
except:
243243
dp_group = None
244244
src_rank = 0
245-
data_parallel_rank = 0
245+
data_parallel_rank = dist.get_rank()
246+
247+
def _filter_sync_optimizer_state(model_state_dict, opt_state_dict):
248+
# get sync name
249+
sync_vname = []
250+
for k, v in model_state_dict.items():
251+
if not getattr(v, "no_sync", False):
252+
sync_vname.append(v.name)
253+
254+
filter_opt_state_dict = {"master_weights": {}}
255+
filter_opt_state_dict["LR_Scheduler"] = opt_state_dict.get("LR_Scheduler", {})
256+
for op_k, op_v in opt_state_dict.items():
257+
if op_k not in ["master_weights", "LR_Scheduler"]:
258+
for sync_v in sync_vname:
259+
if op_k.startswith(sync_v):
260+
filter_opt_state_dict[op_k] = op_v
261+
break
262+
elif op_k == "master_weights":
263+
for k, v in op_v.items():
264+
for sync_v in sync_vname:
265+
if k.startswith(sync_v):
266+
filter_opt_state_dict["master_weights"][k] = v
267+
return filter_opt_state_dict
246268

247269
def _broadcast_moe_optimizer_state(state_dict):
248270
# boardcast_keys
@@ -272,9 +294,11 @@ def _broadcast_moe_optimizer_state(state_dict):
272294
return base_state_dict
273295

274296
if broadcast_dp:
275-
base_state_dict = broadcast_dp_optimizer(state_dict)
297+
filter_opt_state_dict = _filter_sync_optimizer_state(model_state_dict, state_dict)
298+
base_state_dict = broadcast_dp_optimizer(filter_opt_state_dict)
276299
else:
277300
base_state_dict = _broadcast_moe_optimizer_state(state_dict)
301+
278302
if data_parallel_rank > 0:
279303
master_weight = state_dict.pop("master_weights", {})
280304
base_state_dict.update(state_dict)

paddlenlp/transformers/conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,7 +1284,7 @@ def merge_tensor_parallel(cls, state_dict, config) -> None:
12841284

12851285
if len(name_action_mappings) > 0:
12861286
for x in name_action_mappings.keys():
1287-
logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.")
1287+
logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.")
12881288

12891289
return state_dict_to_save
12901290

@@ -1318,7 +1318,7 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):
13181318
break
13191319
if key not in state_keys_map:
13201320
if not ignore_error:
1321-
logger.error(f"tensor parallel conversion: could not find name {key} in loaded state dict!")
1321+
logger.debug(f"tensor parallel conversion: could not find name {key} in loaded state dict!")
13221322
else:
13231323
state_keys_real.remove(state_keys_map[key])
13241324

0 commit comments

Comments
 (0)