Skip to content

Commit f14db82

Browse files
authored
Merge pull request #783 from guoshengCS/fix-transformer-batchsize
Decouple the program desc with batch_size in Transformer.
2 parents 76526e5 + baa01f6 commit f14db82

File tree

4 files changed

+170
-105
lines changed

4 files changed

+170
-105
lines changed

fluid/neural_machine_translation/transformer/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ class TrainTaskConfig(object):
2525
class InferTaskConfig(object):
2626
use_gpu = False
2727
# the number of examples in one run for sequence generation.
28-
# currently the batch size can only be set to 1.
29-
batch_size = 1
28+
batch_size = 10
3029

3130
# the parameters for beam search.
3231
beam_size = 5
@@ -103,6 +102,7 @@ class ModelHyperParams(object):
103102
"src_word",
104103
"src_pos",
105104
"src_slf_attn_bias",
105+
"src_data_shape",
106106
"src_slf_attn_pre_softmax_shape",
107107
"src_slf_attn_post_softmax_shape", )
108108

@@ -112,6 +112,7 @@ class ModelHyperParams(object):
112112
"trg_pos",
113113
"trg_slf_attn_bias",
114114
"trg_src_attn_bias",
115+
"trg_data_shape",
115116
"trg_slf_attn_pre_softmax_shape",
116117
"trg_slf_attn_post_softmax_shape",
117118
"trg_src_attn_pre_softmax_shape",

