Skip to content

Commit c465800

Browse files
committed
Add LayerNormLSTM
1 parent 82169f2 commit c465800

File tree

4 files changed

+79
-24
lines changed

4 files changed

+79
-24
lines changed

layer_norm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
6+
7+
class LayerNorm1D(nn.Module):
8+
def __init__(self, num_outputs, eps=1e-5, affine=True):
9+
super(LayerNorm1D, self).__init__()
10+
self.eps = eps
11+
self.weight = nn.Parameter(torch.ones(1, num_outputs))
12+
self.bias = nn.Parameter(torch.zeros(1, num_outputs))
13+
14+
def forward(self, inputs):
15+
input_mean = inputs.mean(1).expand_as(inputs)
16+
input_std = inputs.std(1).expand_as(inputs)
17+
x = (inputs - input_mean) / (input_std + self.eps)
18+
return x * self.weight.expand_as(x) + self.bias.expand_as(x)

layer_norm_lstm.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
from layer_norm import LayerNorm1D
6+
7+
8+
class LayerNormLSTMCell(nn.Module):
9+
10+
def __init__(self, num_inputs, num_hidden, forget_gate_bias=-1):
11+
super(LayerNormLSTMCell, self).__init__()
12+
13+
self.forget_gate_bias = forget_gate_bias
14+
self.num_hidden = num_hidden
15+
self.fc_i2h = nn.Linear(num_inputs, 4 * num_hidden)
16+
self.fc_h2h = nn.Linear(num_hidden, 4 * num_hidden)
17+
18+
self.ln_i2h = LayerNorm1D(4 * num_hidden)
19+
self.ln_h2h = LayerNorm1D(4 * num_hidden)
20+
21+
self.ln_h2o = LayerNorm1D(num_hidden)
22+
23+
def forward(self, inputs, state):
24+
hx, cx = state
25+
i2h = self.fc_i2h(inputs)
26+
h2h = self.fc_h2h(hx)
27+
x = self.ln_i2h(i2h) + self.ln_h2h(h2h)
28+
gates = x.split(self.num_hidden, 1)
29+
30+
in_gate = F.sigmoid(gates[0])
31+
forget_gate = F.sigmoid(gates[1] + self.forget_gate_bias)
32+
out_gate = F.sigmoid(gates[2])
33+
in_transform = F.tanh(gates[3])
34+
35+
cx = forget_gate * cx + in_gate * in_transform
36+
hx = out_gate * F.tanh(self.ln_h2o(cx))
37+
return hx, cx

meta_optimizer.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import torch.optim as optim
88
from torch.autograd import Variable
99
import math
10+
from utils import preprocess_gradients
11+
from layer_norm_lstm import LayerNormLSTMCell
12+
from layer_norm import LayerNorm1D
1013

1114
class MetaOptimizer(nn.Module):
1215

@@ -16,16 +19,12 @@ def __init__(self, model, num_layers, hidden_size):
1619

1720
self.hidden_size = hidden_size
1821

19-
self.linear1 = nn.Linear(2, hidden_size)
22+
self.linear1 = nn.Linear(3, hidden_size)
23+
self.ln1 = LayerNorm1D(hidden_size)
2024

2125
self.lstms = []
2226
for i in range(num_layers):
23-
self.lstms.append(nn.LSTMCell(hidden_size, hidden_size))
24-
25-
self.lstms[-1].bias_ih.data.fill_(0)
26-
self.lstms[-1].bias_hh.data.fill_(0)
27-
self.lstms[-1].bias_hh.data[10:20].fill_(1)
28-
27+
self.lstms.append(LayerNormLSTMCell(hidden_size, hidden_size))
2928

3029
self.linear2 = nn.Linear(hidden_size, 1)
3130
self.linear2.weight.data.mul_(0.1)
@@ -53,20 +52,9 @@ def reset_lstm(self, keep_states=False, model=None, use_cuda=False):
5352
if use_cuda:
5453
self.hx[i], self.cx[i] = self.hx[i].cuda(), self.cx[i].cuda()
5554

56-
def forward(self, inputs):
57-
initial_size = inputs.size()
58-
x = inputs.view(-1, 1)
59-
55+
def forward(self, x):
6056
# Gradients preprocessing
61-
p = 10
62-
eps = 1e-6
63-
indicator = (x.abs() > math.exp(-p)).float()
64-
x1 = (x.abs() + eps).log() / p * indicator - (1 - indicator)
65-
x2 = x.sign() * indicator + math.exp(p) * x * (1 - indicator)
66-
67-
x = torch.cat((x1, x2), 1)
68-
69-
x = F.tanh(self.linear1(x))
57+
x = F.tanh(self.ln1(self.linear1(x)))
7058

7159
for i in range(len(self.lstms)):
7260
if x.size(0) != self.hx[i].size(0):
@@ -77,8 +65,7 @@ def forward(self, inputs):
7765
x = self.hx[i]
7866

7967
x = self.linear2(x)
80-
x = x.view(*initial_size)
81-
return x
68+
return x.squeeze()
8269

8370
def meta_update(self, model_with_grads):
8471
# First we need to create a flat version of parameters and gradients
@@ -89,10 +76,12 @@ def meta_update(self, model_with_grads):
8976
grads.append(module._parameters['bias'].grad.data.view(-1))
9077

9178
flat_params = self.meta_model.get_flat_params()
92-
flat_grads = Variable(torch.cat(grads))
79+
flat_grads = preprocess_gradients(torch.cat(grads))
80+
81+
inputs = Variable(torch.cat((flat_grads, flat_params.data), 1))
9382

9483
# Meta update itself
95-
flat_params = flat_params + self(flat_grads)
84+
flat_params = flat_params + self(inputs)
9685

9786
self.meta_model.set_flat_params(flat_params)
9887

utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import math
2+
import torch
3+
4+
def preprocess_gradients(x):
5+
p = 10
6+
eps = 1e-6
7+
indicator = (x.abs() > math.exp(-p)).float()
8+
x1 = (x.abs() + eps).log() / p * indicator - (1 - indicator)
9+
x2 = x.sign() * indicator + math.exp(p) * x * (1 - indicator)
10+
11+
return torch.cat((x1, x2), 1)

0 commit comments

Comments
 (0)