Skip to content

Commit b1ae7f9

Browse files
csarofeenapaszke
authored andcommitted
Added functionality for data parallel table (pytorch#843)
1 parent 8b61ee5 commit b1ae7f9

File tree

6 files changed

+155
-50
lines changed

6 files changed

+155
-50
lines changed

test/test_nn.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ def bw_hook(inc, h_module, grad_input, grad_output):
259259
self.assertEqual(counter['forwards'], 2)
260260
self.assertEqual(counter['backwards'], 0)
261261

262-
test_bwd = module.register_backward_hook(lambda *args: bw_hook(1, *args))
262+
test_bwd = module.register_backward_hook(
263+
lambda *args: bw_hook(1, *args))
263264

264265
output = module(input)
265266
self.assertEqual(counter['forwards'], 3)
@@ -816,7 +817,8 @@ def test_parallel_apply(self):
816817
inputs = ((i1,), (i2,))
817818
modules = (l1, l2)
818819
expected_outputs = (expected1, expected2)
819-
outputs = dp.parallel_apply(modules, inputs)
820+
821+
outputs = dp.parallel_apply(modules, inputs, None)
820822
for out, expected in zip(outputs, expected_outputs):
821823
self.assertEqual(out.data, expected)
822824

@@ -833,27 +835,67 @@ def test_data_parallel_noop(self):
833835
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
834836
def test_data_parallel_multiple_input(self):
835837
class TestModule(nn.Module):
836-
def forward(self, x, y):
837-
return x + y
838-
839-
m = TestModule()
840-
x = Variable(torch.randn(5, 5).float())
841-
y = Variable(torch.randn(5, 5).float())
842-
expected = m(x, y)
843838

844-
out = dp.data_parallel(m, (x, y), (0, 1))
845-
self.assertEqual(out, expected)
839+
def forward(self, var1, var2, float1, var3=None):
840+
if var3 is None:
841+
return float1 * (var1 * var2)
842+
else:
843+
return float1 * (var1 * var2 + var3)
846844

847-
out = dp.data_parallel(m, (x, y), (0,))
848-
self.assertEqual(out, expected)
845+
m = TestModule()
846+
var1 = Variable(torch.randn(5, 5).float(), requires_grad=True)
847+
var2 = Variable(torch.randn(5, 5).float(), requires_grad=True)
848+
var3 = Variable(torch.randn(5, 5).float(), requires_grad=False)
849+
850+
float1 = torch.randn(1)[0]
851+
target = Variable(torch.randn(5, 5).float()).cuda()
852+
crit = nn.MSELoss()
853+
854+
expected = m(var1, var2, float1)
855+
loss = expected.sum()
856+
loss.backward()
857+
gvar1_exp = var1.grad.clone()
858+
gvar2_exp = var2.grad.clone()
859+
860+
def local_test(out):
861+
var1.grad.data.fill_(0.0)
862+
var2.grad.data.fill_(0.0)
863+
loss = out.sum()
864+
loss.backward()
865+
self.assertEqual(out, expected)
866+
self.assertEqual(gvar1_exp, var1.grad)
867+
self.assertEqual(gvar2_exp, var2.grad)
868+
869+
out = dp.data_parallel(m, (var1, var2, float1), (0, 1))
870+
local_test(out)
871+
872+
out = dp.data_parallel(m, (var1, var2, float1), (0,))
873+
local_test(out)
874+
875+
var1.grad.data.fill_(0.0)
876+
var2.grad.data.fill_(0.0)
877+
expected = m(var1, var2, float1, var3=var3)
878+
loss = expected.sum()
879+
loss.backward()
880+
gvar1_exp = var1.grad.clone()
881+
gvar2_exp = var2.grad.clone()
849882

850883
dpm = nn.DataParallel(TestModule())
851-
out = dpm(x, y)
852-
self.assertEqual(out, expected)
884+
out = dpm(var1, var2, float1, var3=var3)
885+
local_test(out)
853886

854887
dpm = nn.DataParallel(TestModule(), device_ids=[0])
855-
out = dpm(x, y)
856-
self.assertEqual(out, expected)
888+
out = dpm(var1, var2, float1, var3=var3)
889+
local_test(out)
890+
891+
kwarg_wrap = {'var3': var3}
892+
out = dp.data_parallel(
893+
m, (var1, var2, float1), (0, 1), module_kwargs=kwarg_wrap)
894+
local_test(out)
895+
896+
out = dp.data_parallel(
897+
m, (var1, var2, float1), (0,), module_kwargs=kwarg_wrap)
898+
local_test(out)
857899

858900
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
859901
def test_data_parallel_small_back(self):
@@ -1426,8 +1468,10 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu):
14261468
for nonlinearity in ('tanh', 'relu'):
14271469
hx_val = torch.randn(num_layers, batch, hidden_size)
14281470
input_val = torch.randn(seq_length, batch, input_size)
1429-
grad_output = torch.randn(seq_length, batch, hidden_size * num_directions)
1430-
grad_hy = torch.randn(num_layers * num_directions, batch, hidden_size)
1471+
grad_output = torch.randn(
1472+
seq_length, batch, hidden_size * num_directions)
1473+
grad_hy = torch.randn(
1474+
num_layers * num_directions, batch, hidden_size)
14311475

