1515from vllm .lora .layers .base import BaseLayerWithLoRA
1616from vllm .model_executor .layers .fused_moe import FusedMoE
1717from vllm .model_executor .layers .fused_moe .config import (
18- FUSED_MOE_UNQUANTIZED_CONFIG ,
1918 _get_config_dtype_str ,
20- mxfp4_w4a16_moe_quant_config ,
2119)
2220from vllm .model_executor .layers .fused_moe .fused_marlin_moe import (
2321 modular_marlin_fused_moe ,
2624 modular_triton_fused_moe ,
2725 try_get_optimal_moe_config ,
2826)
29- from vllm .model_executor .layers .quantization .mxfp4 import Mxfp4Config
3027
3128
3229class FusedMoEWithLoRA (BaseLayerWithLoRA ):
3330 def __init__ (self , base_layer : FusedMoE ) -> None :
3431 super ().__init__ ()
3532 self .base_layer = base_layer
33+
34+ assert not self .base_layer .use_ep , (
35+ "EP support for Fused MoE LoRA is not implemented yet."
36+ )
3637 self .tp_size = get_tensor_model_parallel_world_size ()
3738 self .tp_rank = get_tensor_model_parallel_rank ()
3839 self .device = base_layer .w2_weight .device
@@ -42,17 +43,8 @@ def _inject_lora_into_fused_moe(self):
4243 moe_state_dict = {}
4344 top_k = self .base_layer .top_k
4445
45- if self .base_layer .quant_config is None :
46- quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
47- elif not isinstance (self .base_layer .quant_config , Mxfp4Config ):
48- quant_config = self .base_layer .quant_config
49- else :
50- quant_config = mxfp4_w4a16_moe_quant_config (
51- w1_bias = self .base_layer .w13_bias ,
52- w2_bias = self .base_layer .w2_bias ,
53- w1_scale = self .base_layer .w13_weight_scale ,
54- w2_scale = self .base_layer .w2_weight_scale ,
55- )
46+ self .base_layer .ensure_moe_quant_config_init ()
47+ quant_config = self .base_layer .quant_method .moe_quant_config
5648
5749 m_fused_moe_fn = (
5850 modular_triton_fused_moe (
@@ -69,7 +61,6 @@ def wrapper(*args, **kwargs):
6961 moe_state_dict ["hidden_states" ] = kwargs ["hidden_states" ]
7062 moe_state_dict ["topk_ids" ] = kwargs ["topk_ids" ]
7163 moe_state_dict ["topk_weights" ] = kwargs ["topk_weights" ]
72- moe_state_dict ["global_num_experts" ] = kwargs ["global_num_experts" ]
7364 moe_state_dict ["expert_map" ] = kwargs ["expert_map" ]
7465 moe_state_dict ["apply_router_weight_on_input" ] = kwargs [
7566 "apply_router_weight_on_input"
@@ -86,7 +77,7 @@ def wrapper(*args, **kwargs):
8677 hidden_states = moe_state_dict ["hidden_states" ]
8778 topk_weights = moe_state_dict ["topk_weights" ]
8879 curr_topk_ids = moe_state_dict ["topk_ids" ]
89- global_num_experts = moe_state_dict [ "global_num_experts" ]
80+
9081 expert_map = moe_state_dict ["expert_map" ]
9182
9283 config_dtype = _get_config_dtype_str (
@@ -118,7 +109,7 @@ def wrapper(*args, **kwargs):
118109 curr_topk_ids ,
119110 num_tokens ,
120111 config ["BLOCK_SIZE_M" ],
121- global_num_experts ,
112+ self . base_layer . local_num_experts ,
122113 max_loras ,
123114 expert_map ,
124115 )
@@ -236,14 +227,10 @@ def create_lora_weights(
236227 ) -> None :
237228 """Initializes lora matrices."""
238229
239- assert not self .base_layer .use_ep , (
240- "EP support for Fused MoE LoRA is not implemented yet."
241- )
242-
243230 self .w1_lora_a_stacked = torch .zeros (
244231 (
245232 max_loras ,
246- self .base_layer .global_num_experts ,
233+ self .base_layer .local_num_experts ,
247234 lora_config .max_lora_rank ,
248235 self .base_layer .hidden_size ,
249236 ),
@@ -253,7 +240,7 @@ def create_lora_weights(
253240 self .w1_lora_b_stacked = torch .zeros (
254241 (
255242 max_loras ,
256- self .base_layer .global_num_experts ,
243+ self .base_layer .local_num_experts ,
257244 self .base_layer .intermediate_size_per_partition ,
258245 lora_config .max_lora_rank ,
259246 ),
@@ -264,7 +251,7 @@ def create_lora_weights(
264251 self .w2_lora_a_stacked = torch .zeros (
265252 (
266253 max_loras ,
267- self .base_layer .global_num_experts ,
254+ self .base_layer .local_num_experts ,
268255 lora_config .max_lora_rank ,
269256 self .base_layer .intermediate_size_per_partition ,
270257 ),
@@ -274,7 +261,7 @@ def create_lora_weights(
274261 self .w2_lora_b_stacked = torch .zeros (
275262 (
276263 max_loras ,
277- self .base_layer .global_num_experts ,
264+ self .base_layer .local_num_experts ,
278265 self .base_layer .hidden_size ,
279266 lora_config .max_lora_rank ,
280267 ),
@@ -285,7 +272,7 @@ def create_lora_weights(
285272 self .w3_lora_a_stacked = torch .zeros (
286273 (
287274 max_loras ,
288- self .base_layer .global_num_experts ,
275+ self .base_layer .local_num_experts ,
289276 lora_config .max_lora_rank ,
290277 self .base_layer .hidden_size ,
291278 ),
@@ -295,7 +282,7 @@ def create_lora_weights(
295282 self .w3_lora_b_stacked = torch .zeros (
296283 (
297284 max_loras ,
298- self .base_layer .global_num_experts ,
285+ self .base_layer .local_num_experts ,
299286 self .base_layer .intermediate_size_per_partition ,
300287 lora_config .max_lora_rank ,
301288 ),
@@ -308,7 +295,7 @@ def create_lora_weights(
308295 self .lora_a_stacked = []
309296 self .lora_b_stacked = []
310297 for lora_id in range (max_loras ):
311- for experts_id in range (self .base_layer .global_num_experts ):
298+ for experts_id in range (self .base_layer .local_num_experts ):
312299 # gate_proj,down_proj,up_proj
313300 self .lora_a_stacked .append (self .w1_lora_a_stacked [lora_id ][experts_id ])
314301 self .lora_a_stacked .append (self .w2_lora_a_stacked [lora_id ][experts_id ])
0 commit comments