Skip to content

Commit 8628826

Browse files
committed
Adding rnn cell library
1 parent a559d94 commit 8628826

File tree

9 files changed

+298
-30
lines changed

9 files changed

+298
-30
lines changed

test/test_nn.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,12 +618,38 @@ def test_MaxUnpool2d_output_size(self):
618618
mu(output_small, indices_small, (h, w)))
619619

620620

621+
def test_RNN_cell(self):
622+
# this is just a smoke test; these modules are implemented through
623+
# autograd so no Jacobian test is needed
624+
for module in (nn.rnn.cell.RNN, nn.rnn.cell.RNNReLU, nn.rnn.cell.GRU):
625+
for bias in (True, False):
626+
input = Variable(torch.randn(3, 10))
627+
hx = Variable(torch.randn(3, 20))
628+
cell = module(10, 20, bias=bias)
629+
for i in range(6):
630+
hx = cell(input, hx)
631+
632+
hx.sum().backward()
633+
634+
def test_LSTM_cell(self):
635+
# this is just a smoke test; these modules are implemented through
636+
# autograd so no Jacobian test is needed
637+
for bias in (True, False):
638+
input = Variable(torch.randn(3, 10))
639+
hx = Variable(torch.randn(3, 20))
640+
cx = Variable(torch.randn(3, 20))
641+
lstm = nn.rnn.cell.LSTM(10, 20, bias=bias)
642+
for i in range(6):
643+
hx, cx = lstm(input, (hx, cx))
644+
645+
(hx+cx).sum().backward()
646+
621647
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
622648
def test_RNN_cpu_vs_cudnn(self):
623649

624650
def forward_backward(cuda, module, bias, input_val, hx_val, weights_val):
625651
rnn = module(input_size, hidden_size, num_layers, bias=bias)
626-
is_lstm = module == nn.LSTM
652+
is_lstm = module == nn.rnn.LSTM
627653

628654
for x_layer, y_layer in zip(rnn.all_weights, weights_val):
629655
for x, y in zip(x_layer, y_layer):
@@ -678,7 +704,7 @@ def diff(t_cpu, t_gpu, name):
678704

