Skip to content

Commit 2e8b220

Browse files
authored
[LLM INFER] Fix some bugs and chatglm_v2 support block_attn (#9271)
* chatglm2 support block_attn and fix some bugs * fix ci * fix more ut error * update
1 parent b237ba7 commit 2e8b220

File tree

5 files changed

+265
-138
lines changed

5 files changed

+265
-138
lines changed

llm/predict/predictor.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,16 @@ class PredictorArgument:
134134
},
135135
)
136136

137-
@property
138-
def total_max_length(self):
139-
if self.device == "npu":
140-
return self.src_length + self.max_length
141-
else:
142-
return 8192 # Maximum sequence length.
137+
total_max_length: int = field(
138+
default=4096, metadata={"help": "Super parameter. Maximum sequence length(encoder+decoder)."}
139+
)
143140

144141
def __post_init__(self):
145142
if self.append_attn:
146143
self.block_attn = True
144+
assert (
145+
self.src_length + self.max_length <= self.total_max_length
146+
), "src_length + max_length should smaller than total_max_length."
147147

148148

149149
@dataclass
@@ -520,7 +520,7 @@ def _preprocess(self, source):
520520
alibi_slopes = llm_utils.get_alibi_slopes(self.model_config.n_head)
521521
inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32")
522522
arange_tensor_encoder = paddle.arange(self.config.total_max_length, dtype=self.config.dtype)
523-
alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder
523+
alibi = (alibi_slopes[None, :, None, None] * arange_tensor_encoder).astype(self.config.dtype)
524524

525525
if self.model_config.tensor_parallel_degree > 1:
526526
block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree
@@ -1352,13 +1352,19 @@ def create_predictor(
13521352
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
13531353
)
13541354
model.eval()
1355-
13561355
elif "chatglmv2forcausallm" in config.architectures[0].lower():
1357-
from paddlenlp.experimental.transformers import (
1358-
ChatGLMv2ForCausalLMInferenceModel as Model,
1359-
)
1360-
1361-
model = Model.from_pretrained(
1356+
predictor_args.total_max_length = config.seq_length
1357+
if predictor_args.block_attn:
1358+
config.block_size = predictor_args.block_size
1359+
config.max_seq_len = predictor_args.total_max_length
1360+
from paddlenlp.experimental.transformers import (
1361+
ChatGLMv2ForCausalLMBlockInferenceModel as ChatGLMv2InferenceModel,
1362+
)
1363+
else:
1364+
from paddlenlp.experimental.transformers import (
1365+
ChatGLMv2ForCausalLMInferenceModel as ChatGLMv2InferenceModel,
1366+
)
1367+
model = ChatGLMv2InferenceModel.from_pretrained(
13621368
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
13631369
)
13641370
model.eval()
@@ -1522,19 +1528,19 @@ def create_predictor(
15221528
config, predictor_args.batch_size, predictor_args.total_max_length
15231529
)
15241530
elif "chatglmv2forcausallm" in config.architectures[0].lower():
1525-
from paddlenlp.experimental.transformers import (
1526-
ChatGLMv2ForCausalLMInferenceModel,
1527-
)
1528-
1529-
cache_kvs_shape = ChatGLMv2ForCausalLMInferenceModel.get_cache_kvs_shape(
1530-
config, predictor_args.batch_size, predictor_args.total_max_length
1531-
)
1532-
elif "chatglmv2forcausallm" in config.architectures[0].lower():
1533-
from paddlenlp.experimental.transformers import (
1534-
ChatGLMv2ForCausalLMInferenceModel,
1535-
)
1531+
predictor_args.total_max_length = config.seq_length
1532+
if predictor_args.block_attn:
1533+
config.block_size = predictor_args.block_size
1534+
config.max_seq_len = predictor_args.total_max_length
1535+
from paddlenlp.experimental.transformers import (
1536+
ChatGLMv2ForCausalLMBlockInferenceModel as ChatGLMv2InferenceModel,
1537+
)
1538+
else:
1539+
from paddlenlp.experimental.transformers import (
1540+
ChatGLMv2ForCausalLMInferenceModel as ChatGLMv2InferenceModel,
1541+
)
15361542

1537-
cache_kvs_shape = ChatGLMv2ForCausalLMInferenceModel.get_cache_kvs_shape(
1543+
cache_kvs_shape = ChatGLMv2InferenceModel.get_cache_kvs_shape(
15381544
config, predictor_args.batch_size, predictor_args.total_max_length
15391545
)
15401546
elif "chatglmforcausallm" in config.architectures[0].lower():

paddlenlp/experimental/transformers/chatglm_v2/modeling.py

Lines changed: 198 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,17 @@
2121
from paddle.nn.quant import weight_quantize
2222

2323
from paddlenlp.experimental.transformers.fused_transformer_layers import (
24+
FusedBlockMultiTransformer,
25+
FusedBlockMultiTransformerWeightOnly,
2426
FusedMultiTransformerBase,
2527
FusedMultiTransformerConfig,
2628
FusedMultiTransformerWeightOnly,
2729
)
2830
from paddlenlp.experimental.transformers.generation_utils import (
31+
GenerationBlockInferenceModel,
2932
GenerationInferenceModel,
3033
)
34+
from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained
3135
from paddlenlp.transformers import ChatGLMv2Config, ChatGLMv2PretrainedModel
3236
from paddlenlp.transformers.chatglm_v2.modeling import (
3337
Embedding,
@@ -39,9 +43,7 @@
3943
register_base_model,
4044
)
4145

42-
__all__ = [
43-
"ChatGLMv2ForCausalLMInferenceModel",
44-
]
46+
__all__ = ["ChatGLMv2ForCausalLMInferenceModel", "ChatGLMv2ForCausalLMBlockInferenceModel"]
4547

4648

4749
@register_base_model
@@ -176,17 +178,20 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True):
176178
kv_num_heads=config.multi_query_group_num,
177179
)
178180

