Skip to content

Commit fcc2ab1

Browse files
bmccannsoumith
authored andcommitted
altering translate to be compatible with nn.DataParallel
1 parent 0dee89a commit fcc2ab1

File tree

4 files changed

+31
-23
lines changed

4 files changed

+31
-23
lines changed

OpenNMT/onmt/Models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,18 @@ def forward(self, input, hidden, context, init_output):
110110
# self.input_feed=False
111111
outputs = []
112112
output = init_output
113-
for emb_t in emb.chunk(emb.size(0)):
113+
for i, emb_t in enumerate(emb.chunk(emb.size(0), dim=0)):
114114
emb_t = emb_t.squeeze(0)
115115
if self.input_feed:
116116
emb_t = torch.cat([emb_t, output], 1)
117117

118-
output, hidden = self.rnn(emb_t, hidden)
118+
output, h = self.rnn(emb_t, hidden)
119119
output, attn = self.attn(output, context.t())
120120
output = self.dropout(output)
121121
outputs += [output]
122122

123123
outputs = torch.stack(outputs)
124-
return outputs.transpose(0, 1), hidden, attn
124+
return outputs.transpose(0, 1), h, attn
125125

126126

127127
class NMTModel(nn.Module):

OpenNMT/onmt/Translator.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,32 @@ def buildTargetTokens(self, pred, src, attn):
4848

4949
def translateBatch(self, batch):
5050
srcBatch, tgtBatch = batch
51-
batchSize = srcBatch.size(1)
51+
batchSize = srcBatch.size(0)
5252
beamSize = self.opt.beam_size
5353

5454
# (1) run the encoder on the src
5555

5656
# have to execute the encoder manually to deal with padding
5757
encStates = None
5858
context = []
59-
for srcBatch_t in srcBatch.chunk(srcBatch.size(0)):
59+
for srcBatch_t in srcBatch.chunk(srcBatch.size(1), dim=1):
6060
encStates, context_t = self.model.encoder(srcBatch_t, hidden=encStates)
61-
batchPadIdx = srcBatch_t.data.squeeze(0).eq(onmt.Constants.PAD).nonzero()
61+
batchPadIdx = srcBatch_t.data.squeeze(1).eq(onmt.Constants.PAD).nonzero()
6262
if batchPadIdx.nelement() > 0:
6363
batchPadIdx = batchPadIdx.squeeze(1)
6464
encStates[0].data.index_fill_(1, batchPadIdx, 0)
6565
encStates[1].data.index_fill_(1, batchPadIdx, 0)
6666
context += [context_t]
6767

68+
encStates = (self.model._fix_enc_hidden(encStates[0]),
69+
self.model._fix_enc_hidden(encStates[1]))
70+
6871
context = torch.cat(context)
6972
rnnSize = context.size(2)
7073

7174
# This mask is applied to the attention model inside the decoder
7275
# so that the attention ignores source padding
73-
padMask = srcBatch.data.eq(onmt.Constants.PAD).t()
76+
padMask = srcBatch.data.eq(onmt.Constants.PAD)
7477
def applyContextMask(m):
7578
if isinstance(m, onmt.modules.GlobalAttention):
7679
m.applyMask(padMask)
@@ -85,8 +88,8 @@ def applyContextMask(m):
8588
initOutput = self.model.make_init_decoder_output(context)
8689

8790
decOut, decStates, attn = self.model.decoder(
88-
tgtBatch[:-1], decStates, context, initOutput)
89-
for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
91+
tgtBatch[:, :-1], decStates, context, initOutput)
92+
for dec_t, tgt_t in zip(decOut.transpose(0, 1), tgtBatch.transpose(0, 1)[1:].data):
9093
gen_t = self.model.generator.forward(dec_t)
9194
tgt_t = tgt_t.unsqueeze(1)
9295
scores = gen_t.data.gather(1, tgt_t)
@@ -104,7 +107,7 @@ def applyContextMask(m):
104107

105108
decOut = self.model.make_init_decoder_output(context)
106109

107-
padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1)
110+
padMask = srcBatch.data.eq(onmt.Constants.PAD).unsqueeze(0).repeat(beamSize, 1, 1)
108111

109112
batchIdx = list(range(batchSize))
110113
remainingSents = batchSize
@@ -117,9 +120,9 @@ def applyContextMask(m):
117120
if not b.done]).t().contiguous().view(1, -1)
118121

119122
decOut, decStates, attn = self.model.decoder(
120-
Variable(input), decStates, context, decOut)
123+
Variable(input).transpose(0, 1), decStates, context, decOut)
121124
# decOut: 1 x (beam*batch) x numWords
122-
decOut = decOut.squeeze(0)
125+
decOut = decOut.transpose(0, 1).squeeze(0)
123126
out = self.model.generator.forward(decOut)
124127

125128
# batch x beam x numWords
@@ -174,7 +177,7 @@ def updateActive(t):
174177
scores, ks = beam[b].sortBest()
175178

