Skip to content

Commit 87e4c4f

Browse files
authored
[NPU] support npu llama2-13B export & inference (#8442)
* [NPU] support npu llama2-13B export & inference * move csrc_npu to csrc/npu
1 parent 9064078 commit 87e4c4f

File tree

7 files changed

+241
-16
lines changed

7 files changed

+241
-16
lines changed

csrc/npu/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# PaddleNLP 自定义 OP
2+
3+
此文档介绍如何编译安装 PaddleNLP NPU 自定义 OP。
4+
5+
# 1. 安装 PaddleCustomDevice
6+
7+
参考 [PaddleCustomDevice NPU 安装文档](https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/npu/README_cn.md) 进行安装
8+
9+
# 2. 安装 paddlenlp_ops
10+
```shell
11+
python setup.py build bdist_wheel
12+
13+
pip install dist/paddlenlp_ops*.whl
14+
```
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from paddle_custom_device.npu.ops import *

csrc/npu/setup.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from setuptools import Distribution, setup
18+
19+
packages = []
20+
package_data = {}
21+
22+
23+
class BinaryDistribution(Distribution):
24+
def has_ext_modules(self):
25+
return True
26+
27+
28+
def main():
29+
setup(
30+
name="paddlenlp_ops",
31+
version="0.0.0",
32+
description="PaddleNLP NPU CustomOps",
33+
long_description="",
34+
long_description_content_type="text/markdown",
35+
author_email="Paddle-better@baidu.com",
36+
maintainer="PaddlePaddle",
37+
maintainer_email="Paddle-better@baidu.com",
38+
project_urls={},
39+
license="Apache Software License",
40+
packages=[
41+
"paddlenlp_ops",
42+
],
43+
include_package_data=True,
44+
package_data={
45+
"": ["*.py"],
46+
},
47+
package_dir={
48+
"": "python",
49+
},
50+
zip_safe=False,
51+
distclass=BinaryDistribution,
52+
entry_points={"console_scripts": []},
53+
classifiers=[],
54+
keywords="PaddleNLP NPU CustomOps",
55+
)
56+
57+
58+
if __name__ == "__main__":
59+
main()

llm/export_model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ def load_inference_model(model_path, model_name, param_name, exe):
4040
return paddle.static.io.load_inference_model(model_path, exe)
4141

4242

43-
def validate_pdmodel(model_path, model_prefix):
43+
def validate_pdmodel(model_path, model_prefix, device):
4444
paddle.enable_static()
45-
place = paddle.CUDAPlace(0)
45+
if device == "gpu":
46+
place = paddle.CUDAPlace(0)
47+
else:
48+
place = paddle.CustomPlace(device, 0)
4649
exe = paddle.static.Executor(place)
4750
scope = paddle.static.Scope()
4851

@@ -95,7 +98,12 @@ def main():
9598

9699
if tensor_parallel_degree > 1:
97100
export_args.output_path = os.path.join(export_args.output_path, f"rank_{tensor_parallel_rank}")
98-
validate_pdmodel(export_args.output_path, predictor_args.model_prefix)
101+
validate_pdmodel(export_args.output_path, predictor_args.model_prefix, predictor_args.device)
102+
103+
if predictor_args.device == "npu":
104+
from llama.npu.export_utils import process_params
105+
106+
process_params(os.path.join(export_args.output_path, predictor_args.model_prefix))
99107

100108

101109
if __name__ == "__main__":

llm/llama/npu/export_utils.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
import numpy as np
18+
import paddle
19+
from tqdm import tqdm
20+
21+
22+
def parse_arguments():
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--model_path", default="inference/model", help="The directory of exported model.")
25+
return parser.parse_args()
26+
27+
28+
def trans_weight(var):
29+
shape = var.desc.shape()
30+
new_shape = [shape[1], shape[0]]
31+
var.desc.set_shape(new_shape)
32+
33+
var_data = np.array(var.get_value())
34+
var.get_value().set(var_data.T, paddle.CPUPlace())
35+
36+
37+
def convert_dequant_scale(var):
38+
deq_scale = np.array(var.get_value()).astype(np.float32)
39+
new_deq_scale = np.stack([deq_scale.reshape(-1, 1), np.zeros_like(deq_scale).reshape(-1, 1)], axis=-1).reshape(-1)
40+
var.get_value().set(np.frombuffer(new_deq_scale.tobytes(), dtype=np.int64), paddle.CPUPlace())
41+
42+
43+
def process_params(model_path):
44+
paddle.enable_static()
45+
exe = paddle.static.Executor(paddle.CPUPlace())
46+
47+
prog = paddle.static.Program()
48+
startup_prog = paddle.static.Program()
49+
scope = paddle.static.Scope()
50+
with paddle.base.scope_guard(scope):
51+
with paddle.base.program_guard(prog, startup_prog):
52+
[program, feed_target_names, fetch_targets] = paddle.static.io.load_inference_model(model_path, exe)
53+
54+
feed_targets = []
55+
for var in program.list_vars():
56+
if var.name in feed_target_names:
57+
feed_targets.append(var)
58+
59+
block = program.global_block()
60+
61+
for op in tqdm(block.ops, desc="processing the linear layer for NPU"):
62+
if op.type == "matmul_v2":
63+
w_name = op.input_arg_names[-1]
64+
if w_name.endswith("qkv_weight") and op.attr("trans_y") == False:
65+
op._set_attr("trans_y", True)
66+
w = block.var(w_name)
67+
trans_weight(w)
68+
elif w_name.endswith("out_proj_weight") and op.attr("trans_y") == False:
69+
op._set_attr("trans_y", True)
70+
w = block.var(w_name)
71+
trans_weight(w)
72+
elif w_name.endswith("ffn1_weight") and op.attr("trans_y") == False:
73+
op._set_attr("trans_y", True)
74+
w = block.var(w_name)
75+
trans_weight(w)
76+
elif w_name.endswith("ffn2_weight") and op.attr("trans_y") == False:
77+
op._set_attr("trans_y", True)
78+
w = block.var(w_name)
79+
trans_weight(w)
80+
elif w_name == "llama_lm_head_0.w_0" and op.attr("trans_y") == False:
81+
op._set_attr("trans_y", True)
82+
w = block.var(w_name)
83+
trans_weight(w)
84+
85+
for var_name in tqdm(block.vars, desc="processing the dequant layer for NPU"):
86+
if var_name.endswith("qkv_out_scale"):
87+
var = block.var(var_name)
88+
convert_dequant_scale(var)
89+
elif var_name.endswith("linear_out_scale"):
90+
var = block.var(var_name)
91+
convert_dequant_scale(var)
92+
elif var_name.endswith("ffn1_out_scale"):
93+
var = block.var(var_name)
94+
convert_dequant_scale(var)
95+
elif var_name.endswith("ffn2_out_scale"):
96+
var = block.var(var_name)
97+
convert_dequant_scale(var)
98+
99+
paddle.static.save_inference_model(
100+
model_path, feed_targets, fetch_targets, exe, program=program, skip_prune_program=True
101+
)
102+
103+
104+
def main():
105+
args = parse_arguments()
106+
process_params(args.model_path)
107+
108+
109+
if __name__ == "__main__":
110+
main()

llm/predictor.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,12 @@ def _create_predictor(self, predictor_args: PredictorArgument):
647647
if predictor_args.dtype == "bfloat16":
648648
config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")
649649

650-
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
651-
config.enable_use_gpu(100, device_id)
650+
if predictor_args.device in paddle.device.get_all_custom_device_type():
651+
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
652+
config.enable_custom_device(predictor_args.device, device_id)
653+
else:
654+
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
655+
config.enable_use_gpu(100, device_id)
652656
config.enable_new_executor()
653657

654658
if self.tensor_parallel_degree > 1:
@@ -793,6 +797,8 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
793797
self.free_list = [i for i in range(self.max_block_nums)][::-1]
794798
self.used_list = [[] for _ in range(config.batch_size)]
795799

800+
self.benchmark = config.benchmark
801+
796802
def init_inputs(self, config: PredictorArgument):
797803
self.inputs = {}
798804

@@ -909,19 +915,20 @@ def _get_rotary_position_embedding(self, position_ids, head_dim):
909915
return rot_emb
910916

911917
def _preprocess(self, source):
912-
if self.tokenizer.chat_template is not None:
918+
if not self.benchmark and self.tokenizer.chat_template is not None:
913919
source = [source] if isinstance(source, str) else source
914920
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]
915921

