Skip to content

Commit 9ef8711

Browse files
committed
Update README to reflect parallelism
1 parent c80330e commit 9ef8711

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

reptile/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,18 @@ time. It evaluates on a fixed task every 1,000 iterations, taking 5
88
gradient descent steps per task. Turns out that 1 is enough to get good
99
performance, indicating that meta-learning is actually working.
1010

11+
Each meta batch will run in parallel thanks to Ray.
12+
1113
## Requirements
1214

1315
- Python 3.6+
1416
- Numpy
1517
- Matplotlib
18+
- [Ray](https://github.com/ray-project/ray)
1619

1720
## Running the Script
1821

1922
python3 main.py
2023

21-
It will pop up a Matplotlib window that updates every 3,000 iterations.
24+
If the `PLOT` flag in the code is set to `True`, it will create the
25+
directory `fig` and save plots to it.

reptile/main.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33

4+
import os
45
from copy import deepcopy
56

67
import numpy as np
@@ -67,11 +68,11 @@ def gen_task(num_pts=N) -> DataLoader:
6768
dataset = TensorDataset(x, y)
6869

6970
loader = DataLoader(
70-
dataset,
71-
batch_size=BATCH_SIZE,
72-
shuffle=True,
73-
pin_memory=CUDA_AVAILABLE,
74-
)
71+
dataset,
72+
batch_size=BATCH_SIZE,
73+
shuffle=True,
74+
pin_memory=CUDA_AVAILABLE,
75+
)
7576

7677
return loader
7778

@@ -111,8 +112,8 @@ def evaluate(model: Model, task: DataLoader, criterion=criterion) -> float:
111112
model.eval()
112113

113114
x, y = cuda(Variable(task.dataset.data_tensor)), cuda(
114-
Variable(task.dataset.target_tensor)
115-
)
115+
Variable(task.dataset.target_tensor)
116+
)
116117
loss = criterion(model(x), y)
117118
return float(loss)
118119

@@ -135,21 +136,21 @@ def sgd(meta_weights: Weights, epochs: int) -> Weights:
135136

136137

137138
def REPTILE(
138-
meta_weights: Weights,
139-
meta_batch_size: int = META_BATCH_SIZE,
140-
epochs: int = EPOCHS,
141-
) -> Weights:
139+
meta_weights: Weights,
140+
meta_batch_size: int = META_BATCH_SIZE,
141+
epochs: int = EPOCHS,
142+
) -> Weights:
142143
"""Run one iteration of REPTILE."""
143144
weights = ray.get([
144145
sgd.remote(meta_weights, epochs) for _ in range(meta_batch_size)
145-
])
146+
])
146147
weights = [P(w) for w in weights]
147148

148149
# TODO Implement custom optimizer that makes this work with builtin
149150
# optimizers easily. The multiplication by 0 is to get a ParamDict of the
150151
# right size as the identity element for summation.
151152
meta_weights += (META_LR / epochs) * sum((w - meta_weights
152-
for w in weights), 0 * meta_weights)
153+
for w in weights), 0 * meta_weights)
153154
return meta_weights
154155

155156

@@ -172,19 +173,19 @@ def REPTILE(
172173
# Set up plot
173174
fig, ax = plt.subplots()
174175
true_curve = ax.plot(
175-
x_all.numpy(),
176-
y_all.numpy(),
177-
label='True',
178-
color='g',
179-
)
176+
x_all.numpy(),
177+
y_all.numpy(),
178+
label='True',
179+
color='g',
180+
)
180181

181182
ax.plot(
182-
x_plot.numpy(),
183-
y_plot.numpy(),
184-
'x',
185-
label='Training points',
186-
color='k',
187-
)
183+
x_plot.numpy(),
184+
y_plot.numpy(),
185+
'x',
186+
label='Training points',
187+
color='k',
188+
)
188189

189190
ax.legend(loc="lower right")
190191
ax.set_xlim(-5, 5)
@@ -207,15 +208,15 @@ def REPTILE(
207208

208209
ax.set_title(f'REPTILE after {iteration:n} iterations.')
209210
curve, = ax.plot(
210-
x_all.numpy(),
211-
model(Variable(x_all)).data.numpy(),
212-
label=f'Pred after {TEST_GRAD_STEPS:n} gradient steps.',
213-
color='r',
214-
)
211+
x_all.numpy(),
212+
model(Variable(x_all)).data.numpy(),
213+
label=f'Pred after {TEST_GRAD_STEPS:n} gradient steps.',
214+
color='r',
215+
)
215216

217+
os.makedirs('figs', exist_ok=True)
216218
plt.savefig(f'figs/{iteration}.png')
217219

218220
ax.lines.remove(curve)
219221

220222
print(f'Iteration: {iteration}\tLoss: {evaluate(model, plot_task):.3f}')
221-

0 commit comments

Comments
 (0)