Skip to content

Commit fefe3cc

Browse files
committed
Initial commit
0 parents commit fefe3cc

File tree

4 files changed

+212
-0
lines changed

4 files changed

+212
-0
lines changed

data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
3+
4+
def get_batch(batch_size):
5+
x = torch.randn(batch_size, 10)
6+
x = x - 2 * x.pow(2)
7+
y = x.sum(1)
8+
return x, y

main.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import argparse
2+
import operator
3+
import sys
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch.optim as optim
9+
from data import get_batch
10+
from meta_optimizer import MetaOptimizer
11+
from model import MetaModel, Model
12+
from torch.autograd import Variable
13+
14+
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
15+
parser.add_argument('--batch_size', type=int, default=16, metavar='N',
16+
help='batch size (default: 16)')
17+
parser.add_argument('--optimizer_steps', type=int, default=10, metavar='N',
18+
help='number of meta optimizer steps (default: 10)')
19+
parser.add_argument('--updates_per_epoch', type=int, default=100, metavar='N',
20+
help='updates per epoch (default: 100)')
21+
parser.add_argument('--max_epoch', type=int, default=100, metavar='N',
22+
help='number of epoch (default: 100)')
23+
parser.add_argument('--hidden_size', type=int, default=10, metavar='N',
24+
help='hidden size of the meta optimizer (default: 10)')
25+
args = parser.parse_args()
26+
27+
meta_optimizer = MetaOptimizer(args.hidden_size)
28+
optimizer = optim.Adam(meta_optimizer.parameters(), lr=1e-3)
29+
30+
for epoch in range(args.max_epoch):
31+
decrease_in_loss = 0.0
32+
for i in range(args.updates_per_epoch):
33+
34+
# Sample a new model
35+
model = Model()
36+
37+
# Create a helper class
38+
meta_model = MetaModel()
39+
meta_model.copy_params_from(model)
40+
41+
# Reset lstm values of the meta optimizer
42+
meta_optimizer.reset_lstm()
43+
44+
x, y = get_batch(args.batch_size
45+
)
46+
x, y = Variable(x), Variable(y)
47+
48+
# Compute initial loss of the model
49+
f_x = model(x)
50+
initial_loss = (f_x - y).pow(2).mean()
51+
loss_sum = 0
52+
for j in range(args.optimizer_steps):
53+
x, y = get_batch(args.batch_size)
54+
x, y = Variable(x), Variable(y)
55+
56+
# First we need to compute the gradients of the model
57+
f_x = model(x)
58+
loss = (f_x - y).pow(2).mean()
59+
model.zero_grad()
60+
loss.backward()
61+
62+
# Perfom a meta update
63+
meta_optimizer.meta_update(meta_model, model)
64+
65+
# Compute a loss for a step the meta optimizer
66+
f_x = meta_model(x)
67+
loss = (f_x - y).pow(2).mean()
68+
loss_sum += loss
69+
70+
# Compute relative decrease in the loss function w.r.t initial value
71+
decrease_in_loss += loss.data[0] / initial_loss.data[0]
72+
73+
# Update the parameters of the meta optimizer
74+
meta_optimizer.zero_grad()
75+
loss_sum.backward()
76+
optimizer.step()
77+
78+
print("Epoch: {}, average final/initial loss ratio: {}".format(epoch, decrease_in_loss / args.updates_per_epoch))

