Skip to content

Commit 942ca47

Browse files
committed
Copying weights for CUDNN
1 parent b0e33fb commit 942ca47

File tree

8 files changed

+361
-118
lines changed

8 files changed

+361
-118
lines changed

test/common_nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.cuda
88
from torch.autograd import Variable
99
from common import TestCase, to_gpu, get_numerical_jacobian, iter_tensors, contiguous
10+
import torch.backends.cudnn
1011

1112
# tarfile module tries to obtain a file object name in python 3.3
1213
if sys.version_info[:2] == (3, 3):
@@ -15,6 +16,7 @@
1516
TemporaryFile = tempfile.TemporaryFile
1617

1718
TEST_CUDA = torch.cuda.is_available()
19+
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.cuda.FloatTensor(1))
1820
PRECISION = 1e-5
1921

2022
module_tests = [

test/test_nn.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,25 @@
22
import torch
33
import random
44
import unittest
5+
import contextlib
56
from copy import deepcopy
67
from itertools import repeat
78

89
import torch.nn as nn
910
import torch.nn.parallel as dp
1011
from torch.autograd import Variable
1112
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
12-
module_tests, criterion_tests, TEST_CUDA, PRECISION
13+
module_tests, criterion_tests, TEST_CUDA, TEST_CUDNN, PRECISION
1314
from common import freeze_rng_state
1415

16+
@contextlib.contextmanager
17+
def set_default_tensor_type(type):
18+
old_type = torch.typename(torch.Tensor())
19+
torch.set_default_tensor_type(type)
20+
try:
21+
yield
22+
finally:
23+
torch.set_default_tensor_type(old_type)
1524

1625
class InputVariableMixin(object):
1726
def _get_input(self):
@@ -609,6 +618,87 @@ def test_MaxUnpool2d_output_size(self):
609618
mu(output_small, indices_small, (h, w)))
610619

611620

621+
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
622+
def test_RNN_cpu_vs_cudnn(self):
623+
624+
def forwardBackward(cuda, mode, input_val, hx_val, weights_val):
625+
rnn = nn.RNNBase(mode, input_size, hidden_size, num_layers)
626+
627+
for x_layer, y_layer in zip(rnn.all_weights, weights_val):
628+
for x, y in zip(x_layer, y_layer):
629+
x.data.copy_(y.data)
630+
631+
input = Variable(input_val.clone(), requires_grad=True)
632+
if mode == 'LSTM':
633+
hx = (Variable(hx_val.clone(), requires_grad=True),
634+
Variable(hx_val.add(1), requires_grad=True))
635+
else:
636+
hx = Variable(hx_val.clone(), requires_grad=True)
637+
638+
if cuda:
639+
rnn.cuda()
640+
input.data = input.data.cuda()
641+
if mode == 'LSTM':
642+
hx[0].data = hx[0].data.cuda()
643+
hx[1].data = hx[1].data.cuda()
644+
else:
645+
hx.data = hx.data.cuda()
646+
647+
output, hy = rnn(input, hx)
648+
# FIXME this is because of a pytorch bug
649+
if mode == 'LSTM':
650+
fake_loss = 0*(hy[0] + hy[1]).sum()
651+
else:
652+
fake_loss = 0*hy.sum()
653+
654+
loss = output.sum() + fake_loss
655+
loss.backward()
656+
657+
return {'output': output.data,
658+
'hy': hy[0].data if mode == 'LSTM' else hy.data,
659+
'weights': rnn.all_weights,
660+
'grad_input': input.grad,
661+
'grad_hx': hx[0].grad if mode == 'LSTM' else hx.grad,
662+
'cy': hy[1].data if mode == 'LSTM' else None,
663+
'grad_cx': hx[1].grad if mode == 'LSTM' else None}
664+
665+
def diff(t_cpu, t_gpu, name):
666+
self.assertTrue(torch.is_tensor(t_cpu))
667+
self.assertTrue(torch.is_tensor(t_gpu))
668+
delta = t_gpu.cpu().add(-1, t_cpu).abs().max()
669+
# print("{:30s} cpu: {:10g} gpu: {:10g} diff: {:10g}".format(name, t_cpu.abs().max(), t_gpu.abs().max(), delta))
670+
self.assertLess(delta, 2 * PRECISION)
671+
672+
input_size = 10
673+
hidden_size = 20
674+
num_layers = 2
675+
seq_length = 7
676+
batch = 5
677+
678+
# FIXME: we can't use torch.cuda.DoubleTensor because sum() is not yet defined on it
679+
with set_default_tensor_type('torch.FloatTensor'):
680+
for mode in ("RNN_RELU", "RNN_TANH", "GRU", "LSTM"):
681+
input_val = torch.randn(seq_length, batch, input_size)
682+
hx_val = torch.randn(num_layers, batch, hidden_size)
683+
684+
weights_val = nn.RNNBase(mode, input_size, hidden_size, num_layers).all_weights
685+
686+
outputs_cpu = forwardBackward(False, mode, input_val, hx_val, weights_val)
687+
outputs_gpu = forwardBackward(True, mode, input_val, hx_val, weights_val)
688+
689+
diff(outputs_cpu['output'], outputs_gpu['output'], 'output')
690+
diff(outputs_cpu['hy'], outputs_gpu['hy'], 'hy')
691+
diff(outputs_cpu['grad_input'], outputs_gpu['grad_input'], 'grad_input')
692+
diff(outputs_cpu['grad_hx'], outputs_gpu['grad_hx'], 'grad_hx')
693+
if outputs_cpu['cy'] is not None:
694+
diff(outputs_cpu['cy'], outputs_gpu['cy'], 'cy')
695+
diff(outputs_cpu['grad_cx'], outputs_gpu['grad_cx'], 'grad_cx')
696+
697+
for i, (cpu_layer_weight, gpu_layer_weight) in enumerate(zip(outputs_cpu['weights'], outputs_gpu['weights'])):
698+
for j, (cpu_weight, gpu_weight) in enumerate(zip(cpu_layer_weight, gpu_layer_weight)):
699+
diff(cpu_weight.grad, gpu_weight.grad, mode + ' grad_weight[{},{}]'.format(i, j))
700+
701+
612702
def add_test(test):
613703
test_name = test.get_name()
614704
cuda_test_name = test_name + '_cuda'

