5151from vllm .model_executor .sampling_metadata import SamplingMetadata
5252from vllm .sequence import IntermediateTensors
5353
54- from .interfaces import SupportsPP
55- from .utils import (PPMissingLayer , extract_layer_index ,
54+ from .interfaces import SupportsLoRA , SupportsPP
55+ from .utils import (AutoWeightsLoader , PPMissingLayer , extract_layer_index ,
5656 is_pp_missing_parameter ,
5757 make_empty_intermediate_tensors_factory , make_layers ,
5858 maybe_prefix )
@@ -427,66 +427,15 @@ def forward(
427427
428428 return hidden_states
429429
430+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
430431
431- class Ernie4_5_MoeForCausalLM (nn .Module , SupportsPP ):
432- packed_modules_mapping = {
433- "qkv_proj" : [
434- "q_proj" ,
435- "k_proj" ,
436- "v_proj" ,
437- ],
438- "gate_up_proj" : [
439- "gate_proj" ,
440- "up_proj" ,
441- ],
442- }
443-
444- fall_back_to_pt_during_load = False
445-
446- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
447- super ().__init__ ()
448- config = vllm_config .model_config .hf_config
449- quant_config = vllm_config .quant_config
450- self .config = config
451- self .quant_config = quant_config
452- self .model = Ernie4_5_MoeModel (vllm_config = vllm_config ,
453- prefix = maybe_prefix (prefix , "model" ))
454-
455- if get_pp_group ().is_last_rank :
456- self .lm_head = ParallelLMHead (config .vocab_size ,
457- config .hidden_size ,
458- quant_config = quant_config )
459- else :
460- self .lm_head = PPMissingLayer ()
461-
462- if self .config .tie_word_embeddings :
463- self .lm_head .weight = self .model .embed_tokens .weight
464- self .logits_processor = LogitsProcessor (config .vocab_size )
465- self .make_empty_intermediate_tensors = (
466- self .model .make_empty_intermediate_tensors )
467-
468- def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
469- return self .model .get_input_embeddings (input_ids )
470-
471- def forward (
472- self ,
473- input_ids : torch .Tensor ,
474- positions : torch .Tensor ,
475- intermediate_tensors : Optional [IntermediateTensors ] = None ,
476- inputs_embeds : Optional [torch .Tensor ] = None ,
477- ) -> Union [torch .Tensor , IntermediateTensors ]:
478- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
479- inputs_embeds )
480- return hidden_states
481-
482- def compute_logits (
483- self ,
484- hidden_states : torch .Tensor ,
485- sampling_metadata : SamplingMetadata ,
486- ) -> Optional [torch .Tensor ]:
487- logits = self .logits_processor (self .lm_head , hidden_states ,
488- sampling_metadata )
489- return logits
432+ # Params for weights, fp8 weight scales, fp8 activation scales
433+ # (param_name, weight_name, expert_id, shard_id)
434+ return FusedMoE .make_expert_params_mapping (
435+ ckpt_gate_proj_name = "gate_proj" ,
436+ ckpt_down_proj_name = "down_proj" ,
437+ ckpt_up_proj_name = "up_proj" ,
438+ num_experts = self .config .moe_num_experts )
490439
491440 def load_weights (self , weights : Iterable [tuple [str ,
492441 torch .Tensor ]]) -> set [str ]:
@@ -499,16 +448,9 @@ def load_weights(self, weights: Iterable[tuple[str,
499448 ("gate_up_proj" , "up_proj" , 1 ),
500449 ]
501450
502- # Params for weights, fp8 weight scales, fp8 activation scales
503- # (param_name, weight_name, expert_id, shard_id)
504- expert_params_mapping = FusedMoE .make_expert_params_mapping (
505- ckpt_gate_proj_name = "gate_proj" ,
506- ckpt_down_proj_name = "down_proj" ,
507- ckpt_up_proj_name = "up_proj" ,
508- num_experts = self .config .moe_num_experts )
509-
510451 params_dict = dict (self .named_parameters ())
511452 loaded_params : set [str ] = set ()
453+ expert_params_mapping = self .get_expert_mapping ()
512454 for name , loaded_weight in weights :
513455 if self .config .tie_word_embeddings and name .endswith (
514456 "lm_head.weight" ):
@@ -581,3 +523,76 @@ def load_weights(self, weights: Iterable[tuple[str,
581523 weight_loader (param , loaded_weight )
582524 loaded_params .add (name )
583525 return loaded_params
526+
527+
528+ class Ernie4_5_MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ):
529+ packed_modules_mapping = {
530+ "qkv_proj" : [
531+ "q_proj" ,
532+ "k_proj" ,
533+ "v_proj" ,
534+ ],
535+ "gate_up_proj" : [
536+ "gate_proj" ,
537+ "up_proj" ,
538+ ],
539+ }
540+
541+ fall_back_to_pt_during_load = False
542+
543+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
544+ super ().__init__ ()
545+ config = vllm_config .model_config .hf_config
546+ quant_config = vllm_config .quant_config
547+ self .config = config
548+ self .quant_config = quant_config
549+ self .model = Ernie4_5_MoeModel (vllm_config = vllm_config ,
550+ prefix = maybe_prefix (prefix , "model" ))
551+
552+ if get_pp_group ().is_last_rank :
553+ self .lm_head = ParallelLMHead (config .vocab_size ,
554+ config .hidden_size ,
555+ quant_config = quant_config )
556+ else :
557+ self .lm_head = PPMissingLayer ()
558+
559+ if self .config .tie_word_embeddings :
560+ self .lm_head .weight = self .model .embed_tokens .weight
561+ self .logits_processor = LogitsProcessor (config .vocab_size )
562+ self .make_empty_intermediate_tensors = (
563+ self .model .make_empty_intermediate_tensors )
564+
565+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
566+ return self .model .get_input_embeddings (input_ids )
567+
568+ def forward (
569+ self ,
570+ input_ids : torch .Tensor ,
571+ positions : torch .Tensor ,
572+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
573+ inputs_embeds : Optional [torch .Tensor ] = None ,
574+ ) -> Union [torch .Tensor , IntermediateTensors ]:
575+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
576+ inputs_embeds )
577+ return hidden_states
578+
579+ def compute_logits (
580+ self ,
581+ hidden_states : torch .Tensor ,
582+ sampling_metadata : SamplingMetadata ,
583+ ) -> Optional [torch .Tensor ]:
584+ logits = self .logits_processor (self .lm_head , hidden_states ,
585+ sampling_metadata )
586+ return logits
587+
588+ def load_weights (self , weights : Iterable [tuple [str ,
589+ torch .Tensor ]]) -> set [str ]:
590+ loader = AutoWeightsLoader (
591+ self ,
592+ skip_prefixes = (["lm_head." ]
593+ if self .config .tie_word_embeddings else None ),
594+ )
595+ return loader .load_weights (weights )
596+
597+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
598+ return self .model .get_expert_mapping ()
0 commit comments