Skip to content

Commit a559d94

Browse files
committed
docs and such
1 parent 1eb6870 commit a559d94

File tree

3 files changed

+187
-29
lines changed

3 files changed

+187
-29
lines changed

test/test_nn.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -620,55 +620,56 @@ def test_MaxUnpool2d_output_size(self):
620620

621621
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
622622
def test_RNN_cpu_vs_cudnn(self):
623-
624-
def forward_backward(cuda, mode, bias, input_val, hx_val, weights_val):
625-
rnn = nn.RNNBase(mode, input_size, hidden_size, num_layers, bias=bias)
623+
624+
def forward_backward(cuda, module, bias, input_val, hx_val, weights_val):
625+
rnn = module(input_size, hidden_size, num_layers, bias=bias)
626+
is_lstm = module == nn.LSTM
626627

627628
for x_layer, y_layer in zip(rnn.all_weights, weights_val):
628629
for x, y in zip(x_layer, y_layer):
629630
x.data.copy_(y.data)
630631

631632
input = Variable(input_val.clone(), requires_grad=True)
632-
if mode == 'LSTM':
633+
if is_lstm:
633634
hx = (Variable(hx_val.clone(), requires_grad=True),
634635
Variable(hx_val.add(1), requires_grad=True))
635636
else:
636637
hx = Variable(hx_val.clone(), requires_grad=True)
637-
638+
638639
if cuda:
639640
rnn.cuda()
640641
input.data = input.data.cuda()
641-
if mode == 'LSTM':
642+
if is_lstm:
642643
hx[0].data = hx[0].data.cuda()
643644
hx[1].data = hx[1].data.cuda()
644645
else:
645646
hx.data = hx.data.cuda()
646647

647648
output, hy = rnn(input, hx)
648649
# FIXME this is because of a pytorch bug
649-
if mode == 'LSTM':
650+
if is_lstm:
650651
fake_loss = 0*(hy[0] + hy[1]).sum()
651652
else:
652653
fake_loss = 0*hy.sum()
653-
654+
654655
loss = output.sum() + fake_loss
655656
loss.backward()
656-
657+
657658
return {'output': output.data,
658-
'hy': hy[0].data if mode == 'LSTM' else hy.data,
659+
'hy': hy[0].data if is_lstm else hy.data,
659660
'weights': rnn.all_weights,
660661
'grad_input': input.grad,
661-
'grad_hx': hx[0].grad if mode == 'LSTM' else hx.grad,
662-
'cy': hy[1].data if mode == 'LSTM' else None,
663-
'grad_cx': hx[1].grad if mode == 'LSTM' else None}
664-
662+
'grad_hx': hx[0].grad if is_lstm else hx.grad,
663+
'cy': hy[1].data if is_lstm else None,
664+
'grad_cx': hx[1].grad if is_lstm else None}
665+
665666
def diff(t_cpu, t_gpu, name):
666667
self.assertTrue(torch.is_tensor(t_cpu))
667668
self.assertTrue(torch.is_tensor(t_gpu))
668669
delta = t_gpu.cpu().add(-1, t_cpu).abs().max()
669670
# print("{:30s} cpu: {:10g} gpu: {:10g} diff: {:10g}".format(name, t_cpu.abs().max(), t_gpu.abs().max(), delta))
670671
self.assertLess(delta, 2 * PRECISION)
671-
672+
672673
input_size = 10
673674
hidden_size = 20
674675
num_layers = 2
@@ -677,27 +678,27 @@ def diff(t_cpu, t_gpu, name):
677678

