1
1
#!/usr/bin/env python3
2
2
# -*- coding: utf-8 -*-
3
3
4
+ import os
4
5
from copy import deepcopy
5
6
6
7
import numpy as np
@@ -67,11 +68,11 @@ def gen_task(num_pts=N) -> DataLoader:
67
68
dataset = TensorDataset (x , y )
68
69
69
70
loader = DataLoader (
70
- dataset ,
71
- batch_size = BATCH_SIZE ,
72
- shuffle = True ,
73
- pin_memory = CUDA_AVAILABLE ,
74
- )
71
+ dataset ,
72
+ batch_size = BATCH_SIZE ,
73
+ shuffle = True ,
74
+ pin_memory = CUDA_AVAILABLE ,
75
+ )
75
76
76
77
return loader
77
78
@@ -111,8 +112,8 @@ def evaluate(model: Model, task: DataLoader, criterion=criterion) -> float:
111
112
model .eval ()
112
113
113
114
x , y = cuda (Variable (task .dataset .data_tensor )), cuda (
114
- Variable (task .dataset .target_tensor )
115
- )
115
+ Variable (task .dataset .target_tensor )
116
+ )
116
117
loss = criterion (model (x ), y )
117
118
return float (loss )
118
119
@@ -135,21 +136,21 @@ def sgd(meta_weights: Weights, epochs: int) -> Weights:
135
136
136
137
137
138
def REPTILE (
138
- meta_weights : Weights ,
139
- meta_batch_size : int = META_BATCH_SIZE ,
140
- epochs : int = EPOCHS ,
141
- ) -> Weights :
139
+ meta_weights : Weights ,
140
+ meta_batch_size : int = META_BATCH_SIZE ,
141
+ epochs : int = EPOCHS ,
142
+ ) -> Weights :
142
143
"""Run one iteration of REPTILE."""
143
144
weights = ray .get ([
144
145
sgd .remote (meta_weights , epochs ) for _ in range (meta_batch_size )
145
- ])
146
+ ])
146
147
weights = [P (w ) for w in weights ]
147
148
148
149
# TODO Implement custom optimizer that makes this work with builtin
149
150
# optimizers easily. The multiplication by 0 is to get a ParamDict of the
150
151
# right size as the identity element for summation.
151
152
meta_weights += (META_LR / epochs ) * sum ((w - meta_weights
152
- for w in weights ), 0 * meta_weights )
153
+ for w in weights ), 0 * meta_weights )
153
154
return meta_weights
154
155
155
156
@@ -172,19 +173,19 @@ def REPTILE(
172
173
# Set up plot
173
174
fig , ax = plt .subplots ()
174
175
true_curve = ax .plot (
175
- x_all .numpy (),
176
- y_all .numpy (),
177
- label = 'True' ,
178
- color = 'g' ,
179
- )
176
+ x_all .numpy (),
177
+ y_all .numpy (),
178
+ label = 'True' ,
179
+ color = 'g' ,
180
+ )
180
181
181
182
ax .plot (
182
- x_plot .numpy (),
183
- y_plot .numpy (),
184
- 'x' ,
185
- label = 'Training points' ,
186
- color = 'k' ,
187
- )
183
+ x_plot .numpy (),
184
+ y_plot .numpy (),
185
+ 'x' ,
186
+ label = 'Training points' ,
187
+ color = 'k' ,
188
+ )
188
189
189
190
ax .legend (loc = "lower right" )
190
191
ax .set_xlim (- 5 , 5 )
@@ -207,15 +208,15 @@ def REPTILE(
207
208
208
209
ax .set_title (f'REPTILE after { iteration :n} iterations.' )
209
210
curve , = ax .plot (
210
- x_all .numpy (),
211
- model (Variable (x_all )).data .numpy (),
212
- label = f'Pred after { TEST_GRAD_STEPS :n} gradient steps.' ,
213
- color = 'r' ,
214
- )
211
+ x_all .numpy (),
212
+ model (Variable (x_all )).data .numpy (),
213
+ label = f'Pred after { TEST_GRAD_STEPS :n} gradient steps.' ,
214
+ color = 'r' ,
215
+ )
215
216
217
+ os .makedirs ('figs' , exist_ok = True )
216
218
plt .savefig (f'figs/{ iteration } .png' )
217
219
218
220
ax .lines .remove (curve )
219
221
220
222
print (f'Iteration: { iteration } \t Loss: { evaluate (model , plot_task ):.3f} ' )
221
-
0 commit comments