meta_optimizer.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from functools import reduce
2+
from operator import mul
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
import torch.optim as optim
8+
from torch.autograd import Variable
9+
10+
11+
class MetaOptimizer(nn.Module):
12+
13+
def __init__(self, hidden_size):
14+
super(MetaOptimizer, self).__init__()
15+
self.hidden_size = hidden_size
16+
17+
self.linear1 = nn.Linear(1, hidden_size)
18+
19+
self.lstm = nn.LSTMCell(hidden_size, hidden_size)
20+
21+
self.linear2 = nn.Linear(hidden_size, 1)
22+
self.linear2.weight.data.mul_(0.1)
23+
self.linear2.bias.data.fill_(0.0)
24+
25+
self.reset_lstm()
26+
27+
def reset_lstm(self):
28+
self.hx = Variable(torch.zeros(1, self.hidden_size))
29+
self.cx = Variable(torch.zeros(1, self.hidden_size))
30+
31+
def forward(self, inputs):
32+
initial_size = inputs.size()
33+
x = inputs.view(-1, 1)
34+
x = F.tanh(self.linear1(x))
35+
36+
if x.size(0) != self.hx.size(0):
37+
self.hx = self.hx.expand(x.size(0), self.hx.size(1))
38+
self.cx = self.hx.expand(x.size(0), self.cx.size(1))
39+
40+
self.hx, self.cx = self.lstm(x, (self.hx, self.cx))
41+
x = self.hx
42+
43+
x = self.linear2(x)
44+
x = x.view(*initial_size)
45+
return x
46+
47+
def meta_update(self, meta_model, model_with_grads):
48+
# First we need to create a flat version of parameters and gradients
49+
weight_shapes = []
50+
bias_shapes = []
51+
52+
params = []
53+
grads = []
54+
55+
for module in meta_model.children():
56+
weight_shapes.append(list(module._parameters['weight'].size()))
57+
bias_shapes.append(list(module._parameters['bias'].size()))
58+
59+
params.append(module._parameters['weight'].view(-1))
60+
params.append(module._parameters['bias'].view(-1))
61+
62+
for module in model_with_grads.children():
63+
grads.append(module._parameters['weight'].grad.view(-1))
64+
grads.append(module._parameters['bias'].grad.view(-1))
65+
66+
flat_params = torch.cat(params)
67+
flat_grads = torch.cat(grads)
68+
69+
# Meta update itself
70+
flat_params = flat_params + self(flat_grads)
71+
72+
# Restore original shapes
73+
offset = 0
74+
for i, module in enumerate(meta_model.children()):
75+
weight_flat_size = reduce(mul, weight_shapes[i], 1)
76+
bias_flat_size = reduce(mul, bias_shapes[i], 1)
77+
78+
module._parameters['weight'] = flat_params[
79+
offset:offset + weight_flat_size].view(*weight_shapes[i])
80+
module._parameters['bias'] = flat_params[
81+
offset + weight_flat_size:offset + weight_flat_size + bias_flat_size].view(*bias_shapes[i])
82+
83+
offset += weight_flat_size + bias_flat_size
84+
85+
# Finally, copy values from the meta model to the normal one.
86+
meta_model.copy_params_to(model_with_grads)

model.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torch.optim as optim
5+
from torch.autograd import Variable
6+
7+
8+
class Model(nn.Module):
9+
10+
def __init__(self):
11+
super(Model, self).__init__()
12+
self.linear1 = nn.Linear(10, 16)
13+
self.linear2 = nn.Linear(16, 1)
14+
15+
def forward(self, inputs):
16+
x = F.tanh(self.linear1(inputs))
17+
x = self.linear2(x)
18+
return x
19+
20+
# A helper class that keeps track of meta updates
21+
# It's done by replacing parameters with variables and applying updates to
22+
# them.
23+
24+
25+
class MetaModel(Model):
26+
27+
def reset(self):
28+
for module in self.children():
29+
module._parameters['weight'] = Variable(
30+
module._parameters['weight'].data)
31+
module._parameters['bias'] = Variable(
32+
module._parameters['bias'].data)
33+
34+
def copy_params_from(self, model):
35+
for modelA, modelB in zip(self.parameters(), model.parameters()):
36+
modelA.data.copy_(modelB.data)
37+
38+
def copy_params_to(self, model):
39+
for modelA, modelB in zip(self.parameters(), model.parameters()):
40+
modelB.data.copy_(modelA.data)

0 commit comments

Comments
 (0)