14321476
rnn = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity)
14331477
outputs_cpu = forward_backward(False, rnn, input_val, hx_val, grad_output, grad_hy, rnn.all_weights)

torch/cuda/comm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def scatter(tensor, devices, chunk_sizes=None, dim=0):
9494
assert min(chunk_sizes) > 0, "got a negative chunk_size"
9595
chunks = [tensor.narrow(dim, start - size, size)
9696
for start, size in zip(_accumulate(chunk_sizes), chunk_sizes)]
97+
chunks = tuple(chunk.contiguous() for chunk in chunks)
9798
# TODO: copy to a pinned buffer first (if copying from CPU)
9899
return tuple(chunk.cuda(gpu_id, async=chunk.is_contiguous())
99100
for gpu_id, chunk in zip(devices, chunks))

torch/nn/parallel/data_parallel.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99
class DataParallel(Module):
10+
1011
"""Implements data parallelism at the module level.
1112
1213
This container parallelizes the application of the given module by
@@ -21,6 +22,12 @@ class DataParallel(Module):
2122
2223
See also: :ref:`cuda-nn-dataparallel-instead`
2324
25+
Arbitrary positional and keyword inputs are allowed to be passed into
26+
DataParallel EXCEPT Tensors. All variables will be scattered on dim
27+
specified (default 0). Primitive types will be broadcasted, but all
28+
other types will be a shallow copy and can be corrupted if written to in
29+
the model's forward pass.
30+
2431
Args:
2532
module: module to be parallelized
2633
device_ids: CUDA devices (default: all devices)
@@ -36,48 +43,70 @@ class DataParallel(Module):
3643
"""
3744

3845
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
39-
def __init__(self, module, device_ids=None, output_device=None):
46+
47+
def __init__(self, module, device_ids=None, output_device=None, dim=0):
4048
super(DataParallel, self).__init__()
4149
if device_ids is None:
4250
device_ids = list(range(torch.cuda.device_count()))
4351
if output_device is None:
4452
output_device = device_ids[0]
53+
self.dim = dim
4554
self.module = module
4655
self.device_ids = device_ids
4756
self.output_device = output_device
4857
if len(self.device_ids) == 1:
4958
self.module.cuda(device_ids[0])
5059

51-
def forward(self, *inputs):
60+
def forward(self, *inputs, **kwargs):
5261
def _to_cuda(obj):
5362
if isinstance(obj, Variable):
5463
return obj.cuda()
55-
return tuple((map(_to_cuda, obj)))
64+
if isinstance(obj, tuple) or isinstance(obj, list):
65+
return type(obj)((map(_to_cuda, obj)))
66+
return obj
5667

5768
if len(self.device_ids) == 1:
5869
with torch.cuda.device(self.device_ids[0]):
5970
inputs_cuda = _to_cuda(inputs)
60-
return self.module(*inputs_cuda)
71+
if kwargs:
72+
gpu_dict = {}
73+
for key in kwargs.keys():
74+
gpu_dict[key] = _to_cuda(kwargs[key])
75+
return self.module(*inputs_cuda, **gpu_dict)
76+
else:
77+
return self.module(*inputs_cuda)
78+
6179
replicas = self.replicate(self.module, self.device_ids)
6280
scattered = self.scatter(inputs, self.device_ids)
81+
82+
gpu_dicts = None
83+
if kwargs:
84+
scatter_kwargs = {}
85+
for key in kwargs.keys():
86+
scatter_kwargs[key] = self.scatter(
87+
_to_cuda(kwargs[key]), self.device_ids)
88+
gpu_dicts = tuple(
89+
{key: values[i] for key, values in scatter_kwargs.items()}
90+
for i in self.device_ids
91+
)
6392
replicas = replicas[:len(scattered)]
64-
outputs = self.parallel_apply(replicas, scattered)
93+
outputs = self.parallel_apply(replicas, scattered, gpu_dicts)
6594
return self.gather(outputs, self.output_device)
6695

6796
def replicate(self, module, device_ids):
6897
return replicate(module, device_ids)
6998

7099
def scatter(self, input, device_ids):
71-
return scatter(input, device_ids)
100+
return scatter(input, device_ids, dim=self.dim)
72101

73-
def parallel_apply(self, replicas, inputs):
74-
return parallel_apply(replicas, inputs)
102+
def parallel_apply(self, replicas, inputs, kwargs):
103+
return parallel_apply(replicas, inputs, kwargs)
75104

76105
def gather(self, outputs, output_device):
77-
return gather(outputs, output_device)
106+
return gather(outputs, output_device, dim=self.dim)
78107

79108

