77
88
99class RNNBase (Module ):
10- # FIXME: docstring
1110
1211 def __init__ (self , mode , input_size , hidden_size ,
1312 num_layers = 1 , bias = True , batch_first = False , dropout = 0 ):
@@ -22,7 +21,6 @@ def __init__(self, mode, input_size, hidden_size,
2221 self .all_weights = []
2322 super_weights = {}
2423 for layer in range (num_layers ):
25- # FIXME: sizes are different for LSTM/GRU
2624 layer_input_size = input_size if layer == 0 else hidden_size
2725 if mode == 'LSTM' :
2826 gate_size = 4 * hidden_size
@@ -73,17 +71,176 @@ def forward(self, input, hx):
7371
7472
7573class RNN (RNNBase ):
74+ """Applies a multi-layer RNN with tanh non-linearity to an input sequence.
75+
76+
77+ For each element in the input sequence, each layer computes the following
78+ function:
79+ ```
80+ h_t = tanh(w_ih * x_t + b_ih + w_hh * h_(t-1) + b_hh)
81+ ```
82+ where `h_t` is the hidden state at time t, and `x_t` is the hidden
83+ state of the previous layer at time t or `input_t` for the first layer.
84+
85+ Args:
86+ input_size: The number of expected features in the input x
87+ hidden_size: The number of features in the hidden state h
88+ num_layers: the size of the convolving kernel.
89+ bias: If False, then the layer does not use bias weights b_ih and b_hh (default=True).
90+ batch_first: If True, then the input tensor is provided as (batch, seq, feature)
91+ dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer
92+ Input: input, h_0
93+ input: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
94+ h_0: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
95+ Output: output, h_n
96+ output: A (seq_len x batch x hidden_size) tensor containing the output features (h_k) from the last layer of the RNN, for each k
97+ h_n: A (num_layers x batch x hidden_size) tensor containing the hidden state for k=seq_len
98+ Members:
99+ weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, of shape (input_size x hidden_size)
100+ weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, of shape (hidden_size x hidden_size)
101+ bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, of shape (hidden_size)
102+ bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size)
103+ Examples:
104+ >>> rnn = nn.RNN(10, 20, 2)
105+ >>> input = Variable(torch.randn(5, 3, 10))
106+ >>> h0 = Variable(torch.randn(2, 3, 20))
107+ >>> output, hn = rnn(input, h0)
108+ """
109+
76110 def __init__ (self , * args , ** kwargs ):
77111 super (RNN , self ).__init__ ('RNN_TANH' , * args , ** kwargs )
78112
79113class RNNReLU (RNNBase ):
114+ """Applies a multi-layer RNN with ReLU non-linearity to an input sequence.
115+
116+
117+ For each element in the input sequence, each layer computes the following
118+ function:
119+ ```
120+ h_t = ReLU(w_ih x_t + b_ih + w_hh h_(t-1) + b_hh)
121+ ```
122+ where `h_t` is the hidden state at time t, and `x_t` is the hidden
123+ state of the previous layer at time t or `input_t` for the first layer.
124+
125+ Args:
126+ input_size: The number of expected features in the input x
127+ hidden_size: The number of features in the hidden state h
128+ num_layers: the size of the convolving kernel.
129+ bias: If False, then the layer does not use bias weights b_ih and b_hh (default=True).
130+ batch_first: If True, then the input tensor is provided as (batch, seq, feature)
131+ dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer
132+ Input: input, h_0
133+ input: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
134+ h_0: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
135+ Output: output, h_n
136+ output: A (seq_len x batch x hidden_size) tensor containing the output features (h_k) from the last layer of the RNN, for each k
137+ h_n: A (num_layers x batch x hidden_size) tensor containing the hidden state for k=seq_len
138+ Members:
139+ weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, of shape (input_size x hidden_size)
140+ weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, of shape (hidden_size x hidden_size)
141+ bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, of shape (hidden_size)
142+ bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size)
143+ Examples:
144+ >>> rnn = nn.RNNReLU(10, 20, 2)
145+ >>> input = Variable(torch.randn(5, 3, 10))
146+ >>> h0 = Variable(torch.randn(2, 3, 20))
147+ >>> output, hn = rnn(input, h0)
148+ """
149+
80150 def __init__ (self , * args , ** kwargs ):
81151 super (RNNReLU , self ).__init__ ('RNN_RELU' , * args , ** kwargs )
82152
83153class LSTM (RNNBase ):
154+ """Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
155+
156+
157+ For each element in the input sequence, each layer computes the following
158+ function:
159+ ```
160+ i_t = sigmoid(W_ii x_t + b_ii + W_hi h_(t-1) + b_hi)
161+ f_t = sigmoid(W_if x_t + b_if + W_hf h_(t-1) + b_hf)
162+ g_t = tanh(W_ig x_t + b_ig + W_hc h_(t-1) + b_hg)
163+ o_t = sigmoid(W_io x_t + b_io + W_ho h_(t-1) + b_ho)
164+ c_t = f_t * c_(t-1) + i_t * c_t
165+ h_t = o_t * tanh(c_t)
166+ ```
167+ where `h_t` is the hidden state at time t, `c_t` is the cell state at time t,
168+ `x_t` is the hidden state of the previous layer at time t or input_t for the first layer,
169+ and `i_t`, `f_t`, `g_t`, `o_t` are the input, forget, cell, and out gates, respectively.
170+
171+ Args:
172+ input_size: The number of expected features in the input x
173+ hidden_size: The number of features in the hidden state h
174+ num_layers: the size of the convolving kernel.
175+ bias: If False, then the layer does not use bias weights b_ih and b_hh (default=True).
176+ batch_first: If True, then the input tensor is provided as (batch, seq, feature)
177+ dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer
178+ Input: input, (h_0, c_0)
179+ input: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
180+ h_0: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
181+ c_0: A (num_layers x batch x hidden_size) tensor containing the initial cell state for each element in the batch.
182+ Output: output, (h_n, c_n)
183+ output: A (seq_len x batch x hidden_size) tensor containing the output features (h_t) from the last layer of the RNN, for each t
184+ h_n: A (num_layers x batch x hidden_size) tensor containing the hidden state for t=seq_len
185+ c_n: A (num_layers x batch x hidden_size) tensor containing the cell state for t=seq_len
186+ Members:
187+ weight_ih_l[k]: the learnable input-hidden weights of the k-th layer (W_ir|W_ii|W_in), of shape (input_size x 3*hidden_size)
188+ weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer (W_hr|W_hi|W_hn), of shape (hidden_size x 3*hidden_size)
189+ bias_ih_l[k]: the learnable input-hidden bias of the k-th layer (b_ir|b_ii|b_in), of shape (3*hidden_size)
190+ bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer (W_hr|W_hi|W_hn), of shape (3*hidden_size)
191+ Examples:
192+ >>> rnn = nn.LSTM(10, 20, 2)
193+ >>> input = Variable(torch.randn(5, 3, 10))
194+ >>> h0 = Variable(torch.randn(2, 3, 20))
195+ >>> c0 = Variable(torch.randn(2, 3, 20))
196+ >>> output, hn = rnn(input, (h0, c0))
197+ """
84198 def __init__ (self , * args , ** kwargs ):
85199 super (LSTM , self ).__init__ ('LSTM' , * args , ** kwargs )
86200
87201class GRU (RNNBase ):
202+ """Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
203+
204+
205+ For each element in the input sequence, each layer computes the following
206+ function:
207+ ```
208+ r_t = sigmoid(W_ir x_t + b_ir + W_hr h_(t-1) + b_hr)
209+ i_t = sigmoid(W_ii x_t + b_ii + W_hi h_(t-1) + b_hi)
210+ n_t = tanh(W_in x_t + resetgate * W_hn h_(t-1))
211+ h_t = (1 - i_t) * n_t + i_t * h_(t-1)
212+ ```
213+ where `h_t` is the hidden state at time t, `x_t` is the hidden
214+ state of the previous layer at time t or input_t for the first layer,
215+ and `r_t`, `i_t`, `n_t` are the reset, input, and new gates, respectively.
216+
217+ Args:
218+ input_size: The number of expected features in the input x
219+ hidden_size: The number of features in the hidden state h
220+ num_layers: the size of the convolving kernel.
221+ bias: If False, then the layer does not use bias weights b_ih and b_hh (default=True).
222+ batch_first: If True, then the input tensor is provided as (batch, seq, feature)
223+ dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer
224+ Input: input, h_0
225+ input: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
226+ h_0: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
227+ Output: output, h_n
228+ output: A (seq_len x batch x hidden_size) tensor containing the output features (h_t) from the last layer of the RNN, for each t
229+ h_n: A (num_layers x batch x hidden_size) tensor containing the hidden state for t=seq_len
230+ Members:
231+ weight_ih_l[k]: the learnable input-hidden weights of the k-th layer (W_ir|W_ii|W_in), of shape (input_size x 3*hidden_size)
232+ weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer (W_hr|W_hi|W_hn), of shape (hidden_size x 3*hidden_size)
233+ bias_ih_l[k]: the learnable input-hidden bias of the k-th layer (b_ir|b_ii|b_in), of shape (3*hidden_size)
234+ bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer (W_hr|W_hi|W_hn), of shape (3*hidden_size)
235+ Examples:
236+ >>> rnn = nn.GRU(10, 20, 2)
237+ >>> input = Variable(torch.randn(5, 3, 10))
238+ >>> h0 = Variable(torch.randn(2, 3, 20))
239+ >>> output, hn = rnn(input, h0)
240+ """
241+
88242 def __init__ (self , * args , ** kwargs ):
89243 super (GRU , self ).__init__ ('GRU' , * args , ** kwargs )
244+
245+
246+ # FIXME: add module wrappers around XXXCell, and maybe StackedRNN and Recurrent
0 commit comments