678679
# FIXME: we can't use torch.cuda.DoubleTensor because sum() is not yet defined on it
679680
with set_default_tensor_type('torch.FloatTensor'):
680-
for mode in ("RNN_RELU", "RNN_TANH", "GRU", "LSTM"):
681+
for module in (nn.RNN, nn.RNNReLU, nn.LSTM, nn.GRU):
681682
for bias in (True, False):
682683
input_val = torch.randn(seq_length, batch, input_size)
683684
hx_val = torch.randn(num_layers, batch, hidden_size)
684-
685-
weights_val = nn.RNNBase(mode, input_size, hidden_size, num_layers).all_weights
686-
687-
outputs_cpu = forward_backward(False, mode, bias, input_val, hx_val, weights_val)
688-
outputs_gpu = forward_backward(True, mode, bias, input_val, hx_val, weights_val)
689-
685+
686+
weights_val = module(input_size, hidden_size, num_layers).all_weights
687+
688+
outputs_cpu = forward_backward(False, module, bias, input_val, hx_val, weights_val)
689+
outputs_gpu = forward_backward(True, module, bias, input_val, hx_val, weights_val)
690+
690691
diff(outputs_cpu['output'], outputs_gpu['output'], 'output')
691692
diff(outputs_cpu['hy'], outputs_gpu['hy'], 'hy')
692693
diff(outputs_cpu['grad_input'], outputs_gpu['grad_input'], 'grad_input')
693694
diff(outputs_cpu['grad_hx'], outputs_gpu['grad_hx'], 'grad_hx')
694695
if outputs_cpu['cy'] is not None:
695696
diff(outputs_cpu['cy'], outputs_gpu['cy'], 'cy')
696697
diff(outputs_cpu['grad_cx'], outputs_gpu['grad_cx'], 'grad_cx')
697-
698+
698699
for i, (cpu_layer_weight, gpu_layer_weight) in enumerate(zip(outputs_cpu['weights'], outputs_gpu['weights'])):
699700
for j, (cpu_weight, gpu_weight) in enumerate(zip(cpu_layer_weight, gpu_layer_weight)):
700-
diff(cpu_weight.grad, gpu_weight.grad, mode + ' grad_weight[{},{}]'.format(i, j))
701+
diff(cpu_weight.grad, gpu_weight.grad, 'grad_weight[{},{}]'.format(i, j))
701702

702703

703704
def add_test(test):

torch/nn/functions/rnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
7777
# this is a bit weird, it doesn't match the order of parameters
7878
# implied by the cudnn docs, and it also uses nexth for output...
7979
resetgate = sigmoid(gi[:,0*hsz:1*hsz] + gh[:,0*hsz:1*hsz])
80-
updategate = sigmoid(gi[:,1*hsz:2*hsz] + gh[:,1*hsz:2*hsz])
81-
output = tanh(gi[:,2*hsz:3*hsz] + resetgate * gh[:,2*hsz:3*hsz])
82-
nexth = output + updategate * (hidden - output)
80+
inputgate = sigmoid(gi[:,1*hsz:2*hsz] + gh[:,1*hsz:2*hsz])
81+
newgate = tanh(gi[:,2*hsz:3*hsz] + resetgate * gh[:,2*hsz:3*hsz])
82+
nexth = newgate + inputgate * (hidden - newgate)
8383

8484
return nexth, nexth # FIXME: nexth, nexth ???
8585

torch/nn/modules/rnn.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
class RNNBase(Module):
10-
# FIXME: docstring
1110

