|  | 
|  | 1 | +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +# http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import sys | 
|  | 16 | +import os | 
|  | 17 | + | 
|  | 18 | +import numpy as np | 
|  | 19 | +import argparse | 
|  | 20 | +from sklearn.metrics import f1_score, accuracy_score | 
|  | 21 | +import paddle as P | 
|  | 22 | +import paddle.fluid as F | 
|  | 23 | +import paddle.fluid.layers as L | 
|  | 24 | +import paddle.fluid.dygraph as D | 
|  | 25 | +from reader import ChnSentiCorp, pad_batch_data | 
|  | 26 | +from paddle_edl.distill.distill_reader import DistillReader | 
|  | 27 | +import re | 
|  | 28 | + | 
|  | 29 | +import os | 
|  | 30 | +import sys | 
|  | 31 | +from paddle_serving_client import Client | 
|  | 32 | +from paddle_serving_app.reader import ChineseBertReader | 
|  | 33 | +from model import CNN, AdamW, evaluate_student, KL, BOW, KL_T | 
|  | 34 | + | 
|  | 35 | +parser = argparse.ArgumentParser(__doc__) | 
|  | 36 | +parser.add_argument( | 
|  | 37 | + "--fixed_teacher", | 
|  | 38 | + type=str, | 
|  | 39 | + default=None, | 
|  | 40 | + help="fixed teacher for debug local distill") | 
|  | 41 | +parser.add_argument( | 
|  | 42 | + "--s_weight", type=float, default=0.5, help="weight of student in loss") | 
|  | 43 | +parser.add_argument( | 
|  | 44 | + "--epoch_num", type=int, default=10, help="weight of student in loss") | 
|  | 45 | +parser.add_argument( | 
|  | 46 | + "--weight_decay", | 
|  | 47 | + type=float, | 
|  | 48 | + default=0.01, | 
|  | 49 | + help="weight of student in loss") | 
|  | 50 | +parser.add_argument( | 
|  | 51 | + "--opt", type=str, default="AdamW", help="weight of student in loss") | 
|  | 52 | +parser.add_argument("--train_range", type=int, default=10, help="train range") | 
|  | 53 | +parser.add_argument( | 
|  | 54 | + "--use_data_au", type=int, default=1, help="use data augmentation") | 
|  | 55 | +parser.add_argument( | 
|  | 56 | + "--T", type=float, default=2.0, help="weight of student in loss") | 
|  | 57 | +args = parser.parse_args() | 
|  | 58 | +print("parsed args:", args) | 
|  | 59 | + | 
|  | 60 | +g_max_dev_acc = [] | 
|  | 61 | +g_max_test_acc = [] | 
|  | 62 | + | 
|  | 63 | + | 
|  | 64 | +def train_with_distill(train_reader, dev_reader, word_dict, test_reader, | 
|  | 65 | + epoch_num): | 
|  | 66 | + boundaries = [2250 * 2, 2250 * 4, 2250 * 6] | 
|  | 67 | + values = [1e-4, 1.5e-4, 2.5e-4, 4e-4] | 
|  | 68 | + lr = D.PiecewiseDecay(boundaries, values, 0) | 
|  | 69 | + model = BOW(word_dict) | 
|  | 70 | + if args.opt == "Adam": | 
|  | 71 | + opt = F.optimizer.Adam( | 
|  | 72 | + learning_rate=lr, | 
|  | 73 | + parameter_list=model.parameters(), | 
|  | 74 | + regularization=F.regularizer.L2Decay( | 
|  | 75 | + regularization_coeff=args.weight_decay)) | 
|  | 76 | + else: | 
|  | 77 | + opt = AdamW( | 
|  | 78 | + learning_rate=lr, | 
|  | 79 | + parameter_list=model.parameters(), | 
|  | 80 | + weight_decay=args.weight_decay) | 
|  | 81 | + | 
|  | 82 | + model.train() | 
|  | 83 | + max_dev_acc = 0.0 | 
|  | 84 | + max_test_acc = 0.0 | 
|  | 85 | + for epoch in range(epoch_num): | 
|  | 86 | + for step, output in enumerate(train_reader()): | 
|  | 87 | + (_, _, _, _, ids_student, labels, logits_t) = output | 
|  | 88 | + | 
|  | 89 | + ids_student = D.base.to_variable( | 
|  | 90 | + pad_batch_data(ids_student, 'int64')) | 
|  | 91 | + labels = D.base.to_variable(np.array(labels).astype('int64')) | 
|  | 92 | + logits_t = D.base.to_variable(np.array(logits_t).astype('float32')) | 
|  | 93 | + logits_t.stop_gradient = True | 
|  | 94 | + | 
|  | 95 | + _, logits_s = model(ids_student) | 
|  | 96 | + loss_ce, _ = model(ids_student, labels=labels) | 
|  | 97 | + | 
|  | 98 | + if args.T is None: | 
|  | 99 | + loss_kd = KL(logits_s, logits_t) | 
|  | 100 | + loss = args.s_weight * loss_ce + (1.0 - args.s_weight | 
|  | 101 | + ) * loss_kd | 
|  | 102 | + else: | 
|  | 103 | + loss_kd = KL_T(logits_s, logits_t, args.T) | 
|  | 104 | + loss = args.T * args.T * (args.s_weight * loss_ce + | 
|  | 105 | + (1.0 - args.s_weight) * loss_kd) | 
|  | 106 | + | 
|  | 107 | + loss = L.reduce_mean(loss) | 
|  | 108 | + loss.backward() | 
|  | 109 | + if step % 10 == 0: | 
|  | 110 | + print('[step %03d] distill train loss %.5f lr %.3e' % | 
|  | 111 | + (step, loss.numpy(), opt.current_step_lr())) | 
|  | 112 | + opt.minimize(loss) | 
|  | 113 | + model.clear_gradients() | 
|  | 114 | + f1, acc = evaluate_student(model, dev_reader) | 
|  | 115 | + print('student on dev f1 %.5f acc %.5f' % (f1, acc)) | 
|  | 116 | + | 
|  | 117 | + if max_dev_acc < acc: | 
|  | 118 | + max_dev_acc = acc | 
|  | 119 | + | 
|  | 120 | + f1, acc = evaluate_student(model, test_reader) | 
|  | 121 | + print('student on test f1 %.5f acc %.5f' % (f1, acc)) | 
|  | 122 | + | 
|  | 123 | + if max_test_acc < acc: | 
|  | 124 | + max_test_acc = acc | 
|  | 125 | + | 
|  | 126 | + g_max_dev_acc.append(g_max_dev_acc) | 
|  | 127 | + g_max_test_acc.append(g_max_test_acc) | 
|  | 128 | + | 
|  | 129 | + | 
|  | 130 | +def ernie_reader(s_reader, key_list): | 
|  | 131 | + bert_reader = ChineseBertReader({ | 
|  | 132 | + 'max_seq_len': 256, | 
|  | 133 | + "vocab_file": "./data/vocab.txt" | 
|  | 134 | + }) | 
|  | 135 | + | 
|  | 136 | + def reader(): | 
|  | 137 | + for (ids_student, labels, ss) in s_reader(): | 
|  | 138 | + b = {} | 
|  | 139 | + for k in key_list: | 
|  | 140 | + b[k] = [] | 
|  | 141 | + | 
|  | 142 | + for s in ss: | 
|  | 143 | + feed_dict = bert_reader.process(s) | 
|  | 144 | + for k in feed_dict: | 
|  | 145 | + b[k].append(feed_dict[k]) | 
|  | 146 | + b["ids_student"] = ids_student | 
|  | 147 | + b["labels"] = labels | 
|  | 148 | + | 
|  | 149 | + l = [] | 
|  | 150 | + for k in key_list: | 
|  | 151 | + l.append(b[k]) | 
|  | 152 | + | 
|  | 153 | + yield l | 
|  | 154 | + | 
|  | 155 | + return reader | 
|  | 156 | + | 
|  | 157 | + | 
|  | 158 | +if __name__ == "__main__": | 
|  | 159 | + place = F.CUDAPlace(0) | 
|  | 160 | + D.guard(place).__enter__() | 
|  | 161 | + | 
|  | 162 | + ds = ChnSentiCorp() | 
|  | 163 | + word_dict = ds.student_word_dict("./data/vocab.bow.txt") | 
|  | 164 | + batch_size = 16 | 
|  | 165 | + | 
|  | 166 | + input_files = [] | 
|  | 167 | + if args.use_data_au: | 
|  | 168 | + for i in range(1, 5): | 
|  | 169 | + input_files.append("./data/train-data-augmented/part.{}".format(i)) | 
|  | 170 | + else: | 
|  | 171 | + input_files.append("./data/train.part.0") | 
|  | 172 | + | 
|  | 173 | + # student train and dev | 
|  | 174 | + train_reader = ds.pad_batch_reader( | 
|  | 175 | + input_files, word_dict, batch_size=batch_size) | 
|  | 176 | + dev_reader = ds.pad_batch_reader( | 
|  | 177 | + "./data/dev.part.0", word_dict, batch_size=batch_size) | 
|  | 178 | + test_reader = ds.pad_batch_reader( | 
|  | 179 | + "./data/test.part.0", word_dict, batch_size=batch_size) | 
|  | 180 | + | 
|  | 181 | + feed_keys = [ | 
|  | 182 | + "input_ids", "position_ids", "segment_ids", "input_mask", | 
|  | 183 | + "ids_student", "labels" | 
|  | 184 | + ] | 
|  | 185 | + | 
|  | 186 | + # distill reader and teacher | 
|  | 187 | + dr = DistillReader(feed_keys, predicts=['logits']) | 
|  | 188 | + dr.set_teacher_batch_size(batch_size) | 
|  | 189 | + dr.set_serving_conf_file( | 
|  | 190 | + "./ernie_senti_client/serving_client_conf.prototxt") | 
|  | 191 | + if args.fixed_teacher: | 
|  | 192 | + dr.set_fixed_teacher(args.fixed_teacher) | 
|  | 193 | + | 
|  | 194 | + dr_train_reader = ds.batch_reader( | 
|  | 195 | + input_files, word_dict, batch_size=batch_size) | 
|  | 196 | + dr_t = dr.set_batch_generator(ernie_reader(dr_train_reader, feed_keys)) | 
|  | 197 | + | 
|  | 198 | + for i in range(args.train_range): | 
|  | 199 | + train_with_distill( | 
|  | 200 | + dr_t, dev_reader, word_dict, test_reader, epoch_num=args.epoch_num) | 
|  | 201 | + | 
|  | 202 | + arr = np.array(g_max_dev_acc) | 
|  | 203 | + print("max_dev_acc:", arr, "average:", np.average(arr), "train_args:", | 
|  | 204 | + args) | 
|  | 205 | + | 
|  | 206 | + arr = np.array(g_max_test_acc) | 
|  | 207 | + print("max_test_acc:", arr, "average:", np.average(arr), "train_args:", | 
|  | 208 | + args) | 
0 commit comments