916922
for i, text in enumerate(source):
923+
add_special_tokens = self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer))
924+
add_special_tokens = add_special_tokens if not self.benchmark else False
917925
tokens = self.tokenizer(
918926
text,
919927
return_tensors="np",
920928
padding=True,
921929
max_length=self.config.src_length,
922930
# if use chat_template, it will not add special_tokens
923-
add_special_tokens=self.tokenizer.chat_template is None
924-
or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)),
931+
add_special_tokens=add_special_tokens,
925932
)
926933
input_ids = tokens["input_ids"][0]
927934
length = len(input_ids)
@@ -1066,11 +1073,22 @@ def _create_predictor(self, predictor_args: PredictorArgument):
10661073
config = paddle.inference.Config(infer_model_path + ".pdmodel", infer_model_path + ".pdiparams")
10671074

10681075
config.switch_ir_optim(False)
1069-
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
1070-
config.enable_use_gpu(100, device_id)
1076+
if predictor_args.device in paddle.device.get_all_custom_device_type():
1077+
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
1078+
config.enable_custom_device(predictor_args.device, device_id)
1079+
else:
1080+
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
1081+
config.enable_use_gpu(100, device_id)
10711082
# config.disable_glog_info()
10721083
# config.enable_memory_optim()
10731084

1085+
if predictor_args.device == "npu":
1086+
import paddle_custom_device.npu.passes as passes
1087+
1088+
config.switch_ir_optim(True)
1089+
pass_builder = config.pass_builder()
1090+
passes.addPasses(pass_builder, self.model_config.model_type, self.model_config.quant_type)
1091+
10741092
if self.tensor_parallel_degree > 1:
10751093
trainer_endpoints = fleet.worker_endpoints()
10761094
current_endpoint = trainer_endpoints[self.tensor_parallel_rank]
@@ -1516,6 +1534,11 @@ def predict():
15161534
fleet.init(is_collective=True, strategy=strategy)
15171535