179-
if self.use_weight_only:
180-
self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config)
181-
else:
182-
self.transformer_block = FusedMultiTransformerBase(transformer_config)
181+
self.set_transformer_block(transformer_config)
183182

184183
self.post_layer_norm = config.post_layer_norm
185184
if self.post_layer_norm:
186185
LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm
187186
# Final layer norm before output.
188187
self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config)
189188

189+
def set_transformer_block(self, transformer_config):
190+
if self.use_weight_only:
191+
self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config)
192+
else:
193+
self.transformer_block = FusedMultiTransformerBase(transformer_config)
194+
190195
def get_input_embeddings(self):
191196
return self.embedding.word_embeddings
192197

@@ -341,7 +346,7 @@ def key(name):
341346

342347
if self.use_weight_only:
343348
linear_quanted_weight_tensor, linear_weight_scale_tensor = weight_quantize(
344-
out_proj_weight, algo=self.quant_algo
349+
paddle.to_tensor(out_proj_weight), algo=self.quant_algo
345350
)
346351
self.transformer_block.linear_weights[i].set_value(linear_quanted_weight_tensor)
347352
self.transformer_block.linear_weights_scale[i].set_value(linear_weight_scale_tensor)
@@ -352,7 +357,7 @@ def key(name):
352357

353358
if self.use_weight_only:
354359
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
355-
ffn1_weight, algo=self.quant_algo
360+
paddle.to_tensor(ffn1_weight), algo=self.quant_algo
356361
)
357362
self.transformer_block.ffn1_weights[i].set_value(ffn1_quanted_weight_tensor)
358363
self.transformer_block.ffn1_weights_scale[i].set_value(ffn1_weight_scale_tensor)
@@ -361,20 +366,87 @@ def key(name):
361366

362367
if self.use_weight_only:
363368
ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize(
364-
ffn2_weight, algo=self.quant_algo
369+
paddle.to_tensor(ffn2_weight), algo=self.quant_algo
365370
)
366371
self.transformer_block.ffn2_weights[i].set_value(ffn2_quanted_weight_tensor)
367372
self.transformer_block.ffn2_weights_scale[i].set_value(ffn2_weight_scale_tensor)
368373
else:
369374
self.transformer_block.ffn2_weights[i].set_value(ffn2_weight)
370375

