Skip to content

Commit 99231ac

Browse files
bmccannsoumith
authored andcommitted
adding word level accuracy as a metric
1 parent fd87818 commit 99231ac

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

OpenNMT/train.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120

121121
def NMTCriterion(vocabSize):
122122
weight = torch.ones(vocabSize)
123-
weightonmt.Constants.PAD] = 0
123+
weight[onmt.Constants.PAD] = 0
124124
crit = nn.NLLLoss(weight, size_average=False)
125125
if opt.gpus:
126126
crit.cuda()
@@ -129,40 +129,45 @@ def NMTCriterion(vocabSize):
129129

130130
def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
131131
# compute generations one piece at a time
132-
loss = 0
132+
num_correct, loss = 0, 0
133133
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval)
134134

135135
batch_size = outputs.size(1)
136136
outputs_split = torch.split(outputs, opt.max_generator_batches)
137137
targets_split = torch.split(targets, opt.max_generator_batches)
138-
for out_t, targ_t in zip(outputs_split, targets_split):
138+
for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
139139
out_t = out_t.view(-1, out_t.size(2))
140-
pred_t = generator(out_t)
141-
loss_t = crit(pred_t, targ_t.view(-1))
140+
scores_t = generator(out_t)
141+
loss_t = crit(scores_t, targ_t.view(-1))
142+
pred_t = scores_t.max(1)[1]
143+
num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(onmt.Constants.PAD).data).sum()
144+
num_correct += num_correct_t
142145
loss += loss_t.data[0]
143146
if not eval:
144147
loss_t.div(batch_size).backward()
145148

146149
grad_output = None if outputs.grad is None else outputs.grad.data
147-
return loss, grad_output
150+
return loss, grad_output, num_correct
148151

149152

150153
def eval(model, criterion, data):
151154
total_loss = 0
152155
total_words = 0
156+
total_num_correct = 0
153157

154158
model.eval()
155159
for i in range(len(data)):
156160
batch = data[i]
157161
outputs = model(batch) # FIXME volatile
158162
targets = batch[1][1:] # exclude <s> from targets
159-
loss, _ = memoryEfficientLoss(
163+
loss, _, num_correct = memoryEfficientLoss(
160164
outputs, targets, model.generator, criterion, eval=True)
161165
total_loss += loss
166+
total_num_correct += num_correct
162167
total_words += targets.data.ne(onmt.Constants.PAD).sum()
163168

164169
model.train()
165-
return total_loss / total_words
170+
return total_loss / total_words, total_num_correct / total_words
166171

167172

168173
def trainModel(model, trainData, validData, dataset, optim):
@@ -183,6 +188,7 @@ def trainEpoch(epoch):
183188

184189
total_loss, report_loss = 0, 0
185190
total_words, report_tgt_words, report_src_words = 0, 0, 0
191+
total_num_correct = 0
186192
start = time.time()
187193
for i in range(len(trainData)):
188194

@@ -192,7 +198,7 @@ def trainEpoch(epoch):
192198
model.zero_grad()
193199
outputs = model(batch)
194200
targets = batch[1][1:] # exclude <s> from targets
195-
loss, gradOutput = memoryEfficientLoss(
201+
loss, gradOutput, num_correct = memoryEfficientLoss(
196202
outputs, targets, model.generator, criterion)
197203

198204
outputs.backward(gradOutput)
@@ -201,36 +207,40 @@ def trainEpoch(epoch):
201207
optim.step()
202208

203209
report_loss += loss
210+
total_num_correct += num_correct
204211
total_loss += loss
205212
num_words = targets.data.ne(onmt.Constants.PAD).sum()
206213
total_words += num_words
207214
report_tgt_words += num_words
208215
report_src_words += batch[0].data.ne(onmt.Constants.PAD).sum()
209216
if i % opt.log_interval == 0 and i > 0:
210-
print("Epoch %2d, %5d/%5d batches; perplexity: %6.2f; %3.0f source tokens/s; %3.0f target tokens/s; %6.0f s elapsed" %
217+
print("Epoch %2d, %5d/%5d; ppl: %6.2f; acc: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" %
211218
(epoch, i, len(trainData),
212219
math.exp(report_loss / report_tgt_words),
220+
num_correct / num_words * 100,
213221
report_src_words/(time.time()-start),
214222
report_tgt_words/(time.time()-start),
215223
time.time()-start_time))
216224

217225
report_loss = report_tgt_words = report_src_words = 0
218226
start = time.time()
219227

220-
return total_loss / total_words
228+
return total_loss / total_words, total_num_correct / total_words
221229

222230
for epoch in range(opt.start_epoch, opt.epochs + 1):
223231
print('')
224232

225233
# (1) train for one epoch on the training set
226-
train_loss = trainEpoch(epoch)
234+
train_loss, train_acc = trainEpoch(epoch)
227235
train_ppl = math.exp(min(train_loss, 100))
228236
print('Train perplexity: %g' % train_ppl)
237+
print('Train accuracy: %g' % train_acc)
229238

230239
# (2) evaluate on the validation set
231-
valid_loss = eval(model, criterion, validData)
240+
valid_loss, valid_acc = eval(model, criterion, validData)
232241
valid_ppl = math.exp(min(valid_loss, 100))
233242
print('Validation perplexity: %g' % valid_ppl)
243+
print('Validation accuracy: %g' % valid_acc)
234244

235245
# (3) maybe update the learning rate
236246
if opt.update_learning_rate:
@@ -245,7 +255,7 @@ def trainEpoch(epoch):
245255
'optim': optim,
246256
}
247257
torch.save(checkpoint,
248-
'%s_val_%.2f_train_%.2f_e%d.pt' % (opt.save_model, valid_ppl, train_ppl, epoch))
258+
'%s_acc_%.2f_ppl_%.2f_e%d.pt' % (opt.save_model, valid_acc, valid_ppl, epoch))
249259

250260
def main():
251261

0 commit comments

Comments
 (0)