Skip to content

Commit 6e62924

Browse files
committed
Add truncated BPTT
1 parent f00d26a commit 6e62924

File tree

2 files changed

+48
-37
lines changed

2 files changed

+48
-37
lines changed

main.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,21 @@
1313

1414
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
1515
parser.add_argument('--batch_size', type=int, default=16, metavar='N',
16-
help='batch size (default: 16)')
17-
parser.add_argument('--optimizer_steps', type=int, default=10, metavar='N',
18-
help='number of meta optimizer steps (default: 10)')
16+
help='batch size (default: 16)')
17+
parser.add_argument('--optimizer_steps', type=int, default=20, metavar='N',
18+
help='number of meta optimizer steps (default: 20)')
19+
parser.add_argument('--truncated_bptt_step', type=int, default=10, metavar='N',
20+
help='step at which it truncates bptt (default: 10)')
1921
parser.add_argument('--updates_per_epoch', type=int, default=100, metavar='N',
20-
help='updates per epoch (default: 100)')
22+
help='updates per epoch (default: 100)')
2123
parser.add_argument('--max_epoch', type=int, default=100, metavar='N',
22-
help='number of epoch (default: 100)')
24+
help='number of epoch (default: 100)')
2325
parser.add_argument('--hidden_size', type=int, default=10, metavar='N',
24-
help='hidden size of the meta optimizer (default: 10)')
26+
help='hidden size of the meta optimizer (default: 10)')
2527
args = parser.parse_args()
2628

29+
assert args.optimizer_steps % args.truncated_bptt_step == 0
30+
2731
meta_optimizer = MetaOptimizer(args.hidden_size)
2832
optimizer = optim.Adam(meta_optimizer.parameters(), lr=1e-3)
2933

@@ -34,44 +38,47 @@
3438
# Sample a new model
3539
model = Model()
3640

37-
# Create a helper class
38-
meta_model = MetaModel()
39-
meta_model.copy_params_from(model)
40-
41-
# Reset lstm values of the meta optimizer
42-
meta_optimizer.reset_lstm()
43-
4441
x, y = get_batch(args.batch_size)
4542
x, y = Variable(x), Variable(y)
4643

4744
# Compute initial loss of the model
4845
f_x = model(x)
4946
initial_loss = (f_x - y).pow(2).mean()
50-
loss_sum = 0
51-
for j in range(args.optimizer_steps):
52-
x, y = get_batch(args.batch_size)
53-
x, y = Variable(x), Variable(y)
5447

55-
# First we need to compute the gradients of the model
56-
f_x = model(x)
57-
loss = (f_x - y).pow(2).mean()
58-
model.zero_grad()
59-
loss.backward()
48+
for k in range(args.optimizer_steps // args.truncated_bptt_step):
49+
# Keep states for truncated BPTT
50+
meta_optimizer.reset_lstm(keep_states=k > 0)
51+
52+
# Create a helper class
53+
meta_model = MetaModel()
54+
meta_model.copy_params_from(model)
6055

61-
# Perfom a meta update
62-
meta_optimizer.meta_update(meta_model, model)
56+
loss_sum = 0
57+
for j in range(args.truncated_bptt_step):
58+
x, y = get_batch(args.batch_size)
59+
x, y = Variable(x), Variable(y)
6360

64-
# Compute a loss for a step the meta optimizer
65-
f_x = meta_model(x)
66-
loss = (f_x - y).pow(2).mean()
67-
loss_sum += loss
61+
# First we need to compute the gradients of the model
62+
f_x = model(x)
63+
loss = (f_x - y).pow(2).mean()
64+
model.zero_grad()
65+
loss.backward()
66+
67+
# Perfom a meta update
68+
meta_optimizer.meta_update(meta_model, model)
69+
70+
# Compute a loss for a step the meta optimizer
71+
f_x = meta_model(x)
72+
loss = (f_x - y).pow(2).mean()
73+
loss_sum += loss
74+
75+
# Update the parameters of the meta optimizer
76+
meta_optimizer.zero_grad()
77+
loss_sum.backward()
78+
optimizer.step()
6879

6980
# Compute relative decrease in the loss function w.r.t initial value
7081
decrease_in_loss += loss.data[0] / initial_loss.data[0]
7182

72-
# Update the parameters of the meta optimizer
73-
meta_optimizer.zero_grad()
74-
loss_sum.backward()
75-
optimizer.step()
76-
77-
print("Epoch: {}, average final/initial loss ratio: {}".format(epoch, decrease_in_loss / args.updates_per_epoch))
83+
print("Epoch: {}, average final/initial loss ratio: {}".format(epoch,
84+
decrease_in_loss / args.updates_per_epoch))

meta_optimizer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,13 @@ def __init__(self, hidden_size):
2424

2525
self.reset_lstm()
2626

27-
def reset_lstm(self):
28-
self.hx = Variable(torch.zeros(1, self.hidden_size))
29-
self.cx = Variable(torch.zeros(1, self.hidden_size))
27+
def reset_lstm(self, keep_states=False):
28+
if keep_states:
29+
self.hx = Variable(self.hx.data)
30+
self.cx = Variable(self.cx.data)
31+
else:
32+
self.hx = Variable(torch.zeros(1, self.hidden_size))
33+
self.cx = Variable(torch.zeros(1, self.hidden_size))
3034

3135
def forward(self, inputs):
3236
initial_size = inputs.size()

0 commit comments

Comments
 (0)