371376

377+
@register_base_model
378+
class ChatGLMv2BlockInferenceModel(ChatGLMv2InferenceModel):
379+
def __init__(self, config: ChatGLMv2Config):
380+
super().__init__(config)
381+
self.max_seq_len = config.max_sequence_length
382+
self.block_size = config.block_size
383+
384+
def set_transformer_block(self, transformer_config):
385+
if self.use_weight_only:
386+
self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config)
387+
else:
388+
self.transformer_block = FusedBlockMultiTransformer(transformer_config)
389+
390+
def remove_padding(self, input_ids, seq_lens_this_time):
391+
cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time)
392+
token_num = paddle.sum(seq_lens_this_time)
393+
from paddlenlp_ops import get_padding_offset_v2
394+
395+
ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2(
396+
input_ids, cum_offsets_now, token_num, seq_lens_this_time
397+
)
398+
return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k
399+
400+
def forward(
401+
self,
402+
input_ids=None,
403+
attention_mask=None,
404+
inputs_embeds=None,
405+
caches=None,
406+
pre_caches=None,
407+
output_attentions=False,
408+
output_hidden_states=None,
409+
return_dict=False,
410+
**kwargs,
411+
):
412+
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
413+
rope_emb = kwargs.get("rope_emb", None)
414+
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
415+
input_ids, seq_lens_this_time
416+
)
417+
kwargs["cu_seqlens_q"] = cu_seqlens_q
418+
kwargs["cu_seqlens_k"] = cu_seqlens_k
419+
kwargs["padding_offsets"] = padding_offset
420+
kwargs["max_input_length"] = self.max_seq_len
421+
422+
inputs_embeds = self.embedding.word_embeddings(ids_remove_padding)
423+
424+
with dy2st_nocheck_guard_context():
425+
hidden_states, _ = self.transformer_block(
426+
input_ids=input_ids,
427+
src=inputs_embeds,
428+
cum_offsets=cum_offsets,
429+
attn_mask=attention_mask,
430+
caches=caches,
431+
pre_caches=None,
432+
rotary_embs=rope_emb,
433+
**kwargs,
434+
)
435+
hidden_states = self.final_layernorm(hidden_states)
436+
437+
return tuple(v for v in [hidden_states, None, None, None] if v is not None)
438+
439+
372440
class ChatGLMv2ForCausalLMInferenceModel(GenerationInferenceModel, ChatGLMv2PretrainedModel):
373441
def __init__(self, config: ChatGLMv2Config):
374442
super().__init__(config)
375443
self.max_sequence_length = config.max_sequence_length
376444
self.chatglm_v2 = ChatGLMv2InferenceModel(config)
377445