176179
allScores += [scores[:n_best]]
177-
valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1)
180+
valid_attn = srcBatch.transpose(0, 1).data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1)
178181
hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
179182
attn = [a.index_select(1, valid_attn) for a in attn]
180183
allHyp += [hyps]
@@ -186,13 +189,14 @@ def translate(self, srcBatch, goldBatch):
186189
# (1) convert words to indexes
187190
dataset = self.buildData(srcBatch, goldBatch)
188191
batch = dataset[0]
192+
batch = [x.transpose(0, 1) for x in batch]
189193

190194
# (2) translate
191195
pred, predScore, attn, goldScore = self.translateBatch(batch)
192196

193197
# (3) convert indexes to words
194198
predBatch = []
195-
for b in range(batch[0].size(1)):
199+
for b in range(batch[0].size(0)):
196200
predBatch.append(
197201
[self.buildTargetTokens(pred[b][n], srcBatch[b], attn[b][n])
198202
for n in range(self.opt.n_best)]

OpenNMT/train.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525

2626
parser.add_argument('-layers', type=int, default=2,
2727
help='Number of layers in the LSTM encoder/decoder')
28-
parser.add_argument('-rnn_size', type=int, default=500,
28+
parser.add_argument('-rnn_size', type=int, default=512,
2929
help='Size of LSTM hidden states')
30-
parser.add_argument('-word_vec_size', type=int, default=500,
30+
parser.add_argument('-word_vec_size', type=int, default=300,
3131
help='Word embedding sizes')
3232
parser.add_argument('-input_feed', type=int, default=1,
3333
help="""Feed the context vector at each time step as
@@ -43,13 +43,13 @@
4343

4444
## Optimization options
4545

46-
parser.add_argument('-batch_size', type=int, default=64,
46+
parser.add_argument('-batch_size', type=int, default=256,
4747
help='Maximum batch size')
4848
parser.add_argument('-max_generator_batches', type=int, default=32,
4949
help="""Maximum batches of words in a sequence to run
5050
the generator on in parallel. Higher is faster, but uses
5151
more memory.""")
52-
parser.add_argument('-epochs', type=int, default=13,
52+
parser.add_argument('-epochs', type=int, default=50,
5353
help='Number of training epochs')
5454
parser.add_argument('-start_epoch', type=int, default=1,
5555
help='The epoch from which to start')
@@ -58,16 +58,16 @@
5858
with support (-param_init, param_init)""")
5959
parser.add_argument('-optim', default='sgd',
6060
help="Optimization method. [sgd|adagrad|adadelta|adam]")
61-
parser.add_argument('-learning_rate', type=float, default=1,
61+
parser.add_argument('-learning_rate', type=float, default=1.0,
6262
help="""Starting learning rate. If adagrad/adadelta/adam is
6363
used, then this is the global learning rate. Recommended
6464
settings: sgd = 1, adagrad = 0.1, adadelta = 1, adam = 0.1""")
6565
parser.add_argument('-max_grad_norm', type=float, default=5,
6666
help="""If the norm of the gradient vector exceeds this,
6767
renormalize it to have the norm equal to max_grad_norm""")
68-
parser.add_argument('-dropout', type=float, default=0.3,
68+
parser.add_argument('-dropout', type=float, default=0.2,
6969
help='Dropout probability; applied between LSTM stacks.')
70-
parser.add_argument('-learning_rate_decay', type=float, default=0.5,
70+
parser.add_argument('-learning_rate_decay', type=float, default=0.9,
7171
help="""Decay learning rate by this much if (i) perplexity
7272
does not decrease on the validation set or (ii) epoch has
7373
gone past the start_decay_at_limit""")

OpenNMT/translate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import onmt
22
import torch
33
import argparse
4-
import time
54
import math
65

76
parser = argparse.ArgumentParser(description='translate.py')
@@ -37,9 +36,11 @@
3736
help="""If verbose is set, will output the n_best
3837
decoded sentences""")
3938

40-
parser.add_argument('-cuda', action="store_true",
39+
parser.add_argument('-gpu', type=int, default=7,
4140
help="Use CUDA")
4241

42+
43+
4344
def reportScore(name, scoreTotal, wordsTotal):
4445
print("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
4546
name, scoreTotal / wordsTotal,
@@ -48,6 +49,8 @@ def reportScore(name, scoreTotal, wordsTotal):
4849

4950
def main():
5051
opt = parser.parse_args()
52+
opt.cuda = True
53+
torch.cuda.set_device(opt.gpu)
5154

5255
translator = onmt.Translator(opt)
5356

@@ -58,6 +61,7 @@ def main():
5861
srcBatch, tgtBatch = [], []
5962

6063
count = 0
64+
6165
tgtF = open(opt.tgt) if opt.tgt else None
6266
for line in open(opt.src):
6367

0 commit comments

Comments
 (0)