|
38 | 38 | config.n_embed = len(inputs.vocab) |
39 | 39 | config.d_out = len(answers.vocab) |
40 | 40 | config.n_cells = config.n_layers |
| 41 | + |
| 42 | +# double the number of cells for bidirectional networks |
41 | 43 | if config.birnn: |
42 | 44 | config.n_cells *= 2 |
43 | 45 |
|
|
66 | 68 | train_iter.init_epoch() |
67 | 69 | n_correct, n_total = 0, 0 |
68 | 70 | for batch_idx, batch in enumerate(train_iter): |
| 71 | + |
| 72 | + # switch model to training mode, clear gradient accumulators |
69 | 73 | model.train(); opt.zero_grad() |
| 74 | + |
70 | 75 | iterations += 1 |
| 76 | + |
| 77 | + # forward pass |
71 | 78 | answer = model(batch) |
| 79 | + |
| 80 | + # calculate accuracy of predictions in the current batch |
72 | 81 | n_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum() |
73 | 82 | n_total += batch.batch_size |
74 | 83 | train_acc = 100. * n_correct/n_total |
| 84 | + |
| 85 | + # calculate loss of the network output with respect to training labels |
75 | 86 | loss = criterion(answer, batch.label) |
| 87 | + |
| 88 | + # backpropagate and update optimizer learning rate |
76 | 89 | loss.backward(); opt.step() |
| 90 | + |
| 91 | + # checkpoint model periodically |
77 | 92 | if iterations % args.save_every == 0: |
78 | 93 | snapshot_prefix = os.path.join(args.save_path, 'snapshot') |
79 | 94 | snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations) |
80 | 95 | torch.save(model, snapshot_path) |
81 | 96 | for f in glob.glob(snapshot_prefix + '*'): |
82 | 97 | if f != snapshot_path: |
83 | 98 | os.remove(f) |
| 99 | + |
| 100 | + # evaluate performance on validation set periodically |
84 | 101 | if iterations % args.dev_every == 0: |
| 102 | + |
| 103 | + # switch model to evaluation mode |
85 | 104 | model.eval(); dev_iter.init_epoch() |
| 105 | + |
| 106 | + # calculate accuracy on validation set |
86 | 107 | n_dev_correct, dev_loss = 0, 0 |
87 | 108 | for dev_batch_idx, dev_batch in enumerate(dev_iter): |
88 | 109 | answer = model(dev_batch) |
89 | 110 | n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum() |
90 | 111 | dev_loss = criterion(answer, dev_batch.label) |
91 | 112 | dev_acc = 100. * n_dev_correct / len(dev) |
| 113 | + |
92 | 114 | print(dev_log_template.format(time.time()-start, |
93 | 115 | epoch, iterations, 1+batch_idx, len(train_iter), |
94 | 116 | 100. * (1+batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc)) |
| 117 | + |
| 118 | + # update best valiation set accuracy |
95 | 119 | if dev_acc > best_dev_acc: |
| 120 | + |
| 121 | + # found a model with better validation set accuracy |
| 122 | + |
96 | 123 | best_dev_acc = dev_acc |
97 | 124 | snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') |
98 | 125 | snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations) |
| 126 | + |
| 127 | + # save model, delete previous 'best_snapshot' files |
99 | 128 | torch.save(model, snapshot_path) |
100 | 129 | for f in glob.glob(snapshot_prefix + '*'): |
101 | 130 | if f != snapshot_path: |
102 | 131 | os.remove(f) |
| 132 | + |
103 | 133 | elif iterations % args.log_every == 0: |
| 134 | + |
| 135 | + # print progress message |
104 | 136 | print(log_template.format(time.time()-start, |
105 | 137 | epoch, iterations, 1+batch_idx, len(train_iter), |
106 | 138 | 100. * (1+batch_idx) / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12)) |
|
0 commit comments