99import torch .backends .cudnn .rnn
1010
1111
12-
13-
1412def _getCudnnMode (mode ):
1513 if mode == 'RNN_RELU' :
1614 return cudnn .CUDNN_RNN_RELU
@@ -48,11 +46,11 @@ def linear(input, w, b):
4846
4947def RNNReLUCell (input , hidden , w_ih , w_hh , b_ih = None , b_hh = None ):
5048 hy = ReLU (linear (input , w_ih , b_ih ) + linear (hidden , w_hh , b_hh ))
51- return hy , hy
49+ return hy
5250
5351def RNNTanhCell (input , hidden , w_ih , w_hh , b_ih = None , b_hh = None ):
5452 hy = tanh (linear (input , w_ih , b_ih ) + linear (hidden , w_hh , b_hh ))
55- return hy , hy
53+ return hy
5654
5755def LSTMCell (input , hidden , w_ih , w_hh , b_ih = None , b_hh = None ):
5856 hx , cx = hidden
@@ -63,25 +61,23 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
6361 forgetgate = sigmoid (gates [:,1 * hsz :2 * hsz ])
6462 cellgate = tanh ( gates [:,2 * hsz :3 * hsz ])
6563 outgate = sigmoid (gates [:,3 * hsz :4 * hsz ])
66- nextc = (forgetgate * cx ) + (ingate * cellgate )
67- nexth = outgate * tanh (nextc )
64+ cy = (forgetgate * cx ) + (ingate * cellgate )
65+ hy = outgate * tanh (cy )
6866
69- return ( nexth , nextc ), nexth
67+ return hy , cy
7068
7169def GRUCell (input , hidden , w_ih , w_hh , b_ih = None , b_hh = None ):
7270 hsz = hidden .size (1 )
7371 gi = linear (input , w_ih , b_ih )
7472 gh = linear (hidden , w_hh , b_hh )
7573 # FIXME: chunk
7674
77- # this is a bit weird, it doesn't match the order of parameters
78- # implied by the cudnn docs, and it also uses nexth for output...
7975 resetgate = sigmoid (gi [:,0 * hsz :1 * hsz ] + gh [:,0 * hsz :1 * hsz ])
8076 inputgate = sigmoid (gi [:,1 * hsz :2 * hsz ] + gh [:,1 * hsz :2 * hsz ])
8177 newgate = tanh (gi [:,2 * hsz :3 * hsz ] + resetgate * gh [:,2 * hsz :3 * hsz ])
82- nexth = newgate + inputgate * (hidden - newgate )
78+ hy = newgate + inputgate * (hidden - newgate )
8379
84- return nexth , nexth # FIXME: nexth, nexth ???
80+ return hy
8581
8682def StackedRNN (cell , num_layers , lstm = False ):
8783 def forward (input , hidden , weight ):
@@ -92,8 +88,9 @@ def forward(input, hidden, weight):
9288 hidden = zip (* hidden )
9389
9490 for i in range (num_layers ):
95- hy , input = cell (input , hidden [i ], * weight [i ])
91+ hy = cell (input , hidden [i ], * weight [i ])
9692 next_hidden .append (hy )
93+ input = hy [0 ] if lstm else hy
9794
9895 if lstm :
9996 next_h , next_c = zip (* next_hidden )
@@ -222,8 +219,6 @@ def backward_extended(self, grad_output, grad_hy):
222219 weight ,
223220 grad_weight )
224221
225- # FIXME: zero out grad_bias if necessary :)
226-
227222 return grad_input , grad_weight , grad_hx
228223
229224
0 commit comments