Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ class PredictorArgument:
},
)

@property
def total_max_length(self):
if self.device == "npu":
return self.src_length + self.max_length
else:
return 8192 # Maximum sequence length.
total_max_length: int = field(
default=4096, metadata={"help": "Super parameter. Maximum sequence length(encoder+decoder)."}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个跟npu相关同学确认的吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已确认,没问题。

)

def __post_init__(self):
if self.append_attn:
self.block_attn = True
assert (
self.src_length + self.max_length <= self.total_max_length
), "src_length + max_length should smaller than total_max_length."


@dataclass
Expand Down Expand Up @@ -520,7 +520,7 @@ def _preprocess(self, source):
alibi_slopes = llm_utils.get_alibi_slopes(self.model_config.n_head)
inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32")
arange_tensor_encoder = paddle.arange(self.config.total_max_length, dtype=self.config.dtype)
alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder
alibi = (alibi_slopes[None, :, None, None] * arange_tensor_encoder).astype(self.config.dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emm,这个 config dtype保险吗?用户可以改这个值。要不用里面一个tensor的dtype。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个dtype确实需要与config.dtype保持一致的

if self.model_config.tensor_parallel_degree > 1:
block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree
Expand Down Expand Up @@ -1352,13 +1352,19 @@ def create_predictor(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()

elif "chatglmv2forcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel as Model,
)

model = Model.from_pretrained(
predictor_args.total_max_length = config.seq_length
if predictor_args.block_attn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emm,我建议吧 block_attn 放到config的属性里面,然后 ChatGLMv2InferenceModel 里面自己控制。
这里改的话,后期这样修改的模型太多了。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但严格来说其实这个不属于每个模型的Config,如果加入如LlamaConfig的话,每个模型的Config里都需要加,先保持这样吧,后面重构的时候,会看下有没有更好的方式

config.block_size = predictor_args.block_size
config.max_seq_len = predictor_args.total_max_length
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMBlockInferenceModel as ChatGLMv2InferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel as ChatGLMv2InferenceModel,
)
model = ChatGLMv2InferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
Expand Down Expand Up @@ -1522,19 +1528,19 @@ def create_predictor(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "chatglmv2forcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel,
)

cache_kvs_shape = ChatGLMv2ForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "chatglmv2forcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel,
)
predictor_args.total_max_length = config.seq_length
if predictor_args.block_attn:
config.block_size = predictor_args.block_size
config.max_seq_len = predictor_args.total_max_length
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMBlockInferenceModel as ChatGLMv2InferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel as ChatGLMv2InferenceModel,
)

cache_kvs_shape = ChatGLMv2ForCausalLMInferenceModel.get_cache_kvs_shape(
cache_kvs_shape = ChatGLMv2InferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "chatglmforcausallm" in config.architectures[0].lower():
Expand Down
208 changes: 198 additions & 10 deletions paddlenlp/experimental/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
from paddle.nn.quant import weight_quantize

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedBlockMultiTransformer,
FusedBlockMultiTransformerWeightOnly,
FusedMultiTransformerBase,
FusedMultiTransformerConfig,
FusedMultiTransformerWeightOnly,
)
from paddlenlp.experimental.transformers.generation_utils import (
GenerationBlockInferenceModel,
GenerationInferenceModel,
)
from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained

Check warning on line 34 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L34

Added line #L34 was not covered by tests
from paddlenlp.transformers import ChatGLMv2Config, ChatGLMv2PretrainedModel
from paddlenlp.transformers.chatglm_v2.modeling import (
Embedding,
Expand All @@ -39,9 +43,7 @@
register_base_model,
)

__all__ = [
"ChatGLMv2ForCausalLMInferenceModel",
]
__all__ = ["ChatGLMv2ForCausalLMInferenceModel", "ChatGLMv2ForCausalLMBlockInferenceModel"]

Check warning on line 46 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L46

Added line #L46 was not covered by tests


@register_base_model
Expand Down Expand Up @@ -176,17 +178,20 @@
kv_num_heads=config.multi_query_group_num,
)

if self.use_weight_only:
self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config)
else:
self.transformer_block = FusedMultiTransformerBase(transformer_config)
self.set_transformer_block(transformer_config)

Check warning on line 181 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L181

Added line #L181 was not covered by tests

self.post_layer_norm = config.post_layer_norm
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config)

def set_transformer_block(self, transformer_config):
if self.use_weight_only:
self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config)

Check warning on line 191 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L189-L191

Added lines #L189 - L191 were not covered by tests
else:
self.transformer_block = FusedMultiTransformerBase(transformer_config)

Check warning on line 193 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L193

Added line #L193 was not covered by tests

def get_input_embeddings(self):
return self.embedding.word_embeddings

Expand Down Expand Up @@ -341,7 +346,7 @@