80-
def data_parallel(module, inputs, device_ids, output_device=None):
109+
def data_parallel(module, inputs, device_ids, output_device=None, dim=0, module_kwargs=None):
81110
"""Evaluates module(input) in parallel across the GPUs given in device_ids.
82111
83112
This is the functional version of the DataParallel module.
@@ -96,13 +125,27 @@ def data_parallel(module, inputs, device_ids, output_device=None):
96125
inputs = (inputs,)
97126

98127
if not device_ids:
99-
return module(*inputs)
128+
if module_kwargs is None:
129+
return module(*inputs)
130+
else:
131+
return module(*inputs, **module_kwargs)
100132

101133
if output_device is None:
102134
output_device = device_ids[0]
103135

104136
replicas = replicate(module, device_ids)
105-
scattered = scatter(inputs, device_ids)
137+
scattered = scatter(inputs, device_ids, dim)
138+
139+
gpu_dicts = None
140+
if module_kwargs:
141+
scatter_kwargs = {}
142+
for key in module_kwargs.keys():
143+
scatter_kwargs[key] = scatter(module_kwargs[key], device_ids, dim)
144+
gpu_dicts = tuple(
145+
{key: values[i] for key, values in scatter_kwargs.items()}
146+
for i in device_ids
147+
)
148+
106149
replicas = replicas[:len(scattered)]
107-
outputs = parallel_apply(replicas, scattered)
108-
return gather(outputs, output_device)
150+
outputs = parallel_apply(replicas, scattered, gpu_dicts)
151+
return gather(outputs, output_device, dim)

torch/nn/parallel/parallel_apply.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,36 @@
88
import Queue as queue
99

1010

11-
def parallel_apply(modules, inputs):
11+
def parallel_apply(modules, inputs, kwargs_tup=None):
1212
assert len(modules) == len(inputs)
13+
if kwargs_tup:
14+
assert len(modules) == len(kwargs_tup)
15+
else:
16+
kwargs_tup = ({},) * len(modules)
1317
# Fast track
1418
if len(modules) == 1:
15-
return (modules[0](*inputs[0]),)
19+
return (modules[0](*inputs[0], **kwargs_tup[0]), )
1620

1721
lock = threading.Lock()
1822
results = {}
1923

20-
def _worker(module, input, results, lock):
24+
def _worker(module, input, kwargs, results, lock):
2125
var_input = input
2226
while not isinstance(var_input, Variable):
2327
var_input = var_input[0]
2428
try:
2529
with torch.cuda.device_of(var_input):
26-
output = module(*input)
30+
output = module(*input, **kwargs)
2731
with lock:
2832
results[input] = output
2933
except Exception as e:
3034
with lock:
3135
results[input] = e
3236

3337
threads = [threading.Thread(target=_worker,
34-
args=(module, input, results, lock))
35-
for module, input in zip(modules, inputs)]
38+
args=(module, input, kwargs, results, lock),
39+
)
40+
for module, input, kwargs in zip(modules, inputs, kwargs_tup)]
3641

3742
for thread in threads:
3843
thread.start()
Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,37 @@
1+
import torch
12
from torch.autograd import Variable
23
from ._functions import Scatter, Gather
4+
from torch.cuda.comm import broadcast
35

46

5-
def scatter(input, target_gpus):
6-
"""Slices a given variable into approximately equal chunks and distributes
7-
them accross given GPUs
7+
def scatter(input, target_gpus, dim=0):
8+
"""
9+
Slices variables into approximately equal chunks and
10+
distributes them accross given GPUs. Duplicates
11+
references to objects that are not variables. Does not
12+
support Tensors.
813
"""
914
def scatter_map(obj):
1015
if isinstance(obj, Variable):
11-
return Scatter(target_gpus)(obj)
12-
return tuple(zip(*map(scatter_map, obj)))
16+
return Scatter(target_gpus, dim=dim)(obj)
17+
assert not torch.is_tensor(obj), "Tensors not supported in scatter."
18+
if isinstance(obj, tuple) or isinstance(obj, list):
19+
return type(obj)(zip(*map(scatter_map, obj)))
20+
return tuple(obj for targets in target_gpus)
21+
1322
return scatter_map(input)
1423

1524

16-
def gather(outputs, target_device):
17-
"""Gathers variables from different GPUs on a specified device
18-
(-1 means the CPU).
25+
def gather(outputs, target_device, dim=0):
26+
"""
27+
Gathers variables from different GPUs on a specified device
28+
(-1 means the CPU).
1929
"""
2030
def gather_map(outputs):
2131
out = outputs[0]
2232
if isinstance(out, Variable):
23-
return Gather(target_device)(*outputs)
33+
return Gather(target_device, dim=dim)(*outputs)
34+
if out is None:
35+
return None
2436
return type(out)(map(gather_map, zip(*outputs)))
2537
return gather_map(outputs)

torch/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ def __iter__(self):
143143
return iter(map(lambda i: self.select(0, i), _range(self.size(0))))
144144

145145
def split(self, split_size, dim=0):
146-
"""Splits this tensor into a list of tensors.
146+
"""Splits this tensor into a tuple of tensors.
147147
148148
See :func:`torch.split`.
149149
"""
150150
return torch.split(self, split_size, dim)
151151

152152
def chunk(self, n_chunks, dim=0):
153-
"""Splits this tensor into a list of tensors.
153+
"""Splits this tensor into a tuple of tensors.
154154
155155
See :func:`torch.chunk`.
156156
"""

0 commit comments

Comments
 (0)