Skip to content

Commit c80330e

Browse files
committed
Merge branch 'reptile-ray'
* reptile-ray: Hide plotting behind flag Implement parallel gradient computations Implement parallel reduce for future use
2 parents 116288f + 092fe20 commit c80330e

File tree

2 files changed

+136
-57
lines changed

2 files changed

+136
-57
lines changed

reptile/main.py

100644100755
Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from copy import deepcopy
55

6-
import matplotlib.pyplot as plt
76
import numpy as np
7+
import ray
88
import torch
99
import torch.nn.functional as F
1010
from torch import Tensor, linspace, nn, randperm, sin
@@ -15,6 +15,14 @@
1515

1616
from utils import ParamDict as P
1717

18+
# To avoid tkinter not installed error on headless server
19+
try:
20+
import matplotlib
21+
matplotlib.use('AGG')
22+
import matplotlib.pyplot as plt
23+
except:
24+
pass
25+
1826
Weights = P
1927
criterion = F.mse_loss
2028

@@ -25,7 +33,7 @@
2533
N = 50 # Use 50 evenly spaced points on sine wave.
2634

2735
LR, META_LR = 0.02, 0.1 # Copy OpenAI's hyperparameters.
28-
BATCH_SIZE, META_BATCH_SIZE = 10, 1
36+
BATCH_SIZE, META_BATCH_SIZE = 10, 3
2937
EPOCHS, META_EPOCHS = 1, 30_000
3038
TEST_GRAD_STEPS = 2**3
3139
PLOT_EVERY = 3_000
@@ -54,16 +62,16 @@ def gen_task(num_pts=N) -> DataLoader:
5462
# Need to make x N,1 instead of N, to avoid
5563
# https://discuss.pytorch.org/t/dataloader-gives-double-instead-of-float
5664
x = linspace(-5, 5, num_pts)[:, None].float()
57-
y = a * sin(x + b)
65+
y = a * sin(x + b).float()
5866

5967
dataset = TensorDataset(x, y)
6068

6169
loader = DataLoader(
62-
dataset,
63-
batch_size=BATCH_SIZE,
64-
shuffle=True,
65-
pin_memory=CUDA_AVAILABLE,
66-
)
70+
dataset,
71+
batch_size=BATCH_SIZE,
72+
shuffle=True,
73+
pin_memory=CUDA_AVAILABLE,
74+
)
6775

6876
return loader
6977

@@ -102,11 +110,14 @@ def evaluate(model: Model, task: DataLoader, criterion=criterion) -> float:
102110
"""Evaluate model on all the task data at once."""
103111
model.eval()
104112

105-
x, y = cuda(Variable(task.dataset.data_tensor)), cuda(Variable(task.dataset.target_tensor))
113+
x, y = cuda(Variable(task.dataset.data_tensor)), cuda(
114+
Variable(task.dataset.target_tensor)
115+
)
106116
loss = criterion(model(x), y)
107117
return float(loss)
108118

109119

120+
@ray.remote
110121
def sgd(meta_weights: Weights, epochs: int) -> Weights:
111122
"""Run SGD on a randomly generated task."""
112123

@@ -120,60 +131,68 @@ def sgd(meta_weights: Weights, epochs: int) -> Weights:
120131
for x, y in task:
121132
train_batch(x, y, model, opt)
122133

123-
return P(model.state_dict())
134+
return model.state_dict()
124135

125136

126137
def REPTILE(
127-
meta_weights: Weights,
128-
meta_batch_size: int = META_BATCH_SIZE,
129-
epochs: int = EPOCHS,
130-
) -> Weights:
138+
meta_weights: Weights,
139+
meta_batch_size: int = META_BATCH_SIZE,
140+
epochs: int = EPOCHS,
141+
) -> Weights:
131142
"""Run one iteration of REPTILE."""
132-
133-
weights = [sgd(meta_weights, epochs) for _ in range(meta_batch_size)]
143+
weights = ray.get([
144+
sgd.remote(meta_weights, epochs) for _ in range(meta_batch_size)
145+
])
146+
weights = [P(w) for w in weights]
134147

135148
# TODO Implement custom optimizer that makes this work with builtin
136149
# optimizers easily. The multiplication by 0 is to get a ParamDict of the
137150
# right size as the identity element for summation.
138-
meta_weights += (META_LR / epochs) * sum((w - meta_weights for w in weights), 0 * meta_weights)
151+
meta_weights += (META_LR / epochs) * sum((w - meta_weights
152+
for w in weights), 0 * meta_weights)
139153
return meta_weights
140154

141155

142156
if __name__ == '__main__':
157+
try:
158+
ray.init()
159+
except Exception as e:
160+
print(e)
143161