if self.use_weight_only:
linear_quanted_weight_tensor, linear_weight_scale_tensor = weight_quantize(
out_proj_weight, algo=self.quant_algo
paddle.to_tensor(out_proj_weight), algo=self.quant_algo
)
self.transformer_block.linear_weights[i].set_value(linear_quanted_weight_tensor)
self.transformer_block.linear_weights_scale[i].set_value(linear_weight_scale_tensor)
Expand All @@ -352,7 +357,7 @@

if self.use_weight_only:
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
ffn1_weight, algo=self.quant_algo
paddle.to_tensor(ffn1_weight), algo=self.quant_algo
)
self.transformer_block.ffn1_weights[i].set_value(ffn1_quanted_weight_tensor)
self.transformer_block.ffn1_weights_scale[i].set_value(ffn1_weight_scale_tensor)
Expand All @@ -361,20 +366,87 @@

if self.use_weight_only:
ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize(
ffn2_weight, algo=self.quant_algo
paddle.to_tensor(ffn2_weight), algo=self.quant_algo
)
self.transformer_block.ffn2_weights[i].set_value(ffn2_quanted_weight_tensor)
self.transformer_block.ffn2_weights_scale[i].set_value(ffn2_weight_scale_tensor)
else:
self.transformer_block.ffn2_weights[i].set_value(ffn2_weight)


@register_base_model
class ChatGLMv2BlockInferenceModel(ChatGLMv2InferenceModel):
def __init__(self, config: ChatGLMv2Config):
super().__init__(config)
self.max_seq_len = config.max_sequence_length
self.block_size = config.block_size

Check warning on line 382 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L377-L382

Added lines #L377 - L382 were not covered by tests

def set_transformer_block(self, transformer_config):
if self.use_weight_only:
self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config)

Check warning on line 386 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L384-L386

Added lines #L384 - L386 were not covered by tests
else:
self.transformer_block = FusedBlockMultiTransformer(transformer_config)

Check warning on line 388 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L388

Added line #L388 was not covered by tests

def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset_v2

Check warning on line 393 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L390-L393

Added lines #L390 - L393 were not covered by tests

ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2(

Check warning on line 395 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L395

Added line #L395 was not covered by tests
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k

Check warning on line 398 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L398

Added line #L398 was not covered by tests

def forward(

Check warning on line 400 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L400

Added line #L400 was not covered by tests
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
caches=None,
pre_caches=None,
output_attentions=False,
output_hidden_states=None,
return_dict=False,
**kwargs,
):
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
rope_emb = kwargs.get("rope_emb", None)
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(

Check warning on line 414 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L412-L414

Added lines #L412 - L414 were not covered by tests
input_ids, seq_lens_this_time
)
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_k"] = cu_seqlens_k
kwargs["padding_offsets"] = padding_offset
kwargs["max_input_length"] = self.max_seq_len

Check warning on line 420 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L417-L420

Added lines #L417 - L420 were not covered by tests

inputs_embeds = self.embedding.word_embeddings(ids_remove_padding)

Check warning on line 422 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L422

Added line #L422 was not covered by tests

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(

Check warning on line 425 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L424-L425

Added lines #L424 - L425 were not covered by tests
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
attn_mask=attention_mask,
caches=caches,
pre_caches=None,
rotary_embs=rope_emb,
**kwargs,
)
hidden_states = self.final_layernorm(hidden_states)

Check warning on line 435 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L435

Added line #L435 was not covered by tests

return tuple(v for v in [hidden_states, None, None, None] if v is not None)

Check warning on line 437 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L437

Added line #L437 was not covered by tests


class ChatGLMv2ForCausalLMInferenceModel(GenerationInferenceModel, ChatGLMv2PretrainedModel):
def __init__(self, config: ChatGLMv2Config):
super().__init__(config)
self.max_sequence_length = config.max_sequence_length
self.chatglm_v2 = ChatGLMv2InferenceModel(config)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)

Check warning on line 448 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L446-L448

Added lines #L446 - L448 were not covered by tests

@classmethod
def get_cache_kvs_shape(cls, config: ChatGLMv2Config, max_batch_size: int = None, max_length: int = None):
"""get cache_kvs tensor for opt model
Expand Down Expand Up @@ -487,3 +559,119 @@
@paddle.no_grad()
def set_state_dict(self, state_dict):
self.chatglm_v2.set_state_dict(state_dict)


class ChatGLMv2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, ChatGLMv2PretrainedModel):
def __init__(self, config):
super().__init__(config)
self.chatglm_v2 = ChatGLMv2BlockInferenceModel(config)
self.max_sequence_length = config.max_sequence_length

Check warning on line 568 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L564-L568

Added lines #L564 - L568 were not covered by tests

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)

Check warning on line 572 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L570-L572

Added lines #L570 - L572 were not covered by tests

@classmethod
def get_cache_kvs_shape(cls, config: ChatGLMv2Config, max_batch_size: int = None, max_length: int = None):

Check warning on line 575 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L574-L575

Added lines #L574 - L575 were not covered by tests
"""get cache_kvs tensor for chatglmv2 model

