Skip to content

Commit 0c634a1

Browse files
bmccannsoumith
authored andcommitted
new DataParallel allows dim 1; remove unnecessary transposes; add train_ppl to chkpt
1 parent c90842c commit 0c634a1

File tree

3 files changed

+30
-31
lines changed

3 files changed

+30
-31
lines changed

OpenNMT/onmt/Models.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ def __init__(self, opt, dicts):
2626
self.word_lut.weight.copy_(pretrained)
2727

2828
def forward(self, input, hidden=None):
29-
batch_size = input.size(0) # batch first for multi-gpu compatibility
30-
emb = self.word_lut(input).transpose(0, 1)
29+
emb = self.word_lut(input)
30+
3131
if hidden is None:
32+
batch_size = emb.size(1)
3233
h_size = (self.layers * self.num_directions, batch_size, self.hidden_size)
3334
h_0 = Variable(emb.data.new(*h_size).zero_(), requires_grad=False)
3435
c_0 = Variable(emb.data.new(*h_size).zero_(), requires_grad=False)
@@ -91,9 +92,9 @@ def __init__(self, opt, dicts):
9192

9293

9394
def forward(self, input, hidden, context, init_output):
94-
emb = self.word_lut(input).transpose(0, 1)
95+
emb = self.word_lut(input)
9596

96-
batch_size = input.size(0)
97+
batch_size = input.size(1)
9798

9899
h_size = (batch_size, self.hidden_size)
99100

@@ -102,7 +103,7 @@ def forward(self, input, hidden, context, init_output):
102103
# self.input_feed=False
103104
outputs = []
104105
output = init_output
105-
for i, emb_t in enumerate(emb.split(1)):
106+
for emb_t in emb.split(1):
106107
emb_t = emb_t.squeeze(0)
107108
if self.input_feed:
108109
emb_t = torch.cat([emb_t, output], 1)
@@ -113,7 +114,7 @@ def forward(self, input, hidden, context, init_output):
113114
outputs += [output]
114115

115116
outputs = torch.stack(outputs)
116-
return outputs.transpose(0, 1), hidden, attn
117+
return outputs, hidden, attn
117118

118119

119120
class NMTModel(nn.Module):
@@ -145,7 +146,7 @@ def _fix_enc_hidden(self, h):
145146

146147
def forward(self, input):
147148
src = input[0]
148-
tgt = input[1][:, :-1] # exclude last target from inputs
149+
tgt = input[1][:-1] # exclude last target from inputs
149150
enc_hidden, context = self.encoder(src)
150151
init_output = self.make_init_decoder_output(context)
151152

OpenNMT/onmt/Translator.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ def buildTargetTokens(self, pred, src, attn):
4848

4949
def translateBatch(self, batch):
5050
srcBatch, tgtBatch = batch
51-
batchSize = srcBatch.size(0)
51+
batchSize = srcBatch.size(1)
5252
beamSize = self.opt.beam_size
5353

5454
# (1) run the encoder on the src
5555

5656
# have to execute the encoder manually to deal with padding
5757
encStates = None
5858
context = []
59-
for srcBatch_t in srcBatch.chunk(srcBatch.size(1), dim=1):
59+
for srcBatch_t in srcBatch.split(1):
6060
encStates, context_t = self.model.encoder(srcBatch_t, hidden=encStates)
61-
batchPadIdx = srcBatch_t.data.squeeze(1).eq(onmt.Constants.PAD).nonzero()
61+
batchPadIdx = srcBatch_t.data.squeeze(0).eq(onmt.Constants.PAD).nonzero()
6262
if batchPadIdx.nelement() > 0:
6363
batchPadIdx = batchPadIdx.squeeze(1)
6464
encStates[0].data.index_fill_(1, batchPadIdx, 0)
@@ -73,7 +73,7 @@ def translateBatch(self, batch):
7373

7474
# This mask is applied to the attention model inside the decoder
7575
# so that the attention ignores source padding
76-
padMask = srcBatch.data.eq(onmt.Constants.PAD)
76+
padMask = srcBatch.data.eq(onmt.Constants.PAD).t()
7777
def applyContextMask(m):
7878
if isinstance(m, onmt.modules.GlobalAttention):
7979
m.applyMask(padMask)
@@ -88,8 +88,8 @@ def applyContextMask(m):
8888
initOutput = self.model.make_init_decoder_output(context)
8989

9090
decOut, decStates, attn = self.model.decoder(
91-
tgtBatch[:, :-1], decStates, context, initOutput)
92-
for dec_t, tgt_t in zip(decOut.transpose(0, 1), tgtBatch.transpose(0, 1)[1:].data):
91+
tgtBatch[:-1], decStates, context, initOutput)
92+
for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
9393
gen_t = self.model.generator.forward(dec_t)
9494
tgt_t = tgt_t.unsqueeze(1)
9595
scores = gen_t.data.gather(1, tgt_t)
@@ -107,7 +107,7 @@ def applyContextMask(m):
107107

108108
decOut = self.model.make_init_decoder_output(context)
109109

110-
padMask = srcBatch.data.eq(onmt.Constants.PAD).unsqueeze(0).repeat(beamSize, 1, 1)
110+
padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1)
111111

