Skip to content

Commit d6e6324

Browse files
andreh7soumith
authored andcommitted
added comments in snli/train.py, no code changes (pytorch#177)
1 parent 1c6d9d2 commit d6e6324

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

snli/train.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
config.n_embed = len(inputs.vocab)
3939
config.d_out = len(answers.vocab)
4040
config.n_cells = config.n_layers
41+
42+
# double the number of cells for bidirectional networks
4143
if config.birnn:
4244
config.n_cells *= 2
4345

@@ -66,41 +68,71 @@
6668
train_iter.init_epoch()
6769
n_correct, n_total = 0, 0
6870
for batch_idx, batch in enumerate(train_iter):
71+
72+
# switch model to training mode, clear gradient accumulators
6973
model.train(); opt.zero_grad()
74+
7075
iterations += 1
76+
77+
# forward pass
7178
answer = model(batch)
79+
80+
# calculate accuracy of predictions in the current batch
7281
n_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum()
7382
n_total += batch.batch_size
7483
train_acc = 100. * n_correct/n_total
84+
85+
# calculate loss of the network output with respect to training labels
7586
loss = criterion(answer, batch.label)
87+
88+
# backpropagate and update optimizer learning rate
7689
loss.backward(); opt.step()
90+
91+
# checkpoint model periodically
7792
if iterations % args.save_every == 0:
7893
snapshot_prefix = os.path.join(args.save_path, 'snapshot')
7994
snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations)
8095
torch.save(model, snapshot_path)
8196
for f in glob.glob(snapshot_prefix + '*'):
8297
if f != snapshot_path:
8398
os.remove(f)
99+
100+
# evaluate performance on validation set periodically
84101
if iterations % args.dev_every == 0:
102+
103+
# switch model to evaluation mode
85104
model.eval(); dev_iter.init_epoch()
105+
106+
# calculate accuracy on validation set
86107
n_dev_correct, dev_loss = 0, 0
87108
for dev_batch_idx, dev_batch in enumerate(dev_iter):
88109
answer = model(dev_batch)
89110
n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum()
90111
dev_loss = criterion(answer, dev_batch.label)
91112
dev_acc = 100. * n_dev_correct / len(dev)
113+
92114
print(dev_log_template.format(time.time()-start,
93115
epoch, iterations, 1+batch_idx, len(train_iter),
94116
100. * (1+batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc))
117+
118+
# update best valiation set accuracy
95119
if dev_acc > best_dev_acc:
120+
121+
# found a model with better validation set accuracy
122+
96123
best_dev_acc = dev_acc
97124
snapshot_prefix = os.path.join(args.save_path, 'best_snapshot')
98125
snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations)
126+
127+
# save model, delete previous 'best_snapshot' files
99128
torch.save(model, snapshot_path)
100129
for f in glob.glob(snapshot_prefix + '*'):
101130
if f != snapshot_path:
102131
os.remove(f)
132+
103133
elif iterations % args.log_every == 0:
134+
135+
# print progress message
104136
print(log_template.format(time.time()-start,
105137
epoch, iterations, 1+batch_idx, len(train_iter),
106138
100. * (1+batch_idx) / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12))

0 commit comments

Comments
 (0)