679705
# FIXME: we can't use torch.cuda.DoubleTensor because sum() is not yet defined on it
680706
with set_default_tensor_type('torch.FloatTensor'):
681-
for module in (nn.RNN, nn.RNNReLU, nn.LSTM, nn.GRU):
707+
for module in (nn.rnn.RNNTanh, nn.rnn.RNNReLU, nn.rnn.LSTM, nn.rnn.GRU):
682708
for bias in (True, False):
683709
input_val = torch.randn(seq_length, batch, input_size)
684710
hx_val = torch.randn(num_layers, batch, hidden_size)
@@ -880,7 +906,8 @@ def add_test(test):
880906
constructor=lambda: nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
881907
input_size=(1, 3, 5, 5),
882908
fullname='FractionalMaxPool2d_ratio',
883-
test_cuda=False),
909+
test_cuda=False
910+
),
884911
dict(
885912
constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
886913
input_size=(1, 3, 7, 7),

torch/backends/cudnn/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ def set(self, tensor):
117117
def as_tuple(self):
118118
return (self._type, tuple(self._size), tuple(self._stride))
119119

120-
<<<<<<< 9cd68129da50023929aff0ca4e4ba667ae75d785
121-
=======
122120

123121
class TensorDescriptorArray(object):
124122
def __init__(self, N):
@@ -148,7 +146,6 @@ def as_tuple(self):
148146
return (self._type, tuple(self._size), tuple(self._stride))
149147

150148

151-
>>>>>>> CUDNN RNN bindings
152149
class ConvolutionDescriptor(object):
153150
def __init__(self):
154151
ptr = ctypes.c_void_p()

torch/csrc/utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ void THPUtils_invalidArguments(PyObject *given_args,
202202
std::string error_msg;
203203
error_msg.reserve(2000);
204204
error_msg += function_name;
205-
error_msg += " recieved an invalid combination of argument types - got ";
205+
error_msg += " received an invalid combination of argument types - got ";
206206
va_list option_list;
207207
va_start(option_list, num_options);
208208
for (size_t i = 0; i < num_options; i++)

torch/nn/backends/thnn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def _initialize_backend():
1414
from ..functions.thnn import _all_functions as _thnn_functions
1515
from ..functions.linear import Linear
1616
from ..functions.conv import Conv2d
17-
from ..functions.rnn import RNN
17+
from ..functions.rnn import RNN, \
18+
RNNTanhCell, RNNReLUCell, GRUCell, LSTMCell
1819
from ..functions.dropout import Dropout, FeatureDropout
1920
from ..functions.activation import Softsign
2021
from ..functions.loss import CosineEmbeddingLoss, \
@@ -23,6 +24,10 @@ def _initialize_backend():
2324
backend.register_function('Linear', Linear)
2425
backend.register_function('Conv2d', Conv2d)
2526
backend.register_function('RNN', RNN)
27+
backend.register_function('RNNTanhCell', RNNTanhCell)
28+
backend.register_function('RNNReLUCell', RNNReLUCell)
29+
backend.register_function('LSTMCell', LSTMCell)
30+
backend.register_function('GRUCell', GRUCell)
2631
backend.register_function('Dropout', Dropout)
2732
backend.register_function('Dropout2d', FeatureDropout)
2833
backend.register_function('Dropout3d', FeatureDropout)

torch/nn/functions/rnn.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import torch.backends.cudnn.rnn
1010

1111

12-
13-
1412
def _getCudnnMode(mode):
1513
if mode == 'RNN_RELU':
1614
return cudnn.CUDNN_RNN_RELU
@@ -48,11 +46,11 @@ def linear(input, w, b):
4846

4947
def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
5048
hy = ReLU(linear(input, w_ih, b_ih) + linear(hidden, w_hh, b_hh))
51-
return hy, hy
49+
return hy
5250

5351
def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
5452
hy = tanh(linear(input, w_ih, b_ih) + linear(hidden, w_hh, b_hh))
55-
return hy, hy
53+
return hy
5654

5755
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
5856
hx, cx = hidden
@@ -63,25 +61,23 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
6361
forgetgate = sigmoid(gates[:,1*hsz:2*hsz])
6462
cellgate = tanh( gates[:,2*hsz:3*hsz])
6563
outgate = sigmoid(gates[:,3*hsz:4*hsz])
66-
nextc = (forgetgate * cx) + (ingate * cellgate)
67-
nexth = outgate * tanh(nextc)
64+
cy = (forgetgate * cx) + (ingate * cellgate)
65+
hy = outgate * tanh(cy)
6866

69-
return (nexth, nextc), nexth
67+
return hy, cy
7068

7169
def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
7270
hsz = hidden.size(1)
7371
gi = linear(input, w_ih, b_ih)
7472
gh = linear(hidden, w_hh, b_hh)
7573
# FIXME: chunk
7674

77-
# this is a bit weird, it doesn't match the order of parameters
78-
# implied by the cudnn docs, and it also uses nexth for output...
7975
resetgate = sigmoid(gi[:,0*hsz:1*hsz] + gh[:,0*hsz:1*hsz])
8076
inputgate = sigmoid(gi[:,1*hsz:2*hsz] + gh[:,1*hsz:2*hsz])
8177
newgate = tanh(gi[:,2*hsz:3*hsz] + resetgate * gh[:,2*hsz:3*hsz])
82-
nexth = newgate + inputgate * (hidden - newgate)
78+
hy = newgate + inputgate * (hidden - newgate)
8379

84-
return nexth, nexth # FIXME: nexth, nexth ???
80+
return hy
8581

8682
def StackedRNN(cell, num_layers, lstm=False):
8783
def forward(input, hidden, weight):
@@ -92,8 +88,9 @@ def forward(input, hidden, weight):
9288
hidden = zip(*hidden)
9389

9490
for i in range(num_layers):
95-
hy, input = cell(input, hidden[i], *weight[i])
91+
hy = cell(input, hidden[i], *weight[i])
9692
next_hidden.append(hy)
93+
input = hy[0] if lstm else hy
9794

9895
if lstm:
9996
next_h, next_c = zip(*next_hidden)
@@ -222,8 +219,6 @@ def backward_extended(self, grad_output, grad_hy):
222219
weight,
223220
grad_weight)
224221

225-
# FIXME: zero out grad_bias if necessary :)
226-
227222
return grad_input, grad_weight, grad_hx
228223

229224

torch/nn/modules/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
from .padding import ReflectionPad2d, ReplicationPad2d, ReplicationPad3d
1616
from .normalization import CrossMapLRN2d
1717
from .sparse import Embedding
18-
from .rnn import RNNBase, RNN, RNNReLU, GRU, LSTM
18+
# from .rnn import RNNBase, RNN, RNNReLU, GRU, LSTM, cell
19+
import rnn

torch/nn/modules/rnn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from rnn import RNNBase, RNNTanh, RNNReLU, LSTM, GRU
2+
import cell

0 commit comments

Comments
 (0)