77import torch .optim as optim
88from torch .autograd import Variable
99import math
10+ from utils import preprocess_gradients
11+ from layer_norm_lstm import LayerNormLSTMCell
12+ from layer_norm import LayerNorm1D
1013
1114class MetaOptimizer (nn .Module ):
1215
@@ -16,16 +19,12 @@ def __init__(self, model, num_layers, hidden_size):
1619
1720 self .hidden_size = hidden_size
1821
19- self .linear1 = nn .Linear (2 , hidden_size )
22+ self .linear1 = nn .Linear (3 , hidden_size )
23+ self .ln1 = LayerNorm1D (hidden_size )
2024
2125 self .lstms = []
2226 for i in range (num_layers ):
23- self .lstms .append (nn .LSTMCell (hidden_size , hidden_size ))
24-
25- self .lstms [- 1 ].bias_ih .data .fill_ (0 )
26- self .lstms [- 1 ].bias_hh .data .fill_ (0 )
27- self .lstms [- 1 ].bias_hh .data [10 :20 ].fill_ (1 )
28-
27+ self .lstms .append (LayerNormLSTMCell (hidden_size , hidden_size ))
2928
3029 self .linear2 = nn .Linear (hidden_size , 1 )
3130 self .linear2 .weight .data .mul_ (0.1 )
@@ -53,20 +52,9 @@ def reset_lstm(self, keep_states=False, model=None, use_cuda=False):
5352 if use_cuda :
5453 self .hx [i ], self .cx [i ] = self .hx [i ].cuda (), self .cx [i ].cuda ()
5554
56- def forward (self , inputs ):
57- initial_size = inputs .size ()
58- x = inputs .view (- 1 , 1 )
59-
55+ def forward (self , x ):
6056 # Gradients preprocessing
61- p = 10
62- eps = 1e-6
63- indicator = (x .abs () > math .exp (- p )).float ()
64- x1 = (x .abs () + eps ).log () / p * indicator - (1 - indicator )
65- x2 = x .sign () * indicator + math .exp (p ) * x * (1 - indicator )
66-
67- x = torch .cat ((x1 , x2 ), 1 )
68-
69- x = F .tanh (self .linear1 (x ))
57+ x = F .tanh (self .ln1 (self .linear1 (x )))
7058
7159 for i in range (len (self .lstms )):
7260 if x .size (0 ) != self .hx [i ].size (0 ):
@@ -77,8 +65,7 @@ def forward(self, inputs):
7765 x = self .hx [i ]
7866
7967 x = self .linear2 (x )
80- x = x .view (* initial_size )
81- return x
68+ return x .squeeze ()
8269
8370 def meta_update (self , model_with_grads ):
8471 # First we need to create a flat version of parameters and gradients
@@ -89,10 +76,12 @@ def meta_update(self, model_with_grads):
8976 grads .append (module ._parameters ['bias' ].grad .data .view (- 1 ))
9077
9178 flat_params = self .meta_model .get_flat_params ()
92- flat_grads = Variable (torch .cat (grads ))
79+ flat_grads = preprocess_gradients (torch .cat (grads ))
80+
81+ inputs = Variable (torch .cat ((flat_grads , flat_params .data ), 1 ))
9382
9483 # Meta update itself
95- flat_params = flat_params + self (flat_grads )
84+ flat_params = flat_params + self (inputs )
9685
9786 self .meta_model .set_flat_params (flat_params )
9887
0 commit comments