There was an error while loading. Please reload this page.
1 parent 0e77a0b commit 8bb7c5aCopy full SHA for 8bb7c5a
OpenNMT/train.py
@@ -186,8 +186,8 @@ def trainEpoch(epoch):
186
# shuffle mini batch order
187
batchOrder = torch.randperm(len(trainData))
188
189
- total_loss, total_words, total_num_correct = 0
190
- report_loss, report_tgt_words, report_src_words, report_num_correct = 0
+ total_loss, total_words, total_num_correct = 0, 0, 0
+ report_loss, report_tgt_words, report_src_words, report_num_correct = 0, 0, 0, 0
191
start = time.time()
192
for i in range(len(trainData)):
193
@@ -215,7 +215,7 @@ def trainEpoch(epoch):
215
total_words += num_words
216
if i % opt.log_interval == -1 % opt.log_interval:
217
print("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" %
218
- (epoch, i, len(trainData),
+ (epoch, i+1, len(trainData),
219
report_num_correct / report_tgt_words * 100,
220
math.exp(report_loss / report_tgt_words),
221
report_src_words/(time.time()-start),
0 commit comments