3
3
4
4
from copy import deepcopy
5
5
6
- import matplotlib .pyplot as plt
7
6
import numpy as np
7
+ import ray
8
8
import torch
9
9
import torch .nn .functional as F
10
10
from torch import Tensor , linspace , nn , randperm , sin
15
15
16
16
from utils import ParamDict as P
17
17
18
+ # To avoid tkinter not installed error on headless server
19
+ try :
20
+ import matplotlib
21
+ matplotlib .use ('AGG' )
22
+ import matplotlib .pyplot as plt
23
+ except :
24
+ pass
25
+
18
26
Weights = P
19
27
criterion = F .mse_loss
20
28
25
33
N = 50 # Use 50 evenly spaced points on sine wave.
26
34
27
35
LR , META_LR = 0.02 , 0.1 # Copy OpenAI's hyperparameters.
28
- BATCH_SIZE , META_BATCH_SIZE = 10 , 1
36
+ BATCH_SIZE , META_BATCH_SIZE = 10 , 3
29
37
EPOCHS , META_EPOCHS = 1 , 30_000
30
38
TEST_GRAD_STEPS = 2 ** 3
31
39
PLOT_EVERY = 3_000
@@ -54,16 +62,16 @@ def gen_task(num_pts=N) -> DataLoader:
54
62
# Need to make x N,1 instead of N, to avoid
55
63
# https://discuss.pytorch.org/t/dataloader-gives-double-instead-of-float
56
64
x = linspace (- 5 , 5 , num_pts )[:, None ].float ()
57
- y = a * sin (x + b )
65
+ y = a * sin (x + b ). float ()
58
66
59
67
dataset = TensorDataset (x , y )
60
68
61
69
loader = DataLoader (
62
- dataset ,
63
- batch_size = BATCH_SIZE ,
64
- shuffle = True ,
65
- pin_memory = CUDA_AVAILABLE ,
66
- )
70
+ dataset ,
71
+ batch_size = BATCH_SIZE ,
72
+ shuffle = True ,
73
+ pin_memory = CUDA_AVAILABLE ,
74
+ )
67
75
68
76
return loader
69
77
@@ -102,11 +110,14 @@ def evaluate(model: Model, task: DataLoader, criterion=criterion) -> float:
102
110
"""Evaluate model on all the task data at once."""
103
111
model .eval ()
104
112
105
- x , y = cuda (Variable (task .dataset .data_tensor )), cuda (Variable (task .dataset .target_tensor ))
113
+ x , y = cuda (Variable (task .dataset .data_tensor )), cuda (
114
+ Variable (task .dataset .target_tensor )
115
+ )
106
116
loss = criterion (model (x ), y )
107
117
return float (loss )
108
118
109
119
120
+ @ray .remote
110
121
def sgd (meta_weights : Weights , epochs : int ) -> Weights :
111
122
"""Run SGD on a randomly generated task."""
112
123
@@ -120,60 +131,68 @@ def sgd(meta_weights: Weights, epochs: int) -> Weights:
120
131
for x , y in task :
121
132
train_batch (x , y , model , opt )
122
133
123
- return P ( model .state_dict () )
134
+ return model .state_dict ()
124
135
125
136
126
137
def REPTILE (
127
- meta_weights : Weights ,
128
- meta_batch_size : int = META_BATCH_SIZE ,
129
- epochs : int = EPOCHS ,
130
- ) -> Weights :
138
+ meta_weights : Weights ,
139
+ meta_batch_size : int = META_BATCH_SIZE ,
140
+ epochs : int = EPOCHS ,
141
+ ) -> Weights :
131
142
"""Run one iteration of REPTILE."""
132
-
133
- weights = [sgd (meta_weights , epochs ) for _ in range (meta_batch_size )]
143
+ weights = ray .get ([
144
+ sgd .remote (meta_weights , epochs ) for _ in range (meta_batch_size )
145
+ ])
146
+ weights = [P (w ) for w in weights ]
134
147
135
148
# TODO Implement custom optimizer that makes this work with builtin
136
149
# optimizers easily. The multiplication by 0 is to get a ParamDict of the
137
150
# right size as the identity element for summation.
138
- meta_weights += (META_LR / epochs ) * sum ((w - meta_weights for w in weights ), 0 * meta_weights )
151
+ meta_weights += (META_LR / epochs ) * sum ((w - meta_weights
152
+ for w in weights ), 0 * meta_weights )
139
153
return meta_weights
140
154
141
155
142
156
if __name__ == '__main__' :
157
+ try :
158
+ ray .init ()
159
+ except Exception as e :
160
+ print (e )
143
161
144
162
# Need to put model on GPU first for tensors to have the right type.
145
- meta_weights = P (cuda (Model ()).state_dict ())
146
-
147
- # Generate fixed task to evaluate on.
148
- plot_task = gen_task ()
149
-
150
- x_all , y_all = plot_task .dataset .data_tensor , plot_task .dataset .target_tensor
151
- x_plot , y_plot = shuffle (x_all , y_all , length = 10 )
152
-
153
- # Set up plot
154
- fig , ax = plt .subplots ()
155
- true_curve = ax .plot (
156
- x_all .numpy (),
157
- y_all .numpy (),
158
- label = 'True' ,
159
- color = 'g' ,
160
- )
161
-
162
- ax .plot (
163
- x_plot .numpy (),
164
- y_plot .numpy (),
165
- 'x' ,
166
- label = 'Training points' ,
167
- color = 'k' ,
168
- )
169
-
170
- ax .legend (loc = "lower right" )
171
- ax .set_xlim (- 5 , 5 )
172
- ax .set_ylim (- 5 , 5 )
163
+ meta_weights = cuda (Model ()).state_dict ()
164
+
165
+ if PLOT :
166
+ # Generate fixed task to evaluate on.
167
+ plot_task = gen_task ()
168
+
169
+ x_all , y_all = plot_task .dataset .data_tensor , plot_task .dataset .target_tensor
170
+ x_plot , y_plot = shuffle (x_all , y_all , length = 10 )
171
+
172
+ # Set up plot
173
+ fig , ax = plt .subplots ()
174
+ true_curve = ax .plot (
175
+ x_all .numpy (),
176
+ y_all .numpy (),
177
+ label = 'True' ,
178
+ color = 'g' ,
179
+ )
180
+
181
+ ax .plot (
182
+ x_plot .numpy (),
183
+ y_plot .numpy (),
184
+ 'x' ,
185
+ label = 'Training points' ,
186
+ color = 'k' ,
187
+ )
188
+
189
+ ax .legend (loc = "lower right" )
190
+ ax .set_xlim (- 5 , 5 )
191
+ ax .set_ylim (- 5 , 5 )
173
192
174
193
for iteration in range (1 , META_EPOCHS + 1 ):
175
194
176
- meta_weights = REPTILE (meta_weights )
195
+ meta_weights = REPTILE (P ( meta_weights ) )
177
196
178
197
if iteration == 1 or iteration % PLOT_EVERY == 0 :
179
198
@@ -182,24 +201,21 @@ def REPTILE(
182
201
opt = SGD (model .parameters (), lr = LR )
183
202
184
203
for _ in range (TEST_GRAD_STEPS ):
185
- # Training on all the points rather than just a sample works better.
186
204
train_batch (x_plot , y_plot , model , opt )
187
205
188
206
if PLOT :
189
207
190
- ax .set_title (f'REPTILE after { iteration :n} iterations' )
191
- ( curve , ) = ax .plot (
192
- x_all .numpy (),
193
- model (Variable (x_all )).data .numpy (),
194
- label = f'Pred after { TEST_GRAD_STEPS } gradient steps.' ,
195
- color = 'r' ,
196
- )
208
+ ax .set_title (f'REPTILE after { iteration :n} iterations. ' )
209
+ curve , = ax .plot (
210
+ x_all .numpy (),
211
+ model (Variable (x_all )).data .numpy (),
212
+ label = f'Pred after { TEST_GRAD_STEPS :n } gradient steps.' ,
213
+ color = 'r' ,
214
+ )
197
215
198
216
plt .savefig (f'figs/{ iteration } .png' )
199
217
200
- # Pause before clearing and moving on to next plot.
201
- plt .pause (0.01 )
202
-
203
218
ax .lines .remove (curve )
204
219
205
220
print (f'Iteration: { iteration } \t Loss: { evaluate (model , plot_task ):.3f} ' )
221
+
0 commit comments