2121from paddle .nn .quant import weight_quantize
2222
2323from paddlenlp .experimental .transformers .fused_transformer_layers import (
24+ FusedBlockMultiTransformer ,
25+ FusedBlockMultiTransformerWeightOnly ,
2426 FusedMultiTransformerBase ,
2527 FusedMultiTransformerConfig ,
2628 FusedMultiTransformerWeightOnly ,
2729)
2830from paddlenlp .experimental .transformers .generation_utils import (
31+ GenerationBlockInferenceModel ,
2932 GenerationInferenceModel ,
3033)
34+ from paddlenlp .experimental .transformers .utils import infererence_model_from_pretrained
3135from paddlenlp .transformers import ChatGLMv2Config , ChatGLMv2PretrainedModel
3236from paddlenlp .transformers .chatglm_v2 .modeling import (
3337 Embedding ,
3943 register_base_model ,
4044)
4145
42- __all__ = [
43- "ChatGLMv2ForCausalLMInferenceModel" ,
44- ]
46+ __all__ = ["ChatGLMv2ForCausalLMInferenceModel" , "ChatGLMv2ForCausalLMBlockInferenceModel" ]
4547
4648
4749@register_base_model
@@ -176,17 +178,20 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True):
176178 kv_num_heads = config .multi_query_group_num ,
177179 )
178180
179- if self .use_weight_only :
180- self .transformer_block = FusedMultiTransformerWeightOnly (transformer_config )
181- else :
182- self .transformer_block = FusedMultiTransformerBase (transformer_config )
181+ self .set_transformer_block (transformer_config )
183182
184183 self .post_layer_norm = config .post_layer_norm
185184 if self .post_layer_norm :
186185 LayerNormFunc = RMSNorm if config .rmsnorm else nn .LayerNorm
187186 # Final layer norm before output.
188187 self .final_layernorm = LayerNormFunc (config .hidden_size , epsilon = config .layernorm_epsilon , config = config )
189188
189+ def set_transformer_block (self , transformer_config ):
190+ if self .use_weight_only :
191+ self .transformer_block = FusedMultiTransformerWeightOnly (transformer_config )
192+ else :
193+ self .transformer_block = FusedMultiTransformerBase (transformer_config )
194+
190195 def get_input_embeddings (self ):
191196 return self .embedding .word_embeddings
192197
@@ -341,7 +346,7 @@ def key(name):
341346
342347 if self .use_weight_only :
343348 linear_quanted_weight_tensor , linear_weight_scale_tensor = weight_quantize (
344- out_proj_weight , algo = self .quant_algo
349+ paddle . to_tensor ( out_proj_weight ) , algo = self .quant_algo
345350 )
346351 self .transformer_block .linear_weights [i ].set_value (linear_quanted_weight_tensor )
347352 self .transformer_block .linear_weights_scale [i ].set_value (linear_weight_scale_tensor )
@@ -352,7 +357,7 @@ def key(name):
352357
353358 if self .use_weight_only :
354359 ffn1_quanted_weight_tensor , ffn1_weight_scale_tensor = weight_quantize (
355- ffn1_weight , algo = self .quant_algo
360+ paddle . to_tensor ( ffn1_weight ) , algo = self .quant_algo
356361 )
357362 self .transformer_block .ffn1_weights [i ].set_value (ffn1_quanted_weight_tensor )
358363 self .transformer_block .ffn1_weights_scale [i ].set_value (ffn1_weight_scale_tensor )
@@ -361,20 +366,87 @@ def key(name):
361366
362367 if self .use_weight_only :
363368 ffn2_quanted_weight_tensor , ffn2_weight_scale_tensor = weight_quantize (
364- ffn2_weight , algo = self .quant_algo
369+ paddle . to_tensor ( ffn2_weight ) , algo = self .quant_algo
365370 )
366371 self .transformer_block .ffn2_weights [i ].set_value (ffn2_quanted_weight_tensor )
367372 self .transformer_block .ffn2_weights_scale [i ].set_value (ffn2_weight_scale_tensor )
368373 else :
369374 self .transformer_block .ffn2_weights [i ].set_value (ffn2_weight )
370375
371376
377+ @register_base_model
378+ class ChatGLMv2BlockInferenceModel (ChatGLMv2InferenceModel ):
379+ def __init__ (self , config : ChatGLMv2Config ):
380+ super ().__init__ (config )
381+ self .max_seq_len = config .max_sequence_length
382+ self .block_size = config .block_size
383+
384+ def set_transformer_block (self , transformer_config ):
385+ if self .use_weight_only :
386+ self .transformer_block = FusedBlockMultiTransformerWeightOnly (transformer_config )
387+ else :
388+ self .transformer_block = FusedBlockMultiTransformer (transformer_config )
389+
390+ def remove_padding (self , input_ids , seq_lens_this_time ):
391+ cum_offsets_now = paddle .cumsum (self .max_seq_len - seq_lens_this_time )
392+ token_num = paddle .sum (seq_lens_this_time )
393+ from paddlenlp_ops import get_padding_offset_v2
394+
395+ ids_remove_padding , cum_offsets , padding_offset , cu_seqlens_q , cu_seqlens_k = get_padding_offset_v2 (
396+ input_ids , cum_offsets_now , token_num , seq_lens_this_time
397+ )
398+ return ids_remove_padding , padding_offset , cum_offsets , cu_seqlens_q , cu_seqlens_k
399+
400+ def forward (
401+ self ,
402+ input_ids = None ,
403+ attention_mask = None ,
404+ inputs_embeds = None ,
405+ caches = None ,
406+ pre_caches = None ,
407+ output_attentions = False ,
408+ output_hidden_states = None ,
409+ return_dict = False ,
410+ ** kwargs ,
411+ ):
412+ seq_lens_this_time = kwargs .get ("seq_lens_this_time" , None )
413+ rope_emb = kwargs .get ("rope_emb" , None )
414+ ids_remove_padding , padding_offset , cum_offsets , cu_seqlens_q , cu_seqlens_k = self .remove_padding (
415+ input_ids , seq_lens_this_time
416+ )
417+ kwargs ["cu_seqlens_q" ] = cu_seqlens_q
418+ kwargs ["cu_seqlens_k" ] = cu_seqlens_k
419+ kwargs ["padding_offsets" ] = padding_offset
420+ kwargs ["max_input_length" ] = self .max_seq_len
421+
422+ inputs_embeds = self .embedding .word_embeddings (ids_remove_padding )
423+
424+ with dy2st_nocheck_guard_context ():
425+ hidden_states , _ = self .transformer_block (
426+ input_ids = input_ids ,
427+ src = inputs_embeds ,
428+ cum_offsets = cum_offsets ,
429+ attn_mask = attention_mask ,
430+ caches = caches ,
431+ pre_caches = None ,
432+ rotary_embs = rope_emb ,
433+ ** kwargs ,
434+ )
435+ hidden_states = self .final_layernorm (hidden_states )
436+
437+ return tuple (v for v in [hidden_states , None , None , None ] if v is not None )
438+
439+
372440class ChatGLMv2ForCausalLMInferenceModel (GenerationInferenceModel , ChatGLMv2PretrainedModel ):
373441 def __init__ (self , config : ChatGLMv2Config ):
374442 super ().__init__ (config )
375443 self .max_sequence_length = config .max_sequence_length
376444 self .chatglm_v2 = ChatGLMv2InferenceModel (config )
377445
446+ @classmethod
447+ def from_pretrained (cls , pretrained_model_name_or_path , * args , ** kwargs ):
448+ return infererence_model_from_pretrained (cls , pretrained_model_name_or_path , args , kwargs )
449+
378450 @classmethod
379451 def get_cache_kvs_shape (cls , config : ChatGLMv2Config , max_batch_size : int = None , max_length : int = None ):
380452 """get cache_kvs tensor for opt model
@@ -487,3 +559,119 @@ def forward(
487559 @paddle .no_grad ()
488560 def set_state_dict (self , state_dict ):
489561 self .chatglm_v2 .set_state_dict (state_dict )
562+
563+
564+ class ChatGLMv2ForCausalLMBlockInferenceModel (GenerationBlockInferenceModel , ChatGLMv2PretrainedModel ):
565+ def __init__ (self , config ):
566+ super ().__init__ (config )
567+ self .chatglm_v2 = ChatGLMv2BlockInferenceModel (config )
568+ self .max_sequence_length = config .max_sequence_length
569+
570+ @classmethod
571+ def from_pretrained (cls , pretrained_model_name_or_path , * args , ** kwargs ):
572+ return infererence_model_from_pretrained (cls , pretrained_model_name_or_path , args , kwargs )
573+
574+ @classmethod
575+ def get_cache_kvs_shape (cls , config : ChatGLMv2Config , max_batch_size : int = None , max_length : int = None ):
576+ """get cache_kvs tensor for chatglmv2 model
577+
578+ Args:
579+ max_batch_size (int): the max batch size
580+ max_length (int | None, optional): the max_length of cache_kvs. Defaults to None.
581+
582+ Returns:
583+ list[paddle.Tensor]: the list tensor shape for cache
584+ """
585+ max_block_per_seq = (config .max_seq_len + config .block_size - 1 ) // config .block_size
586+ if max_batch_size == - 1 :
587+ max_block_nums = None
588+ else :
589+ max_block_nums = max_batch_size * max_block_per_seq
590+
591+ cache_kvs = []
592+ for _ in range (config .num_hidden_layers ):
593+ cache_kv_shape = [
594+ max_block_nums ,
595+ config .multi_query_group_num ,
596+ config .block_size ,
597+ config .hidden_size // config .num_attention_heads ,
598+ ]
599+ cache_kvs .append (cache_kv_shape )
600+ cache_kvs .append (cache_kv_shape )
601+ return cache_kvs
602+
603+ def prepare_inputs_for_generation (self , ** kwargs ):
604+ # only last token for inputs_ids if cache is defined in kwargs
605+ input_ids = kwargs ["input_ids" ]
606+ src_mask = kwargs .get ("src_mask" , None )
607+ block_tables = kwargs .get ("block_tables" , None )
608+
609+ pre_caches = kwargs .get ("pre_caches" , None )
610+ caches = kwargs .get ("caches" , None )
611+
612+ rope_emb = kwargs ["rope_emb" ]
613+ seq_lens_this_time = kwargs ["seq_lens_this_time" ]
614+ seq_lens_encoder = kwargs ["seq_lens_encoder" ]
615+ seq_lens_decoder = kwargs ["seq_lens_decoder" ]
616+ k_quant_scales = kwargs .get ("k_quant_scales" , None )
617+ v_quant_scales = kwargs .get ("v_quant_scales" , None )
618+ k_dequant_scales = kwargs .get ("k_dequant_scales" , None )
619+ v_dequant_scales = kwargs .get ("v_dequant_scales" , None )
620+ model_inputs = {
621+ "input_ids" : input_ids ,
622+ "src_mask" : src_mask ,
623+ "rope_emb" : rope_emb ,
624+ "pre_caches" : pre_caches ,
625+ "caches" : caches ,
626+ "seq_lens_this_time" : seq_lens_this_time ,
627+ "seq_lens_encoder" : seq_lens_encoder ,
628+ "seq_lens_decoder" : seq_lens_decoder ,
629+ "block_tables" : block_tables ,
630+ "k_quant_scales" : k_quant_scales ,
631+ "v_quant_scales" : v_quant_scales ,
632+ "k_dequant_scales" : k_dequant_scales ,
633+ "v_dequant_scales" : v_dequant_scales ,
634+ }
635+ return model_inputs
636+
637+ def forward (
638+ self ,
639+ input_ids ,
640+ src_mask = None ,
641+ pre_caches = None ,
642+ caches = None ,
643+ seq_lens_this_time = None ,
644+ seq_lens_encoder = None ,
645+ seq_lens_decoder = None ,
646+ rope_emb = None ,
647+ block_tables = None ,
648+ k_quant_scales = None ,
649+ v_quant_scales = None ,
650+ k_dequant_scales = None ,
651+ v_dequant_scales = None ,
652+ ):
653+ outputs = self .chatglm_v2 (
654+ input_ids ,
655+ src_mask = src_mask ,
656+ caches = caches ,
657+ rope_emb = rope_emb ,
658+ block_tables = block_tables ,
659+ pre_caches = pre_caches ,
660+ seq_lens_this_time = seq_lens_this_time ,
661+ seq_lens_encoder = seq_lens_encoder ,
662+ seq_lens_decoder = seq_lens_decoder ,
663+ k_quant_scales = k_quant_scales ,
664+ v_quant_scales = v_quant_scales ,
665+ k_dequant_scales = k_dequant_scales ,
666+ v_dequant_scales = v_dequant_scales ,
667+ )
668+
669+ hidden_states = outputs [0 ]
670+ lm_logits = self .chatglm_v2 .output_layer (hidden_states )
671+ output = (lm_logits ,) + outputs [1 :]
672+
673+ return output
674+
675+ @paddle .no_grad ()
676+ def set_state_dict (self , state_dict ):
677+ self .chatglm_v2 .set_state_dict (state_dict )
0 commit comments