Args:
max_batch_size (int): the max batch size
max_length (int | None, optional): the max_length of cache_kvs. Defaults to None.

Returns:
list[paddle.Tensor]: the list tensor shape for cache
"""
max_block_per_seq = (config.max_seq_len + config.block_size - 1) // config.block_size
if max_batch_size == -1:
max_block_nums = None

Check warning on line 587 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L585-L587

Added lines #L585 - L587 were not covered by tests
else:
max_block_nums = max_batch_size * max_block_per_seq

Check warning on line 589 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L589

Added line #L589 was not covered by tests

cache_kvs = []
for _ in range(config.num_hidden_layers):
cache_kv_shape = [

Check warning on line 593 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L591-L593

Added lines #L591 - L593 were not covered by tests
max_block_nums,
config.multi_query_group_num,
config.block_size,
config.hidden_size // config.num_attention_heads,
]
cache_kvs.append(cache_kv_shape)
cache_kvs.append(cache_kv_shape)
return cache_kvs

Check warning on line 601 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L599-L601

Added lines #L599 - L601 were not covered by tests

def prepare_inputs_for_generation(self, **kwargs):

Check warning on line 603 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L603

Added line #L603 was not covered by tests
# only last token for inputs_ids if cache is defined in kwargs
input_ids = kwargs["input_ids"]
src_mask = kwargs.get("src_mask", None)
block_tables = kwargs.get("block_tables", None)

Check warning on line 607 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L605-L607

Added lines #L605 - L607 were not covered by tests

pre_caches = kwargs.get("pre_caches", None)
caches = kwargs.get("caches", None)

Check warning on line 610 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L609-L610

Added lines #L609 - L610 were not covered by tests

rope_emb = kwargs["rope_emb"]
seq_lens_this_time = kwargs["seq_lens_this_time"]
seq_lens_encoder = kwargs["seq_lens_encoder"]
seq_lens_decoder = kwargs["seq_lens_decoder"]
k_quant_scales = kwargs.get("k_quant_scales", None)
v_quant_scales = kwargs.get("v_quant_scales", None)
k_dequant_scales = kwargs.get("k_dequant_scales", None)
v_dequant_scales = kwargs.get("v_dequant_scales", None)
model_inputs = {

Check warning on line 620 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L612-L620

Added lines #L612 - L620 were not covered by tests
"input_ids": input_ids,
"src_mask": src_mask,
"rope_emb": rope_emb,
"pre_caches": pre_caches,
"caches": caches,
"seq_lens_this_time": seq_lens_this_time,
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"block_tables": block_tables,
"k_quant_scales": k_quant_scales,
"v_quant_scales": v_quant_scales,
"k_dequant_scales": k_dequant_scales,
"v_dequant_scales": v_dequant_scales,
}
return model_inputs

Check warning on line 635 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L635

Added line #L635 was not covered by tests

def forward(

Check warning on line 637 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L637

Added line #L637 was not covered by tests
self,
input_ids,
src_mask=None,
pre_caches=None,
caches=None,
seq_lens_this_time=None,
seq_lens_encoder=None,
seq_lens_decoder=None,
rope_emb=None,
block_tables=None,
k_quant_scales=None,
v_quant_scales=None,
k_dequant_scales=None,
v_dequant_scales=None,
):
outputs = self.chatglm_v2(

Check warning on line 653 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L653

Added line #L653 was not covered by tests
input_ids,
src_mask=src_mask,
caches=caches,
rope_emb=rope_emb,
block_tables=block_tables,
pre_caches=pre_caches,
seq_lens_this_time=seq_lens_this_time,
seq_lens_encoder=seq_lens_encoder,
seq_lens_decoder=seq_lens_decoder,
k_quant_scales=k_quant_scales,
v_quant_scales=v_quant_scales,
k_dequant_scales=k_dequant_scales,
v_dequant_scales=v_dequant_scales,
)

hidden_states = outputs[0]
lm_logits = self.chatglm_v2.output_layer(hidden_states)
output = (lm_logits,) + outputs[1:]

Check warning on line 671 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L669-L671

Added lines #L669 - L671 were not covered by tests

return output

Check warning on line 673 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L673

Added line #L673 was not covered by tests

@paddle.no_grad()
def set_state_dict(self, state_dict):
self.chatglm_v2.set_state_dict(state_dict)

Check warning on line 677 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L675-L677

Added lines #L675 - L677 were not covered by tests
Loading
Loading