Skip to content

Commit bc4486d

Browse files
authored
[Kernel] Enable FusedMoEModularKernel support bias (vllm-project#27754)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 0cdbe7b commit bc4486d

File tree

2 files changed

+15
-30
lines changed

2 files changed

+15
-30
lines changed

vllm/lora/layers/fused_moe.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from vllm.lora.layers.base import BaseLayerWithLoRA
1616
from vllm.model_executor.layers.fused_moe import FusedMoE
1717
from vllm.model_executor.layers.fused_moe.config import (
18-
FUSED_MOE_UNQUANTIZED_CONFIG,
1918
_get_config_dtype_str,
20-
mxfp4_w4a16_moe_quant_config,
2119
)
2220
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
2321
modular_marlin_fused_moe,
@@ -26,13 +24,16 @@
2624
modular_triton_fused_moe,
2725
try_get_optimal_moe_config,
2826
)
29-
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
3027

3128

3229
class FusedMoEWithLoRA(BaseLayerWithLoRA):
3330
def __init__(self, base_layer: FusedMoE) -> None:
3431
super().__init__()
3532
self.base_layer = base_layer
33+
34+
assert not self.base_layer.use_ep, (
35+
"EP support for Fused MoE LoRA is not implemented yet."
36+
)
3637
self.tp_size = get_tensor_model_parallel_world_size()
3738
self.tp_rank = get_tensor_model_parallel_rank()
3839
self.device = base_layer.w2_weight.device
@@ -42,17 +43,8 @@ def _inject_lora_into_fused_moe(self):
4243
moe_state_dict = {}
4344
top_k = self.base_layer.top_k
4445

45-
if self.base_layer.quant_config is None:
46-
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
47-
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
48-
quant_config = self.base_layer.quant_config
49-
else:
50-
quant_config = mxfp4_w4a16_moe_quant_config(
51-
w1_bias=self.base_layer.w13_bias,
52-
w2_bias=self.base_layer.w2_bias,
53-
w1_scale=self.base_layer.w13_weight_scale,
54-
w2_scale=self.base_layer.w2_weight_scale,
55-
)
46+
self.base_layer.ensure_moe_quant_config_init()
47+
quant_config = self.base_layer.quant_method.moe_quant_config
5648

5749
m_fused_moe_fn = (
5850
modular_triton_fused_moe(
@@ -69,7 +61,6 @@ def wrapper(*args, **kwargs):
6961
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
7062
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
7163
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
72-
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
7364
moe_state_dict["expert_map"] = kwargs["expert_map"]
7465
moe_state_dict["apply_router_weight_on_input"] = kwargs[
7566
"apply_router_weight_on_input"
@@ -86,7 +77,7 @@ def wrapper(*args, **kwargs):
8677
hidden_states = moe_state_dict["hidden_states"]
8778
topk_weights = moe_state_dict["topk_weights"]
8879
curr_topk_ids = moe_state_dict["topk_ids"]
89-
global_num_experts = moe_state_dict["global_num_experts"]
80+
9081
expert_map = moe_state_dict["expert_map"]
9182

9283
config_dtype = _get_config_dtype_str(
@@ -118,7 +109,7 @@ def wrapper(*args, **kwargs):
118109
curr_topk_ids,
119110
num_tokens,
120111
config["BLOCK_SIZE_M"],
121-
global_num_experts,
112+
self.base_layer.local_num_experts,
122113
max_loras,
123114
expert_map,
124115
)
@@ -236,14 +227,10 @@ def create_lora_weights(
236227
) -> None:
237228
"""Initializes lora matrices."""
238229

239-
assert not self.base_layer.use_ep, (
240-
"EP support for Fused MoE LoRA is not implemented yet."
241-
)
242-
243230
self.w1_lora_a_stacked = torch.zeros(
244231
(
245232
max_loras,
246-
self.base_layer.global_num_experts,
233+
self.base_layer.local_num_experts,
247234
lora_config.max_lora_rank,
248235
self.base_layer.hidden_size,
249236
),
@@ -253,7 +240,7 @@ def create_lora_weights(
253240
self.w1_lora_b_stacked = torch.zeros(
254241
(
255242
max_loras,
256-
self.base_layer.global_num_experts,
243+
self.base_layer.local_num_experts,
257244
self.base_layer.intermediate_size_per_partition,
258245
lora_config.max_lora_rank,
259246
),
@@ -264,7 +251,7 @@ def create_lora_weights(
264251
self.w2_lora_a_stacked = torch.zeros(
265252
(
266253
max_loras,
267-
self.base_layer.global_num_experts,
254+
self.base_layer.local_num_experts,
268255
lora_config.max_lora_rank,
269256
self.base_layer.intermediate_size_per_partition,
270257
),
@@ -274,7 +261,7 @@ def create_lora_weights(
274261
self.w2_lora_b_stacked = torch.zeros(
275262
(
276263
max_loras,
277-
self.base_layer.global_num_experts,
264+
self.base_layer.local_num_experts,
278265
self.base_layer.hidden_size,
279266
lora_config.max_lora_rank,
280267
),
@@ -285,7 +272,7 @@ def create_lora_weights(
285272
self.w3_lora_a_stacked = torch.zeros(
286273
(
287274
max_loras,
288-
self.base_layer.global_num_experts,
275+
self.base_layer.local_num_experts,
289276
lora_config.max_lora_rank,
290277
self.base_layer.hidden_size,
291278
),
@@ -295,7 +282,7 @@ def create_lora_weights(
295282
self.w3_lora_b_stacked = torch.zeros(
296283
(
297284
max_loras,
298-
self.base_layer.global_num_experts,
285+
self.base_layer.local_num_experts,
299286
self.base_layer.intermediate_size_per_partition,
300287
lora_config.max_lora_rank,
301288
),
@@ -308,7 +295,7 @@ def create_lora_weights(
308295
self.lora_a_stacked = []
309296
self.lora_b_stacked = []
310297
for lora_id in range(max_loras):
311-
for experts_id in range(self.base_layer.global_num_experts):
298+
for experts_id in range(self.base_layer.local_num_experts):
312299
# gate_proj,down_proj,up_proj
313300
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
314301
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,6 @@ def forward_cuda(
672672
apply_router_weight_on_input=apply_router_weight_on_input,
673673
)
674674
elif self.fused_experts is not None:
675-
if self.moe.has_bias:
676-
raise ValueError("FusedMoEModularKernel does not support bias.")
677675
result = self.fused_experts(
678676
hidden_states=x,
679677
w1=layer.w13_weight,

0 commit comments

Comments
 (0)