1211
def __init__(self, mode, input_size, hidden_size,
1312
num_layers=1, bias=True, batch_first=False, dropout=0):
@@ -22,7 +21,6 @@ def __init__(self, mode, input_size, hidden_size,
2221
self.all_weights = []
2322
super_weights = {}
2423
for layer in range(num_layers):
25-
# FIXME: sizes are different for LSTM/GRU
2624
layer_input_size = input_size if layer == 0 else hidden_size
2725
if mode == 'LSTM':
2826
gate_size = 4 * hidden_size
@@ -73,17 +71,176 @@ def forward(self, input, hx):
7371

7472

7573
class RNN(RNNBase):
74+
"""Applies a multi-layer RNN with tanh non-linearity to an input sequence.
75+
76+
77+
For each element in the input sequence, each layer computes the following
78+
function:
79+
```
80+
h_t = tanh(w_ih * x_t + b_ih + w_hh * h_(t-1) + b_hh)
81+
```
82+
where `h_t` is the hidden state at time t, and `x_t` is the hidden
83+
state of the previous layer at time t or `input_t` for the first layer.
84+
85+
Args:
86+
input_size: The number of expected features in the input x
87+
hidden_size: The number of features in the hidden state h
88+
num_layers: the size of the convolving kernel.
89+
bias: If False, then the layer does not use bias weights b_ih and b_hh (default=True).
90+
batch_first: If True, then the input tensor is provided as (batch, seq, feature)
91+
dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer
92+
Input: input, h_0
93+
input: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
94+
h_0: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
95+
Output: output, h_n
96+
output: A (seq_len x batch x hidden_size) tensor containing the output features (h_k) from the last layer of the RNN, for each k
97+
h_n: A (num_layers x batch x hidden_size) tensor containing the hidden state for k=seq_len
98+
Members:
99+
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, of shape (input_size x hidden_size)
100+
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, of shape (hidden_size x hidden_size)
101+
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, of shape (hidden_size)
102+
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size)
103+
Examples:
104+
>>> rnn = nn.RNN(10, 20, 2)
105+
>>> input = Variable(torch.randn(5, 3, 10))
106+
>>> h0 = Variable(torch.randn(2, 3, 20))
107+
>>> output, hn = rnn(input, h0)
108+
"""
109+
76110
def __init__(self, *args, **kwargs):
77111
super(RNN, self).__init__('RNN_TANH', *args, **kwargs)
78112

79113
class RNNReLU(RNNBase):
114+
"""Applies a multi-layer RNN with ReLU non-linearity to an input sequence.
115+
116+
117+
For each element in the input sequence, each layer computes the following
118+
function:
119+
```
120+
h_t = ReLU(w_ih x_t + b_ih + w_hh h_(t-1) + b_hh)
121+
```
122+
where `h_t` is the hidden state at time t, and `x_t` is the hidden
123+
state of the previous layer at time t or `input_t` for the first layer.
124+
125+
Args:
126+
input_size: The number of expected features in the input x
127+
hidden_size: The number of features in the hidden state h
128+
num_layers: the size of the convolving kernel.
129+
bias: If False, then the layer does not use bias weights b_ih and b_hh (default=True).
130+
batch_first: If True, then the input tensor is provided as (batch, seq, feature)
131+
dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer
132+
Input: input, h_0
133+
input: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
134+
h_0: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
135+
Output: output, h_n
136+
output: A (seq_len x batch x hidden_size) tensor containing the output features (h_k) from the last layer of the RNN, for each k
137+
h_n: A (num_layers x batch x hidden_size) tensor containing the hidden state for k=seq_len
138+
Members:
139+
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, of shape (input_size x hidden_size)
140+
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, of shape (hidden_size x hidden_size)
141+
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, of shape (hidden_size)
142+
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size)
143+
Examples:
144+
>>> rnn = nn.RNNReLU(10, 20, 2)
145+
>>> input = Variable(torch.randn(5, 3, 10))
146+
>>> h0 = Variable(torch.randn(2, 3, 20))
147+
>>> output, hn = rnn(input, h0)
148+
"""
149+
80150
def __init__(self, *args, **kwargs):
81151
super(RNNReLU, self).__init__('RNN_RELU', *args, **kwargs)
82152

83153
class LSTM(RNNBase):
154+
"""Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
155+
156+
157+
For each element in the input sequence, each layer computes the following
158+
function:
159+
```
160+
i_t = sigmoid(W_ii x_t + b_ii + W_hi h_(t-1) + b_hi)
161+
f_t = sigmoid(W_if x_t + b_if + W_hf h_(t-1) + b_hf)
162+
g_t = tanh(W_ig x_t + b_ig + W_hc h_(t-1) + b_hg)
163+
o_t = sigmoid(W_io x_t + b_io + W_ho h_(t-1) + b_ho)
164+
c_t = f_t * c_(t-1) + i_t * c_t
165+
h_t = o_t * tanh(c_t)
166+
```
167+
where `h_t` is the hidden state at time t, `c_t` is the cell state at time t,
168+
`x_t` is the hidden state of the previous layer at time t or input_t for the first layer,
169+
and `i_t`, `f_t`, `g_t`, `o_t` are the input, forget, cell, and out gates, respectively.
170+
171+
Args:
172+
input_size: The number of expected features in the input x
173+
hidden_size: The number of features in the hidden state h
174+
num_layers: the size of the convolving kernel.
175+
bias: If False, then the layer does not use bias weights b_ih and b_hh (default=True).
176+
batch_first: If True, then the input tensor is provided as (batch, seq, feature)
177+
dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer
178+
Input: input, (h_0, c_0)
179+
input: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
180+
h_0: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
181+
c_0: A (num_layers x batch x hidden_size) tensor containing the initial cell state for each element in the batch.
182+
Output: output, (h_n, c_n)
183+
output: A (seq_len x batch x hidden_size) tensor containing the output features (h_t) from the last layer of the RNN, for each t
184+
h_n: A (num_layers x batch x hidden_size) tensor containing the hidden state for t=seq_len
185+
c_n: A (num_layers x batch x hidden_size) tensor containing the cell state for t=seq_len
186+
Members:
187+
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer (W_ir|W_ii|W_in), of shape (input_size x 3*hidden_size)
188+
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer (W_hr|W_hi|W_hn), of shape (hidden_size x 3*hidden_size)
189+
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer (b_ir|b_ii|b_in), of shape (3*hidden_size)
190+
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer (W_hr|W_hi|W_hn), of shape (3*hidden_size)
191+
Examples:
192+
>>> rnn = nn.LSTM(10, 20, 2)
193+
>>> input = Variable(torch.randn(5, 3, 10))
194+
>>> h0 = Variable(torch.randn(2, 3, 20))
195+
>>> c0 = Variable(torch.randn(2, 3, 20))
196+
>>> output, hn = rnn(input, (h0, c0))
197+
"""
84198
def __init__(self, *args, **kwargs):
85199
super(LSTM, self).__init__('LSTM', *args, **kwargs)
86200

87201
class GRU(RNNBase):
202+
"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
203+
204+
205+
For each element in the input sequence, each layer computes the following
206+
function:
207+
```
208+
r_t = sigmoid(W_ir x_t + b_ir + W_hr h_(t-1) + b_hr)
209+
i_t = sigmoid(W_ii x_t + b_ii + W_hi h_(t-1) + b_hi)
210+
n_t = tanh(W_in x_t + resetgate * W_hn h_(t-1))
211+
h_t = (1 - i_t) * n_t + i_t * h_(t-1)
212+
```
213+
where `h_t` is the hidden state at time t, `x_t` is the hidden
214+
state of the previous layer at time t or input_t for the first layer,
215+
and `r_t`, `i_t`, `n_t` are the reset, input, and new gates, respectively.
216+
217+
Args:
218+
input_size: The number of expected features in the input x
219+
hidden_size: The number of features in the hidden state h
220+
num_layers: the size of the convolving kernel.
221+
bias: If False, then the layer does not use bias weights b_ih and b_hh (default=True).
222+
batch_first: If True, then the input tensor is provided as (batch, seq, feature)
223+
dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer
224+
Input: input, h_0
225+
input: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
226+
h_0: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
227+
Output: output, h_n
228+
output: A (seq_len x batch x hidden_size) tensor containing the output features (h_t) from the last layer of the RNN, for each t
229+
h_n: A (num_layers x batch x hidden_size) tensor containing the hidden state for t=seq_len
230+
Members:
231+
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer (W_ir|W_ii|W_in), of shape (input_size x 3*hidden_size)
232+
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer (W_hr|W_hi|W_hn), of shape (hidden_size x 3*hidden_size)
233+
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer (b_ir|b_ii|b_in), of shape (3*hidden_size)
234+
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer (W_hr|W_hi|W_hn), of shape (3*hidden_size)
235+
Examples:
236+
>>> rnn = nn.GRU(10, 20, 2)
237+
>>> input = Variable(torch.randn(5, 3, 10))
238+
>>> h0 = Variable(torch.randn(2, 3, 20))
239+
>>> output, hn = rnn(input, h0)
240+
"""
241+
88242
def __init__(self, *args, **kwargs):
89243
super(GRU, self).__init__('GRU', *args, **kwargs)
244+
245+
246+
# FIXME: add module wrappers around XXXCell, and maybe StackedRNN and Recurrent

0 commit comments

Comments
 (0)