15181536
predictor = create_predictor(predictor_args, model_args)
1537+
1538+
if predictor_args.benchmark:
1539+
benchmark(predictor, predictor_args, model_args)
1540+
return
1541+
15191542
source_texts = []
15201543
target_texts = []
15211544
if model_args.data_file:
@@ -1559,14 +1582,10 @@ def predict():
15591582
out = {"src": source, "tgt": target, "output": output}
15601583
f.write(json.dumps(out, ensure_ascii=False) + "\n")
15611584

1562-
if predictor_args.benchmark:
1563-
benchmark(predictor, predictor_args, model_args)
1564-
15651585

15661586
def benchmark(predictor, predictor_args, model_args):
15671587
# Just construct a simple benchmark input. We pad input to the src_length.
1568-
test_texts = "hello world, how are you?"
1569-
benchmark_texts = [test_texts + "<pad>" * predictor_args.src_length for _ in range(predictor_args.batch_size)]
1588+
benchmark_texts = [predictor.tokenizer.pad_token * predictor_args.src_length for _ in range(predictor_args.batch_size)]
15701589

15711590
batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size)
15721591
print("***********Start Benchmark**********")

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def compute_layernorm_before_qkv(self, src, i):
570570
return ln_out
571571

572572
def compute_qkv_linear(self, ln_out, i):
573-
if float(paddle.version.cuda()) < 11.6:
573+
if paddle.version.cuda() == "False" or float(paddle.version.cuda()) < 11.6:
574574
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i], False, True)
575575
if self.qkv_biases[i] is not None:
576576
qkv_out = paddle.add(qkv_out, self.qkv_biases[i])

0 commit comments

Comments
 (0)