@@ -180,6 +180,77 @@ def dynamic_lstm(input,
180180 return hidden , cell
181181
182182
183+ def gru_unit (input ,
184+ hidden ,
185+ size ,
186+ weight = None ,
187+ bias = None ,
188+ activation = 'tanh' ,
189+ gate_activation = 'sigmoid' ,
190+ main_program = None ,
191+ startup_program = None ):
192+ """
193+ GRUUnit Operator implements partial calculations of the GRU unit as following:
194+
195+ $$
196+ update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
197+ reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
198+ output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
199+ output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
200+ $$
201+
202+ which is same as one time step of GRU Operator.
203+
204+ @note To implement the complete GRU unit, fully-connected operator must be
205+ used before to feed xu, xr and xc as the Input of GRUUnit operator.
206+
207+ TODO(ChunweiYan) add more document here
208+ """
209+ activation_dict = dict (
210+ identity = 0 ,
211+ sigmoid = 1 ,
212+ tanh = 2 ,
213+ relu = 3 , )
214+ activation = activation_dict [activation ]
215+ gate_activation = activation_dict [gate_activation ]
216+
217+ helper = LayerHelper ('gru_unit' , ** locals ())
218+ dtype = helper .input_dtype ()
219+ size = size / 3
220+
221+ # create weight
222+ if weight is None :
223+ weight = helper .create_parameter (
224+ attr = helper .param_attr , shape = [size , 3 * size ], dtype = dtype )
225+
226+ # create bias
227+ if bias is None :
228+ bias_size = [1 , 3 * size ]
229+ bias = helper .create_parameter (
230+ attr = helper .bias_attr , shape = bias_size , dtype = dtype , is_bias = True )
231+
232+ gate = helper .create_tmp_variable (dtype )
233+ reset_hidden_pre = helper .create_tmp_variable (dtype )
234+ updated_hidden = helper .create_tmp_variable (dtype )
235+
236+ helper .append_op (
237+ type = 'gru_unit' ,
238+ inputs = {'Input' : input ,
239+ 'HiddenPrev' : hidden ,
240+ 'Weight' : weight },
241+ outputs = {
242+ 'Gate' : gate ,
243+ 'ResetHiddenPrev' : reset_hidden_pre ,
244+ 'Hidden' : updated_hidden ,
245+ },
246+ attrs = {
247+ 'activation' : 0 ,
248+ 'gate_activation' : 1 ,
249+ })
250+
251+ return updated_hidden , reset_hidden_pre , gate
252+
253+
183254def data (name ,
184255 shape ,
185256 append_batch_size = True ,
0 commit comments