Skip to content

Commit 466e878

Browse files
authored
[Quantization] Enable BNB support for more MoE models (vllm-project#21100)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 2179372 commit 466e878

File tree

5 files changed

+223
-181
lines changed

5 files changed

+223
-181
lines changed

docs/models/supported_models.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ Specified using `--task generate`.
316316
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
317317
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
318318
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
319-
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | | ✅︎ | ✅︎ |
319+
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
320320
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
321321
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
322322
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
@@ -328,8 +328,8 @@ Specified using `--task generate`.
328328
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ |
329329
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ |
330330
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ |
331-
| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | | ✅︎ | ✅︎ |
332-
| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. | | ✅︎ | ✅︎ |
331+
| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ |
332+
| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ |
333333
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
334334
| `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
335335
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
@@ -351,7 +351,7 @@ Specified using `--task generate`.
351351
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
352352
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
353353
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
354-
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | | | ✅︎ |
354+
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ |
355355
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
356356
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
357357
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

vllm/model_executor/models/bailing_moe.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from vllm.model_executor.sampling_metadata import SamplingMetadata
5454
from vllm.sequence import IntermediateTensors
5555

56-
from .interfaces import SupportsPP
56+
from .interfaces import SupportsLoRA, SupportsPP
5757
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
5858
make_empty_intermediate_tensors_factory, make_layers,
5959
maybe_prefix)
@@ -374,21 +374,25 @@ def forward(
374374
hidden_states, _ = self.norm(hidden_states, residual)
375375
return hidden_states
376376

377+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
378+
return FusedMoE.make_expert_params_mapping(
379+
ckpt_gate_proj_name="gate_proj",
380+
ckpt_down_proj_name="down_proj",
381+
ckpt_up_proj_name="up_proj",
382+
num_experts=self.config.num_experts,
383+
)
384+
377385
def load_weights(self, weights: Iterable[tuple[str,
378386
torch.Tensor]]) -> set[str]:
379387
stacked_params_mapping = [
380388
# (param_name, shard_name, shard_id)
381389
("gate_up_proj", "gate_proj", 0),
382390
("gate_up_proj", "up_proj", 1),
383391
]
384-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
385-
ckpt_gate_proj_name="gate_proj",
386-
ckpt_down_proj_name="down_proj",
387-
ckpt_up_proj_name="up_proj",
388-
num_experts=self.config.num_experts)
389392

390393
params_dict = dict(self.named_parameters(remove_duplicate=False))
391394
loaded_params: set[str] = set()
395+
expert_params_mapping = self.get_expert_mapping()
392396
for name, loaded_weight in weights:
393397
if self.config.norm_head and "lm_head.weight" in name:
394398
loaded_weight = F.normalize(loaded_weight,
@@ -449,7 +453,7 @@ def load_weights(self, weights: Iterable[tuple[str,
449453
return loaded_params
450454

451455

452-
class BailingMoeForCausalLM(nn.Module, SupportsPP):
456+
class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
453457

454458
packed_modules_mapping = {
455459
"query_key_value": ["query_key_value"],
@@ -518,3 +522,6 @@ def load_weights(self, weights: Iterable[tuple[str,
518522
if self.config.tie_word_embeddings else None),
519523
)
520524
return loader.load_weights(weights)
525+
526+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
527+
return self.model.get_expert_mapping()

vllm/model_executor/models/ernie45_moe.py

Lines changed: 84 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@
5151
from vllm.model_executor.sampling_metadata import SamplingMetadata
5252
from vllm.sequence import IntermediateTensors
5353

54-
from .interfaces import SupportsPP
55-
from .utils import (PPMissingLayer, extract_layer_index,
54+
from .interfaces import SupportsLoRA, SupportsPP
55+
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5656
is_pp_missing_parameter,
5757
make_empty_intermediate_tensors_factory, make_layers,
5858
maybe_prefix)
@@ -427,66 +427,15 @@ def forward(
427427

428428
return hidden_states
429429

430+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
430431

431-
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP):
432-
packed_modules_mapping = {
433-
"qkv_proj": [
434-
"q_proj",
435-
"k_proj",
436-
"v_proj",
437-
],
438-
"gate_up_proj": [
439-
"gate_proj",
440-
"up_proj",
441-
],
442-
}
443-
444-
fall_back_to_pt_during_load = False
445-
446-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
447-
super().__init__()
448-
config = vllm_config.model_config.hf_config
449-
quant_config = vllm_config.quant_config
450-
self.config = config
451-
self.quant_config = quant_config
452-
self.model = Ernie4_5_MoeModel(vllm_config=vllm_config,
453-
prefix=maybe_prefix(prefix, "model"))
454-
455-
if get_pp_group().is_last_rank:
456-
self.lm_head = ParallelLMHead(config.vocab_size,
457-
config.hidden_size,
458-
quant_config=quant_config)
459-
else:
460-
self.lm_head = PPMissingLayer()
461-
462-
if self.config.tie_word_embeddings:
463-
self.lm_head.weight = self.model.embed_tokens.weight
464-
self.logits_processor = LogitsProcessor(config.vocab_size)
465-
self.make_empty_intermediate_tensors = (
466-
self.model.make_empty_intermediate_tensors)
467-
468-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
469-
return self.model.get_input_embeddings(input_ids)
470-
471-
def forward(
472-
self,
473-
input_ids: torch.Tensor,
474-
positions: torch.Tensor,
475-
intermediate_tensors: Optional[IntermediateTensors] = None,
476-
inputs_embeds: Optional[torch.Tensor] = None,
477-
) -> Union[torch.Tensor, IntermediateTensors]:
478-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
479-
inputs_embeds)
480-
return hidden_states
481-
482-
def compute_logits(
483-
self,
484-
hidden_states: torch.Tensor,
485-
sampling_metadata: SamplingMetadata,
486-
) -> Optional[torch.Tensor]:
487-
logits = self.logits_processor(self.lm_head, hidden_states,
488-
sampling_metadata)
489-
return logits
432+
# Params for weights, fp8 weight scales, fp8 activation scales
433+
# (param_name, weight_name, expert_id, shard_id)
434+
return FusedMoE.make_expert_params_mapping(
435+
ckpt_gate_proj_name="gate_proj",
436+
ckpt_down_proj_name="down_proj",
437+
ckpt_up_proj_name="up_proj",
438+
num_experts=self.config.moe_num_experts)
490439

491440
def load_weights(self, weights: Iterable[tuple[str,
492441
torch.Tensor]]) -> set[str]:
@@ -499,16 +448,9 @@ def load_weights(self, weights: Iterable[tuple[str,
499448
("gate_up_proj", "up_proj", 1),
500449
]
501450

502-
# Params for weights, fp8 weight scales, fp8 activation scales
503-
# (param_name, weight_name, expert_id, shard_id)
504-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
505-
ckpt_gate_proj_name="gate_proj",
506-
ckpt_down_proj_name="down_proj",
507-
ckpt_up_proj_name="up_proj",
508-
num_experts=self.config.moe_num_experts)
509-
510451
params_dict = dict(self.named_parameters())
511452
loaded_params: set[str] = set()
453+
expert_params_mapping = self.get_expert_mapping()
512454
for name, loaded_weight in weights:
513455
if self.config.tie_word_embeddings and name.endswith(
514456
"lm_head.weight"):
@@ -581,3 +523,76 @@ def load_weights(self, weights: Iterable[tuple[str,
581523
weight_loader(param, loaded_weight)
582524
loaded_params.add(name)
583525
return loaded_params
526+
527+
528+
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
529+
packed_modules_mapping = {
530+
"qkv_proj": [
531+
"q_proj",
532+
"k_proj",
533+
"v_proj",
534+
],
535+
"gate_up_proj": [
536+
"gate_proj",
537+
"up_proj",
538+
],
539+
}
540+
541+
fall_back_to_pt_during_load = False
542+
543+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
544+
super().__init__()
545+
config = vllm_config.model_config.hf_config
546+
quant_config = vllm_config.quant_config
547+
self.config = config
548+
self.quant_config = quant_config
549+
self.model = Ernie4_5_MoeModel(vllm_config=vllm_config,
550+
prefix=maybe_prefix(prefix, "model"))
551+
552+
if get_pp_group().is_last_rank:
553+
self.lm_head = ParallelLMHead(config.vocab_size,
554+
config.hidden_size,
555+
quant_config=quant_config)
556+
else:
557+
self.lm_head = PPMissingLayer()
558+
559+
if self.config.tie_word_embeddings:
560+
self.lm_head.weight = self.model.embed_tokens.weight
561+
self.logits_processor = LogitsProcessor(config.vocab_size)
562+
self.make_empty_intermediate_tensors = (
563+
self.model.make_empty_intermediate_tensors)
564+
565+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
566+
return self.model.get_input_embeddings(input_ids)
567+
568+
def forward(
569+
self,
570+
input_ids: torch.Tensor,
571+
positions: torch.Tensor,
572+
intermediate_tensors: Optional[IntermediateTensors] = None,
573+
inputs_embeds: Optional[torch.Tensor] = None,
574+
) -> Union[torch.Tensor, IntermediateTensors]:
575+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
576+
inputs_embeds)
577+
return hidden_states
578+
579+
def compute_logits(
580+
self,
581+
hidden_states: torch.Tensor,
582+
sampling_metadata: SamplingMetadata,
583+
) -> Optional[torch.Tensor]:
584+
logits = self.logits_processor(self.lm_head, hidden_states,
585+
sampling_metadata)
586+
return logits
587+
588+
def load_weights(self, weights: Iterable[tuple[str,
589+
torch.Tensor]]) -> set[str]:
590+
loader = AutoWeightsLoader(
591+
self,
592+
skip_prefixes=(["lm_head."]
593+
if self.config.tie_word_embeddings else None),
594+
)
595+
return loader.load_weights(weights)
596+
597+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
598+
return self.model.get_expert_mapping()

vllm/model_executor/models/grok1.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,16 @@ def forward(
360360
hidden_states, _ = self.norm(hidden_states, residual)
361361
return hidden_states
362362

363+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
364+
# Map Grok1's unique expert parameter names to standard names
365+
# Grok1 uses "num_experts" in its config
366+
num_experts = getattr(self.config, "num_experts", 8)
367+
return FusedMoE.make_expert_params_mapping(
368+
ckpt_gate_proj_name="linear", # Grok1 specific
369+
ckpt_down_proj_name="linear_1", # Grok1 specific
370+
ckpt_up_proj_name="linear_v", # Grok1 specific
371+
num_experts=num_experts)
372+
363373
def load_weights(self, weights: Iterable[tuple[str,
364374
torch.Tensor]]) -> set[str]:
365375
stacked_params_mapping = [
@@ -369,18 +379,9 @@ def load_weights(self, weights: Iterable[tuple[str,
369379
("qkv_proj", "v_proj", "v"),
370380
]
371381

372-
# Map Grok1's unique expert parameter names to standard names
373-
# Grok1 uses "num_experts" in its config
374-
num_experts = getattr(self.config, "num_experts", 8)
375-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
376-
ckpt_gate_proj_name="linear", # Grok1 specific
377-
ckpt_down_proj_name="linear_1", # Grok1 specific
378-
ckpt_up_proj_name="linear_v", # Grok1 specific
379-
num_experts=num_experts)
380-
381382
params_dict = dict(self.named_parameters())
382383
loaded_params: set[str] = set()
383-
384+
expert_params_mapping = self.get_expert_mapping()
384385
for name, loaded_weight in weights:
385386
if (self.quant_config is not None and
386387
(scale_name := self.quant_config.get_cache_scale(name))):
@@ -544,3 +545,6 @@ def load_weights(self, weights: Iterable[tuple[str,
544545
skip_prefixes=skip_prefixes,
545546
)
546547
return loader.load_weights(weights)
548+
549+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
550+
return self.model.get_expert_mapping()

0 commit comments

Comments
 (0)