torch/autograd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

33
from .variable import Variable
4-
from .function import Function
4+
from .function import Function, NestedInputFunction
55

66
assert torch._C._autograd_init()

torch/autograd/function.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from collections import OrderedDict
33
from itertools import chain
44

5+
import torch # FIXME: is this ok? Needed for torch.is_tensor
6+
import collections
57

68
class Function(_C._FunctionBase):
79

@@ -43,3 +45,89 @@ def __init__(self, inplace=False):
4345
super(InplaceFunction, self).__init__()
4446
self.inplace = inplace
4547

48+
def _nested_map(condition, fn):
49+
def _map(obj):
50+
if condition(obj):
51+
return fn(obj)
52+
elif obj is None:
53+
return None
54+
elif isinstance(obj, (list, tuple)):
55+
return type(obj)(_map(x) for x in obj)
56+
else:
57+
raise ValueError("NestedIOFunction doesn't know how to process "
58+
"an input object of type " + torch.typename(obj))
59+
return _map
60+
61+
def _iter_filter(condition):
62+
def _iter(obj):
63+
if condition(obj):
64+
yield obj
65+
elif obj is None:
66+
return
67+
elif isinstance(obj, (list, tuple)):
68+
for o in obj:
69+
for var in _iter(o):
70+
yield var
71+
else:
72+
raise ValueError("NestedIOFunction doesn't know how to process "
73+
"an input object of type " + torch.typename(obj))
74+
return _iter
75+
76+
77+
_iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable))
78+
_iter_tensors = _iter_filter(torch.is_tensor)
79+
_iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o))
80+
_map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: o.data)
81+
_map_tensor_type = _nested_map(lambda o: torch.is_tensor(o), lambda o: o.type())
82+
83+
def _map_tensor_fromiter(itr):
84+
return _nested_map(lambda o: torch.is_tensor(o), lambda o: itr.next())
85+
def _map_variable_fromiter(itr):
86+
return _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: itr.next())
87+
88+
class NestedIOFunction(Function):
89+
90+
def _do_forward(self, *input):
91+
self._nested_input = input
92+
flat_input = tuple(_iter_variables(input))
93+
flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
94+
nested_output = self._nested_output
95+
nested_variables = _map_tensor_fromiter(iter(flat_output))(self._nested_output)
96+
return nested_variables
97+
98+
def backward(self, *gradients):
99+
nested_gradients = _map_tensor_fromiter(iter(gradients))(self._nested_output)
100+
del self._nested_output
101+
result = self.backward_extended(*nested_gradients)
102+
return tuple(_iter_None_tensors(result))
103+
104+
__call__ = _do_forward
105+
106+
def forward(self, *args):
107+
nested_tensors = _map_variable_tensor(self._nested_input)
108+
result = self.forward_extended(*nested_tensors)
109+
del self._nested_input
110+
self._nested_output = result
111+
return tuple(_iter_tensors(result))
112+
113+
def save_for_backward(self, *args):
114+
self.to_save = tuple(_iter_tensors(args))
115+
self._to_save_nested = args
116+
117+
@property
118+
def saved_tensors(self):
119+
flat_tensors = super(NestedIOFunction, self).saved_tensors
120+
return _map_tensor_fromiter(iter(flat_tensors))(self._to_save_nested)
121+
122+
def mark_dirty(self, *args, **kwargs):
123+
self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
124+
125+
def mark_non_differentiable(self, *args, **kwargs):
126+
self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
127+
128+
def forward_extended(self, *input):
129+
raise NotImplementedError
130+
131+
def backward_extended(self, *grad_output):
132+
raise NotImplementedError
133+
raise NotImplementedError

0 commit comments

Comments
 (0)