Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
| [LLaMA](./config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Qwen](./config/qwen) | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ |
| [Mixtral](./config/mixtral) | ✅ | ✅ | ✅ | ❌ | 🚧 | 🚧 | 🚧 | 🚧 |
| [Mistral](./config/mistral) | ❌ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ |
| [Baichuan/Baichuan2](./config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ |
| [ChatGLM-6B](./config/chatglm) | ❌ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ❌ |
| [ChatGLM2/ChatGLM3](./config/chatglm2) | ❌ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ |
Expand Down
20 changes: 20 additions & 0 deletions llm/config/mistral/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Mistral

## 1. 模型介绍

**支持模型权重:**

| Model |
|--------------------------------------|
| mistralai/Mistral-7B-Instruct-v0.3 |
| mistralai/Mistral-7B-v0.1 |



使用方法:

```python
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
```
38 changes: 38 additions & 0 deletions llm/config/mistral/dpo_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3",
"train_dataset_path": "./dpo_data/train.jsonl",
"dev_dataset_path": "./dpo_data/train.jsonl",
"output_dir": "./checkpoints/dpo_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"per_device_eval_batch_size": 1,
"num_train_epochs": 1,
"max_steps": 100,
"learning_rate": 1e-06,
"warmup_steps": 10,
"logging_steps": 1,
"evaluation_strategy": "steps",
"save_strategy": "steps",
"eval_steps": 100,
"save_steps": 500,
"max_seq_len": 4096,
"max_prompt_len": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"tensor_parallel_degree": 8,
"sharding_parallel_degree": 1,
"sharding": "stage1",
"use_flash_attention": true,
"recompute": false,
"recompute_granularity": "full",
"dpo_beta": 0.1,
"benchmark": false,
"dpo_loss_type": "sigmoid",
"dpo_label_smoothing": 0.0,
"unified_checkpoint": true,
"autotuner_benchmark":false
}
32 changes: 32 additions & 0 deletions llm/config/mistral/lora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mistral_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"recompute": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"use_flash_attention": true,
"zero_padding": true,
"lora": true
}
30 changes: 30 additions & 0 deletions llm/config/mistral/pt_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mistral_pt_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-02,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": false,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"prefix_tuning": true
}
30 changes: 30 additions & 0 deletions llm/config/mistral/sft_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mistral_sft_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"zero_padding": true,
"tensor_parallel_degree": 8,
"pipeline_parallel_degree": 1
}
4 changes: 2 additions & 2 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,11 @@ def neft_post_hook(module, input, output):

if data_args.zero_padding:
if (
model.base_model_prefix not in ["llama", "bloom", "chatglm", "chatglm_v2", "qwen"]
model.base_model_prefix not in ["llama", "bloom", "chatglm", "chatglm_v2", "qwen", "mistral"]
and training_args.pipeline_parallel_degree < 1
):
raise NotImplementedError(
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far."
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM, QWen and Mistral so far."
)
train_ds = (
train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
Expand Down
1 change: 1 addition & 0 deletions llm/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_convert_example(model):
"opt",
"qwen",
"mixtral",
"mistral",
"gemma",
"qwen2",
"qwen2_moe",
Expand Down
20 changes: 20 additions & 0 deletions llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def get_prefix_tuning_params(model):
hidden_size = model.config.hidden_size
postprocess_past_key_value = llama_postprocess_past_key_value
multi_query_group_num = None
elif model.base_model_prefix == "mistral":
from paddlenlp.peft.prefix import mistral_postprocess_past_key_value

num_attention_heads = model.config.num_attention_heads
num_hidden_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
postprocess_past_key_value = mistral_postprocess_past_key_value
multi_query_group_num = model.config.num_key_value_heads
elif model.base_model_prefix == "qwen":
from paddlenlp.peft.prefix import qwen_postprocess_past_key_value

Expand Down Expand Up @@ -190,6 +198,17 @@ def get_lora_target_modules(model):
".*w2.*",
".*w3.*",
]
elif model.base_model_prefix == "mistral":
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
".*gate.*",
".*w1.*",
".*w2.*",
".*w3.*",
]
elif model.base_model_prefix == "qwen2_moe":
target_modules = [
".*q_proj.*",
Expand Down Expand Up @@ -279,6 +298,7 @@ def prediction_step(
)[0]
all_preds = []
for pred_tokens in generated_tokens:
pred_tokens = pred_tokens.numpy()
pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist()
all_preds.append(pred_tokens)
max_pred_length = max([len(x) for x in all_preds])
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/peft/prefix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
bloom_postprocess_past_key_value,
chatglm_postprocess_past_key_value,
llama_postprocess_past_key_value,
mistral_postprocess_past_key_value,
qwen_postprocess_past_key_value,
)
7 changes: 7 additions & 0 deletions paddlenlp/peft/prefix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def llama_postprocess_past_key_value(past_key_values):
return tuple(zip(keys, values))


def mistral_postprocess_past_key_value(past_key_values):
# (layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim)*2
keys, values = paddle.transpose(past_key_values, perm=[2, 0, 3, 1, 4]).split(2)

return tuple(zip(keys, values))


def qwen_postprocess_past_key_value(past_key_values):
# (layer_num, bs, prefixlen, head_num/tensor_parallel_degree, head_dim)*2
keys, values = paddle.transpose(past_key_values, perm=[2, 0, 1, 3, 4]).split(2)
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@
from .rw.modeling import *
from .rw.configuration import *
from .rw.tokenizer import *
from .mistral.modeling import *
from .mistral.configuration import *
from .qwen import *
from .mixtral.modeling import *
from .mixtral.configuration import *
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
("Blip", "blip"),
("Bloom", "bloom"),
("QWen", "qwen"),
("Mistral", "mistral"),
("Mixtral", "mixtral"),
("Qwen2", "qwen2"),
("Qwen2Moe", "qwen2_moe"),
Expand Down
15 changes: 15 additions & 0 deletions paddlenlp/transformers/mistral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration import MistralConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mistral为什么没有tokenizer文件

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from .modeling import MistralForCausalLM
69 changes: 69 additions & 0 deletions paddlenlp/transformers/mistral/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mistral model configuration"""

from ..configuration_utils import PretrainedConfig


class MistralConfig(PretrainedConfig):
model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading