Skip to content

Commit b06b947

Browse files
[Rocm][fused_moe][fp4] view weight to torch.float4_e2m1fn_x2 when running aiter fused moe for fp4 model (vllm-project#27474)
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
1 parent 4673e46 commit b06b947

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ def __init__(
458458

459459
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
460460
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
461+
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
461462

462463
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
463464
self.input_dtype, self.weight_dtype
@@ -581,6 +582,17 @@ def process_weights_after_loading(self, layer):
581582
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
582583
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
583584
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
585+
586+
if self.fp4_dtype is not None:
587+
layer.w13_weight = torch.nn.Parameter(
588+
layer.w13_weight.view(self.fp4_dtype),
589+
requires_grad=layer.w13_weight.requires_grad,
590+
)
591+
layer.w2_weight = torch.nn.Parameter(
592+
layer.w2_weight.view(self.fp4_dtype),
593+
requires_grad=layer.w2_weight.requires_grad,
594+
)
595+
584596
torch.cuda.empty_cache()
585597

586598
def get_fused_moe_quant_config(

0 commit comments

Comments
 (0)