446+
@classmethod
447+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
448+
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)
449+
378450
@classmethod
379451
def get_cache_kvs_shape(cls, config: ChatGLMv2Config, max_batch_size: int = None, max_length: int = None):
380452
"""get cache_kvs tensor for opt model
@@ -487,3 +559,119 @@ def forward(
487559
@paddle.no_grad()
488560
def set_state_dict(self, state_dict):
489561
self.chatglm_v2.set_state_dict(state_dict)
562+
563+
564+
class ChatGLMv2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, ChatGLMv2PretrainedModel):
565+
def __init__(self, config):
566+
super().__init__(config)
567+
self.chatglm_v2 = ChatGLMv2BlockInferenceModel(config)
568+
self.max_sequence_length = config.max_sequence_length
569+
570+
@classmethod
571+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
572+
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)
573+
574+
@classmethod
575+
def get_cache_kvs_shape(cls, config: ChatGLMv2Config, max_batch_size: int = None, max_length: int = None):
576+
"""get cache_kvs tensor for chatglmv2 model
577+
578+
Args:
579+
max_batch_size (int): the max batch size
580+
max_length (int | None, optional): the max_length of cache_kvs. Defaults to None.
581+
582+
Returns:
583+
list[paddle.Tensor]: the list tensor shape for cache
584+
"""
585+
max_block_per_seq = (config.max_seq_len + config.block_size - 1) // config.block_size
586+
if max_batch_size == -1:
587+
max_block_nums = None
588+
else:
589+
max_block_nums = max_batch_size * max_block_per_seq
590+
591+
cache_kvs = []
592+
for _ in range(config.num_hidden_layers):
593+
cache_kv_shape = [
594+
max_block_nums,
595+
config.multi_query_group_num,
596+
config.block_size,
597+
config.hidden_size // config.num_attention_heads,
598+
]
599+
cache_kvs.append(cache_kv_shape)
600+
cache_kvs.append(cache_kv_shape)
601+
return cache_kvs
602+
603+
def prepare_inputs_for_generation(self, **kwargs):
604+
# only last token for inputs_ids if cache is defined in kwargs
605+
input_ids = kwargs["input_ids"]
606+
src_mask = kwargs.get("src_mask", None)
607+
block_tables = kwargs.get("block_tables", None)
608+
609+
pre_caches = kwargs.get("pre_caches", None)
610+
caches = kwargs.get("caches", None)
611+
612+
rope_emb = kwargs["rope_emb"]
613+
seq_lens_this_time = kwargs["seq_lens_this_time"]
614+
seq_lens_encoder = kwargs["seq_lens_encoder"]
615+
seq_lens_decoder = kwargs["seq_lens_decoder"]
616+
k_quant_scales = kwargs.get("k_quant_scales", None)
617+
v_quant_scales = kwargs.get("v_quant_scales", None)
618+
k_dequant_scales = kwargs.get("k_dequant_scales", None)
619+
v_dequant_scales = kwargs.get("v_dequant_scales", None)
620+
model_inputs = {
621+
"input_ids": input_ids,
622+
"src_mask": src_mask,
623+
"rope_emb": rope_emb,
624+
"pre_caches": pre_caches,
625+
"caches": caches,
626+
"seq_lens_this_time": seq_lens_this_time,
627+
"seq_lens_encoder": seq_lens_encoder,
628+
"seq_lens_decoder": seq_lens_decoder,
629+
"block_tables": block_tables,
630+
"k_quant_scales": k_quant_scales,
631+
"v_quant_scales": v_quant_scales,
632+
"k_dequant_scales": k_dequant_scales,
633+
"v_dequant_scales": v_dequant_scales,
634+
}
635+
return model_inputs
636+
637+
def forward(
638+
self,
639+
input_ids,
640+
src_mask=None,
641+
pre_caches=None,
642+
caches=None,
643+
seq_lens_this_time=None,
644+
seq_lens_encoder=None,
645+
seq_lens_decoder=None,
646+
rope_emb=None,
647+
block_tables=None,
648+
k_quant_scales=None,
649+
v_quant_scales=None,
650+
k_dequant_scales=None,
651+
v_dequant_scales=None,
652+
):
653+
outputs = self.chatglm_v2(
654+
input_ids,
655+
src_mask=src_mask,
656+
caches=caches,
657+
rope_emb=rope_emb,
658+
block_tables=block_tables,
659+
pre_caches=pre_caches,
660+
seq_lens_this_time=seq_lens_this_time,
661+
seq_lens_encoder=seq_lens_encoder,
662+
seq_lens_decoder=seq_lens_decoder,
663+
k_quant_scales=k_quant_scales,
664+
v_quant_scales=v_quant_scales,
665+
k_dequant_scales=k_dequant_scales,
666+
v_dequant_scales=v_dequant_scales,
667+
)
668+
669+
hidden_states = outputs[0]
670+
lm_logits = self.chatglm_v2.output_layer(hidden_states)
671+
output = (lm_logits,) + outputs[1:]
672+
673+
return output
674+
675+
@paddle.no_grad()
676+
def set_state_dict(self, state_dict):
677+
self.chatglm_v2.set_state_dict(state_dict)

0 commit comments

Comments
 (0)