Skip to content

Commit 5d3be3b

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Bugfix][LoRA][FusedMoE] Select MxFP4 Backend based on LoRA Enablement (vllm-project#27487)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 4f882be commit 5d3be3b

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,8 @@ class FusedMoEConfig:
825825

826826
is_act_and_mul: bool = True
827827

828+
is_lora_enabled: bool = False
829+
828830
def __post_init__(self):
829831
if self.dp_size > 1:
830832
logger.debug_once(

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,7 @@ def maybe_roundup_hidden_size(
982982
act_dtype: torch.dtype,
983983
quant_config: QuantizationConfig | None,
984984
moe_parallel_config: FusedMoEParallelConfig,
985+
is_lora_enabled: bool,
985986
) -> int:
986987
"""
987988
Given layer hidden size and MoE configurations, round up hidden_size
@@ -992,6 +993,9 @@ def maybe_roundup_hidden_size(
992993
act_dtype: Data type of the layer activations.
993994
quant_config: Fused MoE quantization configuration.
994995
moe_parallel_config: Fused MoE parallelization strategy configuration.
996+
is_lora_enabled: True if the engine is enabled with LoRA. This
997+
is used in the case of mxfp4 quantization in selecting the
998+
MxFP4Backend.
995999
9961000
Return:
9971001
Rounded up hidden_size if rounding up is required based on the configs.
@@ -1015,7 +1019,7 @@ def maybe_roundup_hidden_size(
10151019
get_mxfp4_backend,
10161020
)
10171021

1018-
current_mxfp4_backend = get_mxfp4_backend()
1022+
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
10191023
if (
10201024
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
10211025
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
@@ -1139,7 +1143,11 @@ def __init__(
11391143

11401144
# Round up hidden size if needed.
11411145
hidden_size = maybe_roundup_hidden_size(
1142-
hidden_size, moe_in_dtype, quant_config, self.moe_parallel_config
1146+
hidden_size,
1147+
moe_in_dtype,
1148+
quant_config,
1149+
self.moe_parallel_config,
1150+
is_lora_enabled=self.vllm_config.lora_config is not None,
11431151
)
11441152

11451153
# For smuggling this layer into the fused moe custom op
@@ -1270,8 +1278,9 @@ def __init__(
12701278
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
12711279
has_bias=has_bias,
12721280
is_act_and_mul=is_act_and_mul,
1281+
is_lora_enabled=vllm_config.lora_config is not None,
12731282
)
1274-
self.moe_config = moe
1283+
self.moe_config: FusedMoEConfig = moe
12751284
self.moe_quant_config: FusedMoEQuantConfig | None = None
12761285
self.quant_config = quant_config
12771286

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,24 @@ class Mxfp4Backend(Enum):
7373
TRITON = 6
7474

7575

76-
def get_mxfp4_backend():
76+
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
77+
"""
78+
Not all MXFP4 backends support LoRA. Select backends that are known to
79+
have LoRA support.
80+
"""
81+
if not current_platform.is_cuda():
82+
return Mxfp4Backend.NONE
83+
84+
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
85+
return Mxfp4Backend.MARLIN
86+
87+
88+
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
7789
# Backend Selection
90+
91+
if with_lora_support:
92+
return get_mxfp4_backend_with_lora()
93+
7894
if current_platform.is_cuda():
7995
if (
8096
current_platform.is_device_capability(90)
@@ -183,13 +199,14 @@ def __init__(self, moe: FusedMoEConfig):
183199
super().__init__(moe)
184200
self.topk_indices_dtype = None
185201
self.moe = moe
186-
self.mxfp4_backend = get_mxfp4_backend()
202+
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
187203
self.max_capture_size = (
188204
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
189205
)
190206

191207
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
192-
"No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available."
208+
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
209+
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
193210
"Please check your environment and try again."
194211
)
195212
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}

0 commit comments

Comments
 (0)