144162
# Need to put model on GPU first for tensors to have the right type.
145-
meta_weights = P(cuda(Model()).state_dict())
146-
147-
# Generate fixed task to evaluate on.
148-
plot_task = gen_task()
149-
150-
x_all, y_all = plot_task.dataset.data_tensor, plot_task.dataset.target_tensor
151-
x_plot, y_plot = shuffle(x_all, y_all, length=10)
152-
153-
# Set up plot
154-
fig, ax = plt.subplots()
155-
true_curve = ax.plot(
156-
x_all.numpy(),
157-
y_all.numpy(),
158-
label='True',
159-
color='g',
160-
)
161-
162-
ax.plot(
163-
x_plot.numpy(),
164-
y_plot.numpy(),
165-
'x',
166-
label='Training points',
167-
color='k',
168-
)
169-
170-
ax.legend(loc="lower right")
171-
ax.set_xlim(-5, 5)
172-
ax.set_ylim(-5, 5)
163+
meta_weights = cuda(Model()).state_dict()
164+
165+
if PLOT:
166+
# Generate fixed task to evaluate on.
167+
plot_task = gen_task()
168+
169+
x_all, y_all = plot_task.dataset.data_tensor, plot_task.dataset.target_tensor
170+
x_plot, y_plot = shuffle(x_all, y_all, length=10)
171+
172+
# Set up plot
173+
fig, ax = plt.subplots()
174+
true_curve = ax.plot(
175+
x_all.numpy(),
176+
y_all.numpy(),
177+
label='True',
178+
color='g',
179+
)
180+
181+
ax.plot(
182+
x_plot.numpy(),
183+
y_plot.numpy(),
184+
'x',
185+
label='Training points',
186+
color='k',
187+
)
188+
189+
ax.legend(loc="lower right")
190+
ax.set_xlim(-5, 5)
191+
ax.set_ylim(-5, 5)
173192

174193
for iteration in range(1, META_EPOCHS + 1):
175194

176-
meta_weights = REPTILE(meta_weights)
195+
meta_weights = REPTILE(P(meta_weights))
177196

178197
if iteration == 1 or iteration % PLOT_EVERY == 0:
179198

@@ -182,24 +201,21 @@ def REPTILE(
182201
opt = SGD(model.parameters(), lr=LR)
183202

184203
for _ in range(TEST_GRAD_STEPS):
185-
# Training on all the points rather than just a sample works better.
186204
train_batch(x_plot, y_plot, model, opt)
187205

188206
if PLOT:
189207

190-
ax.set_title(f'REPTILE after {iteration:n} iterations')
191-
(curve, ) = ax.plot(
192-
x_all.numpy(),
193-
model(Variable(x_all)).data.numpy(),
194-
label=f'Pred after {TEST_GRAD_STEPS} gradient steps.',
195-
color='r',
196-
)
208+
ax.set_title(f'REPTILE after {iteration:n} iterations.')
209+
curve, = ax.plot(
210+
x_all.numpy(),
211+
model(Variable(x_all)).data.numpy(),
212+
label=f'Pred after {TEST_GRAD_STEPS:n} gradient steps.',
213+
color='r',
214+
)
197215

198216
plt.savefig(f'figs/{iteration}.png')
199217

200-
# Pause before clearing and moving on to next plot.
201-
plt.pause(0.01)
202-
203218
ax.lines.remove(curve)
204219

205220
print(f'Iteration: {iteration}\tLoss: {evaluate(model, plot_task):.3f}')
221+

reptile/parallel_reduce.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
from typing import List
5+
import ray
6+
7+
# TODO parallelize
8+
9+
10+
def partition(xs) -> List[slice]:
11+
def split(xs) -> List[int]:
12+
# `bin` gives str '0b----', so we drop the first 2 chars, '0b'. We can
13+
# ignore the negative case since length is always nonnegative.
14+
base2_decomp = reversed([int(x) for x in f'{len(xs):b}'])
15+
16+
pows = sorted([i for i, a in enumerate(base2_decomp) if a != 0], )
17+
return pows
18+
19+
parts = split(xs)
20+
21+
slices, start = [], 0
22+
23+
for n in parts:
24+
stop = start + 2**n
25+
slices.append(slice(start, stop))
26+
start = stop
27+
28+
return slices
29+
30+
31+
@ray.remote()
32+
def foldr(f, xs):
33+
L = len(xs)
34+
if L == 1:
35+
return xs[0]
36+
elif L == 0:
37+
raise ValueError('Sequence must have length greater than 0.')
38+
else:
39+
slices = partition(xs)
40+
41+
@ray.remote
42+
def _foldr(chunk):
43+
while len(chunk) > 1:
44+
chunk = [f(chunk[2 * i], chunk[2 * i + 1]) for i, _ in enumerate(chunk[::2])]
45+
return chunk[0]
46+
47+
return foldr(f, [_foldr.remote(xs[slice]) for slice in slices])
48+
49+
50+
if __name__ == '__main__':
51+
from operator import add
52+
import hypothesis
53+
import hypothesis.strategies as st
54+
from hypothesis import assume, example, given, infer
55+
from functools import reduce
56+
57+
@given(xs=infer)
58+
def test_foldr(f, xs: List[int]):
59+
assume(xs != [])
60+
assert foldr(f, xs) == reduce(f, xs)
61+
62+
f = add
63+
test_foldr(f)

0 commit comments

Comments
 (0)