Skip to content

Commit ddc9e76

Browse files
py-andy-cjeejeelee
authored andcommitted
Enable prequant
1 parent 1b99028 commit ddc9e76

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -427,13 +427,9 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
427427
elif isinstance(module, FusedMoE) and hasattr(
428428
module.quant_method, "quant_config"):
429429
# TODO: support FusedMoE with prequant and 8bit.
430-
if self.pre_quant:
430+
if self.pre_quant and self.load_8bit:
431431
raise ValueError(
432-
"Prequant BitsAndBytes models with FusedMoE is not "
433-
"supported yet.")
434-
if self.load_8bit:
435-
raise ValueError(
436-
"BitsAndBytes 8bit quantization with FusedMoE is not "
432+
"Prequant BitsAndBytes 8bit models with FusedMoE is not "
437433
"supported yet.")
438434
# Get the corresponding weight name using module name and
439435
# expert_params_mapping.

vllm/model_executor/models/qwen3_moe.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
default_weight_loader, maybe_remap_kv_scale_name)
5353
from vllm.model_executor.sampling_metadata import SamplingMetadata
5454
from vllm.sequence import IntermediateTensors
55+
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
5556

5657
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
5758
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
@@ -149,8 +150,13 @@ def __init__(
149150
self.gate = ReplicatedLinear(config.hidden_size,
150151
config.num_experts,
151152
bias=False,
152-
quant_config=None,
153+
quant_config=self._maybe_ignore_quant_config(quant_config), # Some quantization methods do not quantize the gate
153154
prefix=f"{prefix}.gate")
155+
156+
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
157+
if not isinstance(quant_config, (BitsAndBytesConfig)):
158+
return None
159+
return quant_config
154160

155161
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
156162
# NOTE: hidden_states can have either 1D or 2D shape.

0 commit comments

Comments
 (0)