Skip to content

Commit 55aa005

Browse files
authored
Add NLP distillation show case. (#122)
1 parent 91ec356 commit 55aa005

File tree

8 files changed

+910
-0
lines changed

8 files changed

+910
-0
lines changed

example/distill/nlp/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# ERNIE distillation
2+
We show how to distill knowledge from ERNIE to a mini model: BOW and other models on Chinese sentiment task.
3+
4+
## Quick start
5+
### Download dataset
6+
```
7+
wget https://paddle-edl.bj.bcebos.com/distillation/chnsenticorp/data.tgz
8+
tar -xzvf data.tgz
9+
```
10+
11+
### Get the teacher model
12+
```
13+
nohup python -u ./fine_tune.py > finetune.log 2>&1 &
14+
```
15+
16+
When the job completes, the directories needed for distillation: `ernie_senti_server` and `ernie_senti_client` will be generated.
17+
18+
### Or download the teacher model directly
19+
You can also download the teacher model directly and then you needn't generate the model yourself.
20+
21+
```
22+
wget https://paddle-edl.bj.bcebos.com/distillation/chnsenticorp/ernie_senti.tgz
23+
tar -xzvf ernie_senti.tgz
24+
```
25+
26+
### Start a local teacher
27+
```
28+
nohup python -m paddle_serving_server_gpu.serve \
29+
--model ./ernie_senti_server/ \
30+
--port 19290 \
31+
--thread 8 \
32+
--mem_optim \
33+
--gpu_ids 0 > teatcher.log 2>&1 &
34+
```
35+
36+
### Start a student
37+
Now the student is BOW. CNN, LSTM, tiny ernie will be added later.
38+
39+
```
40+
python -u distill.py --fixed_teacher 127.0.0.1:19290
41+
```
42+
43+
### Result
44+
| model | dev dataset(acc) | test dataset(acc) |
45+
| :----: | :-----: | :----: |
46+
| BOW | 0.901 | 0.908 |
47+
| BOW + distillation | 0.905 | 0.915 |

example/distill/nlp/distill.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import os
17+
18+
import numpy as np
19+
import argparse
20+
from sklearn.metrics import f1_score, accuracy_score
21+
import paddle as P
22+
import paddle.fluid as F
23+
import paddle.fluid.layers as L
24+
import paddle.fluid.dygraph as D
25+
from reader import ChnSentiCorp, pad_batch_data
26+
from paddle_edl.distill.distill_reader import DistillReader
27+
import re
28+
29+
import os
30+
import sys
31+
from paddle_serving_client import Client
32+
from paddle_serving_app.reader import ChineseBertReader
33+
from model import CNN, AdamW, evaluate_student, KL, BOW, KL_T
34+
35+
parser = argparse.ArgumentParser(__doc__)
36+
parser.add_argument(
37+
"--fixed_teacher",
38+
type=str,
39+
default=None,
40+
help="fixed teacher for debug local distill")
41+
parser.add_argument(
42+
"--s_weight", type=float, default=0.5, help="weight of student in loss")
43+
parser.add_argument(
44+
"--epoch_num", type=int, default=10, help="weight of student in loss")
45+
parser.add_argument(
46+
"--weight_decay",
47+
type=float,
48+
default=0.01,
49+
help="weight of student in loss")
50+
parser.add_argument(
51+
"--opt", type=str, default="AdamW", help="weight of student in loss")
52+
parser.add_argument("--train_range", type=int, default=10, help="train range")
53+
parser.add_argument(
54+
"--use_data_au", type=int, default=1, help="use data augmentation")
55+
parser.add_argument(
56+
"--T", type=float, default=2.0, help="weight of student in loss")
57+
args = parser.parse_args()
58+
print("parsed args:", args)
59+
60+
g_max_dev_acc = []
61+
g_max_test_acc = []
62+
63+
64+
def train_with_distill(train_reader, dev_reader, word_dict, test_reader,
65+
epoch_num):
66+
boundaries = [2250 * 2, 2250 * 4, 2250 * 6]
67+
values = [1e-4, 1.5e-4, 2.5e-4, 4e-4]
68+
lr = D.PiecewiseDecay(boundaries, values, 0)
69+
model = BOW(word_dict)
70+
if args.opt == "Adam":
71+
opt = F.optimizer.Adam(
72+
learning_rate=lr,
73+
parameter_list=model.parameters(),
74+
regularization=F.regularizer.L2Decay(
75+
regularization_coeff=args.weight_decay))
76+
else:
77+
opt = AdamW(
78+
learning_rate=lr,
79+
parameter_list=model.parameters(),
80+
weight_decay=args.weight_decay)
81+
82+
model.train()
83+
max_dev_acc = 0.0
84+
max_test_acc = 0.0
85+
for epoch in range(epoch_num):
86+
for step, output in enumerate(train_reader()):
87+
(_, _, _, _, ids_student, labels, logits_t) = output
88+
89+
ids_student = D.base.to_variable(
90+
pad_batch_data(ids_student, 'int64'))
91+
labels = D.base.to_variable(np.array(labels).astype('int64'))
92+
logits_t = D.base.to_variable(np.array(logits_t).astype('float32'))
93+
logits_t.stop_gradient = True
94+
95+
_, logits_s = model(ids_student)
96+
loss_ce, _ = model(ids_student, labels=labels)
97+
98+
if args.T is None:
99+
loss_kd = KL(logits_s, logits_t)
100+
loss = args.s_weight * loss_ce + (1.0 - args.s_weight
101+
) * loss_kd
102+
else:
103+
loss_kd = KL_T(logits_s, logits_t, args.T)
104+
loss = args.T * args.T * (args.s_weight * loss_ce +
105+
(1.0 - args.s_weight) * loss_kd)
106+
107+
loss = L.reduce_mean(loss)
108+
loss.backward()
109+
if step % 10 == 0:
110+
print('[step %03d] distill train loss %.5f lr %.3e' %
111+
(step, loss.numpy(), opt.current_step_lr()))
112+
opt.minimize(loss)
113+
model.clear_gradients()
114+
f1, acc = evaluate_student(model, dev_reader)
115+
print('student on dev f1 %.5f acc %.5f' % (f1, acc))
116+
117+
if max_dev_acc < acc:
118+
max_dev_acc = acc
119+
120+
f1, acc = evaluate_student(model, test_reader)
121+
print('student on test f1 %.5f acc %.5f' % (f1, acc))
122+
123+
if max_test_acc < acc:
124+
max_test_acc = acc
125+
126+
g_max_dev_acc.append(g_max_dev_acc)
127+
g_max_test_acc.append(g_max_test_acc)
128+
129+
130+
def ernie_reader(s_reader, key_list):
131+
bert_reader = ChineseBertReader({
132+
'max_seq_len': 256,
133+
"vocab_file": "./data/vocab.txt"
134+
})
135+
136+
def reader():
137+
for (ids_student, labels, ss) in s_reader():
138+
b = {}
139+
for k in key_list:
140+
b[k] = []
141+
142+
for s in ss:
143+
feed_dict = bert_reader.process(s)
144+
for k in feed_dict:
145+
b[k].append(feed_dict[k])
146+
b["ids_student"] = ids_student
147+
b["labels"] = labels
148+
149+
l = []
150+
for k in key_list:
151+
l.append(b[k])
152+
153+
yield l
154+
155+
return reader
156+
157+
158+
if __name__ == "__main__":
159+
place = F.CUDAPlace(0)
160+
D.guard(place).__enter__()
161+
162+
ds = ChnSentiCorp()
163+
word_dict = ds.student_word_dict("./data/vocab.bow.txt")
164+
batch_size = 16
165+
166+
input_files = []
167+
if args.use_data_au:
168+
for i in range(1, 5):
169+
input_files.append("./data/train-data-augmented/part.{}".format(i))
170+
else:
171+
input_files.append("./data/train.part.0")
172+
173+
# student train and dev
174+
train_reader = ds.pad_batch_reader(
175+
input_files, word_dict, batch_size=batch_size)
176+
dev_reader = ds.pad_batch_reader(
177+
"./data/dev.part.0", word_dict, batch_size=batch_size)
178+
test_reader = ds.pad_batch_reader(
179+
"./data/test.part.0", word_dict, batch_size=batch_size)
180+
181+
feed_keys = [
182+
"input_ids", "position_ids", "segment_ids", "input_mask",
183+
"ids_student", "labels"
184+
]
185+
186+
# distill reader and teacher
187+
dr = DistillReader(feed_keys, predicts=['logits'])
188+
dr.set_teacher_batch_size(batch_size)
189+
dr.set_serving_conf_file(
190+
"./ernie_senti_client/serving_client_conf.prototxt")
191+
if args.fixed_teacher:
192+
dr.set_fixed_teacher(args.fixed_teacher)
193+
194+
dr_train_reader = ds.batch_reader(
195+
input_files, word_dict, batch_size=batch_size)
196+
dr_t = dr.set_batch_generator(ernie_reader(dr_train_reader, feed_keys))
197+
198+
for i in range(args.train_range):
199+
train_with_distill(
200+
dr_t, dev_reader, word_dict, test_reader, epoch_num=args.epoch_num)
201+
202+
arr = np.array(g_max_dev_acc)
203+
print("max_dev_acc:", arr, "average:", np.average(arr), "train_args:",
204+
args)
205+
206+
arr = np.array(g_max_test_acc)
207+
print("max_test_acc:", arr, "average:", np.average(arr), "train_args:",
208+
args)

0 commit comments

Comments
 (0)