@@ -229,7 +229,7 @@ def broadcast_dp_optimizer(state_dict):
229229 return state_dict
230230
231231
232- def broadcast_moe_optimizer (state_dict , broadcast_dp = True ):
232+ def broadcast_moe_optimizer (state_dict , model_state_dict = None , broadcast_dp = True ):
233233
234234 try :
235235 hcg = fleet .get_hybrid_communicate_group ()
@@ -242,7 +242,29 @@ def broadcast_moe_optimizer(state_dict, broadcast_dp=True):
242242 except :
243243 dp_group = None
244244 src_rank = 0
245- data_parallel_rank = 0
245+ data_parallel_rank = dist .get_rank ()
246+
247+ def _filter_sync_optimizer_state (model_state_dict , opt_state_dict ):
248+ # get sync name
249+ sync_vname = []
250+ for k , v in model_state_dict .items ():
251+ if not getattr (v , "no_sync" , False ):
252+ sync_vname .append (v .name )
253+
254+ filter_opt_state_dict = {"master_weights" : {}}
255+ filter_opt_state_dict ["LR_Scheduler" ] = opt_state_dict .get ("LR_Scheduler" , {})
256+ for op_k , op_v in opt_state_dict .items ():
257+ if op_k not in ["master_weights" , "LR_Scheduler" ]:
258+ for sync_v in sync_vname :
259+ if op_k .startswith (sync_v ):
260+ filter_opt_state_dict [op_k ] = op_v
261+ break
262+ elif op_k == "master_weights" :
263+ for k , v in op_v .items ():
264+ for sync_v in sync_vname :
265+ if k .startswith (sync_v ):
266+ filter_opt_state_dict ["master_weights" ][k ] = v
267+ return filter_opt_state_dict
246268
247269 def _broadcast_moe_optimizer_state (state_dict ):
248270 # boardcast_keys
@@ -272,9 +294,11 @@ def _broadcast_moe_optimizer_state(state_dict):
272294 return base_state_dict
273295
274296 if broadcast_dp :
275- base_state_dict = broadcast_dp_optimizer (state_dict )
297+ filter_opt_state_dict = _filter_sync_optimizer_state (model_state_dict , state_dict )
298+ base_state_dict = broadcast_dp_optimizer (filter_opt_state_dict )
276299 else :
277300 base_state_dict = _broadcast_moe_optimizer_state (state_dict )
301+
278302 if data_parallel_rank > 0 :
279303 master_weight = state_dict .pop ("master_weights" , {})
280304 base_state_dict .update (state_dict )
0 commit comments