File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed
vllm/model_executor/layers/quantization/quark Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -458,6 +458,7 @@ def __init__(
458458
459459 self .weight_dtype = self .weight_quant ["dtype" ].replace ("fp" , "mxfp" )
460460 self .input_dtype = self .input_quant ["dtype" ].replace ("fp" , "mxfp" )
461+ self .fp4_dtype = getattr (torch , "float4_e2m1fn_x2" , None )
461462
462463 self .ocp_mx_scheme = OCP_MX_Scheme .from_quant_dtype (
463464 self .input_dtype , self .weight_dtype
@@ -581,6 +582,17 @@ def process_weights_after_loading(self, layer):
581582 w2_weight_scale = layer .w2_weight_scale .view (s0 * s1 , - 1 )
582583 w2_weight_scale = e8m0_shuffle (w2_weight_scale )
583584 layer .w2_weight_scale .data = w2_weight_scale .view (s0 , s1 , - 1 )
585+
586+ if self .fp4_dtype is not None :
587+ layer .w13_weight = torch .nn .Parameter (
588+ layer .w13_weight .view (self .fp4_dtype ),
589+ requires_grad = layer .w13_weight .requires_grad ,
590+ )
591+ layer .w2_weight = torch .nn .Parameter (
592+ layer .w2_weight .view (self .fp4_dtype ),
593+ requires_grad = layer .w2_weight .requires_grad ,
594+ )
595+
584596 torch .cuda .empty_cache ()
585597
586598 def get_fused_moe_quant_config (
You can’t perform that action at this time.
0 commit comments