fluid/neural_machine_translation/transformer/infer.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def translate_batch(exe,
2424
n_best,
2525
batch_size,
2626
n_head,
27+
d_model,
2728
src_pad_idx,
2829
trg_pad_idx,
2930
bos_idx,
@@ -43,6 +44,11 @@ def translate_batch(exe,
4344
return_pos=True,
4445
return_attn_bias=True,
4546
return_max_len=False)
47+
# Append the data shape input to reshape the output of embedding layer.
48+
enc_in_data = enc_in_data + [
49+
np.array(
50+
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
51+
]
4652
# Append the shape inputs to reshape before and after softmax in encoder
4753
# self attention.
4854
enc_in_data = enc_in_data + [
@@ -59,9 +65,14 @@ def translate_batch(exe,
5965
scores = np.zeros((batch_size, beam_size), dtype="float32")
6066
prev_branchs = [[] for i in range(batch_size)]
6167
next_ids = [[] for i in range(batch_size)]
62-
# Use beam_map to map the instance idx in batch to beam idx, since the
68+
# Use beam_inst_map to map beam idx to the instance idx in batch, since the
6369
# size of feeded batch is changing.
64-
beam_map = range(batch_size)
70+
beam_inst_map = {
71+
beam_idx: inst_idx
72+
for inst_idx, beam_idx in enumerate(range(batch_size))
73+
}
74+
# Use active_beams to recode the alive.
75+
active_beams = range(batch_size)
6576

6677
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
6778
"""
@@ -98,8 +109,14 @@ def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
98109
[-1e9]).astype("float32")
99110
# This is used to remove attention on the paddings of source sequences.
100111
trg_src_attn_bias = np.tile(
101-
src_slf_attn_bias[:, :, ::src_max_length, :],
102-
[beam_size, 1, trg_max_len, 1])
112+
src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
113+
[1, beam_size, 1, trg_max_len, 1]).reshape([
114+
-1, src_slf_attn_bias.shape[1], trg_max_len,
115+
src_slf_attn_bias.shape[-1]
116+
])
117+
# Append the shape input to reshape the output of embedding layer.
118+
trg_data_shape = np.array(
119+
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
103120
# Append the shape inputs to reshape before and after softmax in
104121
# decoder self attention.
105122
trg_slf_attn_pre_softmax_shape = np.array(
@@ -112,22 +129,24 @@ def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
112129
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
113130
trg_src_attn_post_softmax_shape = np.array(
114131
trg_src_attn_bias.shape, dtype="int32")
115-
enc_output = np.tile(enc_output, [beam_size, 1, 1])
132+
enc_output = np.tile(
133+
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
134+
[-1, enc_output.shape[-2], enc_output.shape[-1]])
116135
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
117-
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
118-
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
119-
enc_output
136+
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
137+
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
138+
trg_src_attn_post_softmax_shape, enc_output
120139

121-
def update_dec_in_data(dec_in_data, next_ids, active_beams):
140+
def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
122141
"""
123142
Update the input data of decoder mainly by slicing from the previous
124143
input data and dropping the finished instance beams.
125144
"""
126145
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
127-
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
128-
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
129-
enc_output = dec_in_data
130-
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
146+
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
147+
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
148+
trg_src_attn_post_softmax_shape, enc_output = dec_in_data
149+
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
131150
trg_words = np.array(
132151
[
133152
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
@@ -138,6 +157,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
138157
trg_pos = np.array(
139158
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
140159
dtype="int64").reshape([-1, 1])
160+
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
141161
active_beams_indice = (
142162
(np.array(active_beams) * beam_size)[:, np.newaxis] +
143163
np.array(range(beam_size))[np.newaxis, :]).flatten()
@@ -152,6 +172,10 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
152172
trg_src_attn_bias = np.tile(trg_src_attn_bias[
153173
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
154174
[1, 1, trg_cur_len, 1])
175+
# Append the shape input to reshape the output of embedding layer.
176+
trg_data_shape = np.array(
177+
[len(active_beams) * beam_size, trg_cur_len, d_model],
178+
dtype="int32")
155179
# Append the shape inputs to reshape before and after softmax in
156180
# decoder self attention.
157181
trg_slf_attn_pre_softmax_shape = np.array(
@@ -166,9 +190,9 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
166190
trg_src_attn_bias.shape, dtype="int32")
167191
enc_output = enc_output[active_beams_indice, :, :]
168192
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
169-
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
170-
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
171-
enc_output
193+
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
194+
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
195+
trg_src_attn_post_softmax_shape, enc_output
172196

173197
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
174198
enc_output)
@@ -177,15 +201,18 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
177201
feed=dict(zip(dec_in_names, dec_in_data)),
178202
fetch_list=dec_out_names)[0]
179203
predict_all = np.log(
180-
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:,
181-
-1, :])
182-
predict_all = (predict_all + scores[beam_map].reshape(
183-
[len(beam_map) * beam_size, -1])).reshape(
184-
[len(beam_map), beam_size, -1])
204+
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
205+
[:, -1, :])
206+
predict_all = (predict_all + scores[active_beams].reshape(
207+
[len(beam_inst_map) * beam_size, -1])).reshape(
208+
[len(beam_inst_map), beam_size, -1])
185209
if not output_unk: # To exclude the <unk> token.
186210
predict_all[:, :, unk_idx] = -1e9
187211
active_beams = []
188-
for inst_idx, beam_idx in enumerate(beam_map):
212+
for beam_idx in range(batch_size):
213+
if not beam_inst_map.has_key(beam_idx):
214+
continue
215+
inst_idx = beam_inst_map[beam_idx]
189216
predict = (predict_all[inst_idx, :, :]
190217
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
191218
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
@@ -198,10 +225,14 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
198225
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
199226
if next_ids[beam_idx][-1][0] != eos_idx:
200227
active_beams.append(beam_idx)
201-
beam_map = active_beams
202-
if len(beam_map) == 0:
228+
if len(active_beams) == 0:
203229
break
204-
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)
230+
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
231+
beam_inst_map)
232+
beam_inst_map = {
233+
beam_idx: inst_idx
234+
for inst_idx, beam_idx in enumerate(active_beams)
235+
}
205236

206237
# Decode beams and select n_best sequences for each instance by backtrace.
207238
seqs = [
@@ -215,10 +246,8 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
215246
def main():
216247
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
217248
exe = fluid.Executor(place)
218-
# The current program desc is coupled with batch_size and the only
219-
# supported batch size is 1 currently.
249+
220250
encoder_program = fluid.Program()
221-
model.batch_size = InferTaskConfig.batch_size
222251
with fluid.program_guard(main_program=encoder_program):
223252
enc_output = encoder(
224253
ModelHyperParams.src_vocab_size + 1,
@@ -228,7 +257,6 @@ def main():
228257
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
229258
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
230259

231-
model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
232260
decoder_program = fluid.Program()
233261
with fluid.program_guard(main_program=decoder_program):
234262
predict = decoder(
@@ -273,6 +301,9 @@ def main():
273301

274302
trg_idx2word = paddle.dataset.wmt16.get_dict(
275303
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
304+
# Append the <pad> token since the dict provided by dataset.wmt16 does
305+
# not include it.
306+
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"
276307

277308
def post_process_seq(seq,
278309
bos_idx=ModelHyperParams.bos_idx,
@@ -306,6 +337,7 @@ def post_process_seq(seq,
306337
InferTaskConfig.n_best,
307338
len(data),
308339
ModelHyperParams.n_head,
340+
ModelHyperParams.d_model,
309341
ModelHyperParams.src_pad_idx,
310342
ModelHyperParams.trg_pad_idx,
311343
ModelHyperParams.bos_idx,

0 commit comments

Comments
 (0)