Skip to content

Commit eef3f47

Browse files
committed
Add cuda support
1 parent 44ca158 commit eef3f47

File tree

3 files changed

+80
-54
lines changed

3 files changed

+80
-54
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ python main.py
1212
- [x] Initial implementation
1313
- [x] Toy data
1414
- [x] LSTM updates
15+
- [ ] Refactor, find a better way to organize the modules
1516
- [ ] Compare with standard optimizers
1617
- [ ] Real data
1718
- [ ] More difficult models

main.py

Lines changed: 76 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,61 +24,84 @@
2424
help='number of epoch (default: 100)')
2525
parser.add_argument('--hidden_size', type=int, default=10, metavar='N',
2626
help='hidden size of the meta optimizer (default: 10)')
27+
parser.add_argument('--no-cuda', action='store_true', default=False,
28+
help='enables CUDA training')
2729
args = parser.parse_args()
30+
args.cuda = not args.no_cuda and torch.cuda.is_available()
2831

2932
assert args.optimizer_steps % args.truncated_bptt_step == 0
3033

31-
# Create a meta optimizer that wraps a model into a meta model
32-
# to keep track of the meta updates.
33-
meta_optimizer = MetaOptimizer(MetaModel(Model()), args.hidden_size)
34-
optimizer = optim.Adam(meta_optimizer.parameters(), lr=1e-3)
35-
loss_fn = lambda f_x, y: (f_x - y).pow(2).mean()
3634

37-
for epoch in range(args.max_epoch):
38-
decrease_in_loss = 0.0
39-
for i in range(args.updates_per_epoch):
40-
41-
# Sample a new model
42-
model = Model()
43-
44-
x, y = get_batch(args.batch_size)
45-
x, y = Variable(x), Variable(y)
46-
47-
# Compute initial loss of the model
48-
f_x = model(x)
49-
initial_loss = loss_fn(f_x, y)
50-
51-
for k in range(args.optimizer_steps // args.truncated_bptt_step):
52-
# Keep states for truncated BPTT
53-
meta_optimizer.reset_lstm(keep_states=k > 0, model=model)
54-
55-
loss_sum = 0
56-
for j in range(args.truncated_bptt_step):
57-
x, y = get_batch(args.batch_size)
58-
x, y = Variable(x), Variable(y)
59-
60-
# First we need to compute the gradients of the model
61-
f_x = model(x)
62-
loss = loss_fn(f_x, y)
63-
model.zero_grad()
64-
loss.backward()
65-
66-
# Perfom a meta update using gradients from model
67-
# and return the current meta model saved in the optimizer
68-
meta_model = meta_optimizer.meta_update(model)
69-
70-
# Compute a loss for a step the meta optimizer
71-
f_x = meta_model(x)
72-
loss = loss_fn(f_x, y)
73-
loss_sum += loss
74-
75-
# Update the parameters of the meta optimizer
76-
meta_optimizer.zero_grad()
77-
loss_sum.backward()
78-
optimizer.step()
79-
80-
# Compute relative decrease in the loss function w.r.t initial value
81-
decrease_in_loss += loss.data[0] / initial_loss.data[0]
82-
83-
print("Epoch: {}, average final/initial loss ratio: {}".format(epoch,
84-
decrease_in_loss / args.updates_per_epoch))
35+
def main():
36+
# Create a meta optimizer that wraps a model into a meta model
37+
# to keep track of the meta updates.
38+
meta_model = Model()
39+
if args.cuda:
40+
meta_model.cuda()
41+
42+
meta_optimizer = MetaOptimizer(MetaModel(meta_model), args.hidden_size)
43+
if args.cuda:
44+
meta_optimizer.cuda()
45+
46+
optimizer = optim.Adam(meta_optimizer.parameters(), lr=1e-3)
47+
loss_fn = lambda f_x, y: (f_x - y).pow(2).mean()
48+
49+
for epoch in range(args.max_epoch):
50+
decrease_in_loss = 0.0
51+
for i in range(args.updates_per_epoch):
52+
53+
# Sample a new model
54+
model = Model()
55+
if args.cuda:
56+
model.cuda()
57+
58+
x, y = get_batch(args.batch_size)
59+
x, y = Variable(x), Variable(y)
60+
if args.cuda:
61+
x, y = x.cuda(), y.cuda()
62+
63+
# Compute initial loss of the model
64+
f_x = model(x)
65+
initial_loss = loss_fn(f_x, y)
66+
67+
for k in range(args.optimizer_steps // args.truncated_bptt_step):
68+
# Keep states for truncated BPTT
69+
meta_optimizer.reset_lstm(
70+
keep_states=k > 0, model=model, use_cuda=args.cuda)
71+
72+
loss_sum = 0
73+
for j in range(args.truncated_bptt_step):
74+
x, y = get_batch(args.batch_size)
75+
x, y = Variable(x), Variable(y)
76+
if args.cuda:
77+
x, y = x.cuda(), y.cuda()
78+
79+
# First we need to compute the gradients of the model
80+
f_x = model(x)
81+
loss = loss_fn(f_x, y)
82+
model.zero_grad()
83+
loss.backward()
84+
85+
# Perfom a meta update using gradients from model
86+
# and return the current meta model saved in the optimizer
87+
meta_model = meta_optimizer.meta_update(model)
88+
89+
# Compute a loss for a step the meta optimizer
90+
f_x = meta_model(x)
91+
loss = loss_fn(f_x, y)
92+
loss_sum += loss
93+
94+
# Update the parameters of the meta optimizer
95+
meta_optimizer.zero_grad()
96+
loss_sum.backward()
97+
optimizer.step()
98+
99+
# Compute relative decrease in the loss function w.r.t initial
100+
# value
101+
decrease_in_loss += loss.data[0] / initial_loss.data[0]
102+
103+
print("Epoch: {}, average final/initial loss ratio: {}".format(epoch,
104+
decrease_in_loss / args.updates_per_epoch))
105+
106+
if __name__ == "__main__":
107+
main()

meta_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, model, hidden_size):
2424
self.linear2.weight.data.mul_(0.1)
2525
self.linear2.bias.data.fill_(0.0)
2626

27-
def reset_lstm(self, keep_states=False, model=None):
27+
def reset_lstm(self, keep_states=False, model=None, use_cuda=False):
2828
self.meta_model.reset()
2929
self.meta_model.copy_params_from(model)
3030

@@ -34,6 +34,8 @@ def reset_lstm(self, keep_states=False, model=None):
3434
else:
3535
self.hx = Variable(torch.zeros(1, self.hidden_size))
3636
self.cx = Variable(torch.zeros(1, self.hidden_size))
37+
if use_cuda:
38+
self.hx, self.cx = self.hx.cuda(), self.cx.cuda()
3739

3840
def forward(self, inputs):
3941
initial_size = inputs.size()

0 commit comments

Comments
 (0)