File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
paddlenlp/experimental/transformers Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -1299,8 +1299,12 @@ def init_weight_shape(self, config):
12991299 self .moe_ffn2_weight_shape = [self .config .moe_config .num_experts , self .dim_feedforward , self .embed_dim ]
13001300
13011301 if config .quant_type == "weight_only_int4" :
1302- self .moe_ffn1_weight_shape [2 ] //= 2
1303- self .moe_ffn2_weight_shape [2 ] //= 2
1302+ if config .moe_config .has_shared_expert ():
1303+ self .moe_ffn1_weight_shape [2 ] //= 2
1304+ self .moe_ffn2_weight_shape [1 ] //= 2
1305+ else :
1306+ self .moe_ffn1_weight_shape [2 ] //= 2
1307+ self .moe_ffn2_weight_shape [2 ] //= 2
13041308
13051309 if self .config .moe_config .has_shared_expert ():
13061310 self .shared_expert_ffn1_weight_shape = [
@@ -1315,6 +1319,9 @@ def init_weight_shape(self, config):
13151319 self .embed_dim ,
13161320 1 ,
13171321 ]
1322+ if config .quant_type == "weight_only_int4" :
1323+ self .shared_expert_ffn1_weight_shape [0 ] //= 2
1324+ self .shared_expert_ffn2_weight_shape [0 ] //= 2
13181325
13191326 def compute_qkv_linear (self , ln_out , i ):
13201327 return weight_only_linear (
You can’t perform that action at this time.
0 commit comments