112112
batchIdx = list(range(batchSize))
113113
remainingSents = batchSize
@@ -120,9 +120,9 @@ def applyContextMask(m):
120120
if not b.done]).t().contiguous().view(1, -1)
121121

122122
decOut, decStates, attn = self.model.decoder(
123-
Variable(input, volatile=True).transpose(0, 1), decStates, context, decOut)
123+
Variable(input, volatile=True), decStates, context, decOut)
124124
# decOut: 1 x (beam*batch) x numWords
125-
decOut = decOut.transpose(0, 1).squeeze(0)
125+
decOut = decOut.squeeze(0)
126126
out = self.model.generator.forward(decOut)
127127

128128
# batch x beam x numWords
@@ -177,7 +177,7 @@ def updateActive(t):
177177
scores, ks = beam[b].sortBest()
178178

179179
allScores += [scores[:n_best]]
180-
valid_attn = srcBatch.transpose(0, 1).data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1)
180+
valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1)
181181
hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
182182
attn = [a.index_select(1, valid_attn) for a in attn]
183183
allHyp += [hyps]
@@ -189,14 +189,13 @@ def translate(self, srcBatch, goldBatch):
189189
# (1) convert words to indexes
190190
dataset = self.buildData(srcBatch, goldBatch)
191191
batch = dataset[0]
192-
batch = [x.transpose(0, 1) for x in batch]
193192

194193
# (2) translate
195194
pred, predScore, attn, goldScore = self.translateBatch(batch)
196195

197196
# (3) convert indexes to words
198197
predBatch = []
199-
for b in range(batch[0].size(0)):
198+
for b in range(batch[0].size(1)):
200199
predBatch.append(
201200
[self.buildTargetTokens(pred[b][n], srcBatch[b], attn[b][n])
202201
for n in range(self.opt.n_best)]

OpenNMT/train.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ def NMTCriterion(vocabSize):
117117
def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
118118
# compute generations one piece at a time
119119
loss = 0
120-
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval).contiguous()
120+
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval)
121121

122122
batch_size = outputs.size(1)
123123
outputs_split = torch.split(outputs, opt.max_generator_batches)
124-
targets_split = torch.split(targets.contiguous(), opt.max_generator_batches)
124+
targets_split = torch.split(targets, opt.max_generator_batches)
125125
for out_t, targ_t in zip(outputs_split, targets_split):
126126
out_t = out_t.view(-1, out_t.size(2))
127127
pred_t = generator(out_t)
@@ -140,9 +140,9 @@ def eval(model, criterion, data):
140140

141141
model.eval()
142142
for i in range(len(data)):
143-
batch = [x.transpose(0, 1) for x in data[i]] # must be batch first for gather/scatter in DataParallel
143+
batch = data[i]
144144
outputs = model(batch) # FIXME volatile
145-
targets = batch[1][:, 1:] # exclude <s> from targets
145+
targets = batch[1][1:] # exclude <s> from targets
146146
loss, _ = memoryEfficientLoss(
147147
outputs, targets, model.generator, criterion, eval=True)
148148
total_loss += loss
@@ -172,11 +172,10 @@ def trainEpoch(epoch):
172172

173173
batchIdx = batchOrder[i] if epoch >= opt.curriculum else i
174174
batch = trainData[batchIdx]
175-
batch = [x.transpose(0, 1) for x in batch] # must be batch first for gather/scatter in DataParallel
176175

177176
model.zero_grad()
178177
outputs = model(batch)
179-
targets = batch[1][:, 1:] # exclude <s> from targets
178+
targets = batch[1][1:] # exclude <s> from targets
180179
loss, gradOutput = memoryEfficientLoss(
181180
outputs, targets, model.generator, criterion)
182181

@@ -209,7 +208,8 @@ def trainEpoch(epoch):
209208

210209
# (1) train for one epoch on the training set
211210
train_loss = trainEpoch(epoch)
212-
print('Train perplexity: %g' % math.exp(min(train_loss, 100)))
211+
train_ppl = math.exp(min(train_loss, 100))
212+
print('Train perplexity: %g' % train_ppl)
213213

214214
# (2) evaluate on the validation set
215215
valid_loss = eval(model, criterion, validData)
@@ -229,8 +229,7 @@ def trainEpoch(epoch):
229229
'optim': optim,
230230
}
231231
torch.save(checkpoint,
232-
'%s_e%d_%.2f.pt' % (opt.save_model, epoch, valid_ppl))
233-
232+
'%s_val%.2f_e%d_train%.2f.pt' % (opt.save_model, valid_ppl, epoch, train_ppl))
234233

235234
def main():
236235

@@ -258,11 +257,11 @@ def main():
258257
generator = nn.Sequential(
259258
nn.Linear(opt.rnn_size, dicts['tgt'].size()),
260259
nn.LogSoftmax())
261-
if len(opt.gpus) > 1:
262-
generator = nn.DataParallel(generator, device_ids=opt.gpus)
260+
# if len(opt.gpus) > 1:
261+
# generator = nn.DataParallel(generator, device_ids=opt.gpus)
263262
model = onmt.Models.NMTModel(encoder, decoder, generator)
264263
if len(opt.gpus) > 1:
265-
model = nn.DataParallel(model, device_ids=opt.gpus)
264+
model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
266265
if opt.gpus:
267266
model.cuda()
268267
else:

0 commit comments

Comments
 (0)