Skip to content

Commit 5cb794b

Browse files
committed
Modify loss
1 parent d20e6e9 commit 5cb794b

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
help='number of meta optimizer steps (default: 100)')
2020
parser.add_argument('--truncated_bptt_step', type=int, default=20, metavar='N',
2121
help='step at which it truncates bptt (default: 20)')
22-
parser.add_argument('--updates_per_epoch', type=int, default=10, metavar='N',
22+
parser.add_argument('--updates_per_epoch', type=int, default=100, metavar='N',
2323
help='updates per epoch (default: 100)')
24-
parser.add_argument('--max_epoch', type=int, default=100, metavar='N',
25-
help='number of epoch (default: 100)')
24+
parser.add_argument('--max_epoch', type=int, default=10000, metavar='N',
25+
help='number of epoch (default: 10000)')
2626
parser.add_argument('--hidden_size', type=int, default=10, metavar='N',
2727
help='hidden size of the meta optimizer (default: 10)')
2828
parser.add_argument('--no-cuda', action='store_true', default=False,
@@ -85,6 +85,9 @@ def main():
8585
keep_states=k > 0, model=model, use_cuda=args.cuda)
8686

8787
loss_sum = 0
88+
prev_loss = torch.zeros(1)
89+
if args.cuda:
90+
prev_loss = prev_loss.cuda()
8891
for j in range(args.truncated_bptt_step):
8992
x, y = next(train_iter)
9093
if args.cuda:
@@ -104,7 +107,10 @@ def main():
104107
# Compute a loss for a step the meta optimizer
105108
f_x = meta_model(x)
106109
loss = F.nll_loss(f_x, y)
107-
loss_sum += loss
110+
111+
loss_sum += (loss - Variable(prev_loss))
112+
113+
prev_loss = loss.data
108114

109115
# Update the parameters of the meta optimizer
110116
meta_optimizer.zero_grad()
@@ -117,7 +123,7 @@ def main():
117123
# value
118124
decrease_in_loss += loss.data[0] / initial_loss.data[0]
119125

120-
print("Epoch: {}, average final/initial loss ratio: {}".format(epoch,
126+
print("Epoch: {}, final loss {}, average final/initial loss ratio: {}".format(epoch, loss.data[0],
121127
decrease_in_loss / args.updates_per_epoch))
122128

123129
if __name__ == "__main__":

0 commit comments

Comments
 (0)