Skip to content

Commit 13053a7

Browse files
authored
[MoE] fix expert parallel (#9760)
* fix moe uc
1 parent 1afb9b2 commit 13053a7

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

llm/run_finetune.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def main():
151151
quantization_config=quantization_config,
152152
)
153153

154+
if "Qwen2Moe" in str(model_config.architectures) and training_args.data_parallel_degree > 1:
155+
training_args.use_expert_parallel = True
156+
154157
LlmMetaConfig.set_llm_config(model_config, training_args)
155158
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm
156159

llm/run_pretrain.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,9 @@ def main():
478478
except:
479479
print("Not register llama pp reshard information.")
480480

481+
if "Qwen2Moe" in str(config.architectures) and training_args.data_parallel_degree > 1:
482+
training_args.use_expert_parallel = True
483+
481484
if model_args.continue_training:
482485
# NOTE(gongenlei): new add
483486
if training_args.autotuner_benchmark:

paddlenlp/transformers/moe_layer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,14 @@ def __init__(
162162
self.moe_num_experts_per_device = self._parse_moe_expert_parallel(
163163
self.moe_num_experts, self.expert_parallel_degree
164164
)
165+
self.is_dummy_moe = False if self.expert_parallel_degree > 1 else True
165166
else:
166167
# when moe_group is dummy, we don't need to use all_to_all
167168
self.moe_group = None
168169
self.moe_rank = 0
169170
self.expert_parallel_degree = 1
170171
self.moe_num_experts_per_device = self.moe_num_experts
172+
self.is_dummy_moe = True
171173

172174
self.all_to_all_dropout = all_to_all_dropout
173175
self.enable_recompute = False
@@ -181,6 +183,7 @@ def __init__(
181183

182184
self.gate = gate
183185
self.gate.group = self.moe_group
186+
self._post_init()
184187

185188
def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree):
186189
assert (

0 commit comments

Comments
 (0)