@@ -69,8 +69,14 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar;
6969class IdentityActivation : public ActivationFunction {
7070public:
7171 static const std::string name;
72- void forward (Argument& act) { (void )act; }
73- void backward (Argument& act) { (void )act; }
72+ Error __must_check forward (Argument& act) {
73+ (void )act;
74+ return Error ();
75+ }
76+ Error __must_check backward (Argument& act) {
77+ (void )act;
78+ return Error ();
79+ }
7480 const std::string& getName () const { return name; }
7581};
7682const std::string IdentityActivation::name = " " ;
@@ -86,8 +92,14 @@ static InitFunction __reg_activation__identity([] {
8692 * \f]
8793 */
8894BEGIN_DEFINE_ACTIVATION (sigmoid)
89- void forward (Argument& act) { act.value ->sigmoid (*act.value ); }
90- void backward (Argument& act) { act.grad ->sigmoidDerivative (*act.value ); }
95+ Error __must_check forward (Argument& act) {
96+ act.value ->sigmoid (*act.value );
97+ return Error ();
98+ }
99+ Error __must_check backward (Argument& act) {
100+ act.grad ->sigmoidDerivative (*act.value );
101+ return Error ();
102+ }
91103END_DEFINE_ACTIVATION (sigmoid)
92104
93105/* *
@@ -103,9 +115,12 @@ MatrixPtr sftMaxDot_;
103115MatrixPtr one_;
104116
105117public:
106- void forward (Argument& act) { act.value ->softmax (*act.value ); }
118+ Error __must_check forward (Argument& act) {
119+ act.value ->softmax (*act.value );
120+ return Error ();
121+ }
107122
108- void backward (Argument& act) {
123+ Error __must_check backward (Argument& act) {
109124 MatrixPtr outputV = act.value ;
110125 MatrixPtr outputG = act.grad ;
111126
@@ -137,6 +152,7 @@ void backward(Argument& act) {
137152
138153 act.grad ->softmaxDerivative (*act.value , *sftMaxSum_);
139154 }
155+ return Error ();
140156}
141157END_DEFINE_ACTIVATION (softmax)
142158
@@ -151,8 +167,11 @@ ACTIVATION_CLASS_NAME(softmax) softmax_;
151167Argument argument_;
152168
153169public:
154- void forward (Argument& act) {
155- CHECK_EQ (act.value ->getWidth (), 1UL );
170+ Error __must_check forward (Argument& act) {
171+ if (act.value ->getWidth () != 1UL ) {
172+ return Error (
173+ " Input width for each timestep of sequence softmax should be 1" );
174+ }
156175
157176 if (!argument_.value ) {
158177 argument_.value = Matrix::create (nullptr ,
@@ -169,10 +188,14 @@ void forward(Argument& act) {
169188
170189 auto starts = act.sequenceStartPositions ->getVector (useGpu (act.deviceId ));
171190 act.value ->sequenceSoftmax (*act.value , *starts);
191+ return Error ();
172192}
173193
174- void backward (Argument& act) {
175- CHECK_EQ (act.grad ->getWidth (), 1UL );
194+ Error __must_check backward (Argument& act) {
195+ if (act.value ->getWidth () != 1UL ) {
196+ return Error (
197+ " Input width for each timestep of sequence softmax should be 1" );
198+ }
176199
177200 size_t numSequences = act.getNumSequences ();
178201 const int * starts = act.sequenceStartPositions ->getData (false );
@@ -184,8 +207,10 @@ void backward(Argument& act) {
184207 argument_.value ->setData (act.value ->getData () + offset, 1UL , size);
185208 argument_.grad ->setData (act.grad ->getData () + offset, 1UL , size);
186209
187- softmax_.backward (argument_);
210+ Error status = softmax_.backward (argument_);
211+ if (!status) return status;
188212 }
213+ return Error ();
189214}
190215END_DEFINE_ACTIVATION (sequence_softmax)
191216
@@ -200,9 +225,15 @@ END_DEFINE_ACTIVATION(sequence_softmax)
200225 * 0 otherwise.
201226 */
202227BEGIN_DEFINE_ACTIVATION (relu)
203- void forward (Argument& act) { act.value ->relu (*act.value ); }
228+ Error __must_check forward (Argument& act) {
229+ act.value ->relu (*act.value );
230+ return Error ();
231+ }
204232
205- void backward (Argument& act) { act.grad ->reluDerivative (*act.value ); }
233+ Error __must_check backward (Argument& act) {
234+ act.grad ->reluDerivative (*act.value );
235+ return Error ();
236+ }
206237END_DEFINE_ACTIVATION (relu)
207238
208239/* *
@@ -219,9 +250,15 @@ END_DEFINE_ACTIVATION(relu)
219250 * TODO(yuyang18): Remove magic number 24 or make it configuable.
220251 */
221252BEGIN_DEFINE_ACTIVATION (brelu)
222- void forward (Argument& act) { act.value ->brelu (*act.value ); }
253+ Error __must_check forward (Argument& act) {
254+ act.value ->brelu (*act.value );
255+ return Error ();
256+ }
223257
224- void backward (Argument& act) { act.grad ->breluDerivative (*act.value ); }
258+ Error __must_check backward (Argument& act) {
259+ act.grad ->breluDerivative (*act.value );
260+ return Error ();
261+ }
225262END_DEFINE_ACTIVATION (brelu)
226263
227264/* *
@@ -231,9 +268,15 @@ END_DEFINE_ACTIVATION(brelu)
231268 * \f]
232269 */
233270BEGIN_DEFINE_ACTIVATION (tanh)
234- void forward (Argument& act) { act.value ->tanh (*act.value ); }
271+ Error __must_check forward (Argument& act) {
272+ act.value ->tanh (*act.value );
273+ return Error ();
274+ }
235275
236- void backward (Argument& act) { act.grad ->tanhDerivative (*act.value ); }
276+ Error __must_check backward (Argument& act) {
277+ act.grad ->tanhDerivative (*act.value );
278+ return Error ();
279+ }
237280END_DEFINE_ACTIVATION (tanh)
238281
239282/* *
@@ -248,10 +291,14 @@ real a, b;
248291
249292public:
250293ACTIVATION_CLASS_NAME (stanh)() : a(1.7159 ), b(2 . / 3 .) {}
251- void forward (Argument& act) { act.value ->scaledTanh (*act.value , a, b); }
294+ Error __must_check forward (Argument& act) {
295+ act.value ->scaledTanh (*act.value , a, b);
296+ return Error ();
297+ }
252298
253- void backward (Argument& act) {
299+ Error __must_check backward (Argument& act) {
254300 act.grad ->scaledTanhDerivative (*act.value , a, b);
301+ return Error ();
255302}
256303END_DEFINE_ACTIVATION (stanh)
257304
@@ -262,9 +309,15 @@ END_DEFINE_ACTIVATION(stanh)
262309 * \f]
263310 */
264311BEGIN_DEFINE_ACTIVATION (softrelu)
265- void forward (Argument& act) { act.value ->softrelu (*act.value ); }
312+ Error __must_check forward (Argument& act) {
313+ act.value ->softrelu (*act.value );
314+ return Error ();
315+ }
266316
267- void backward (Argument& act) { act.grad ->softreluDerivative (*act.value ); }
317+ Error __must_check backward (Argument& act) {
318+ act.grad ->softreluDerivative (*act.value );
319+ return Error ();
320+ }
268321END_DEFINE_ACTIVATION (softrelu)
269322
270323/* *
@@ -280,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu)
280333 * 0 if z=0
281334 */
282335BEGIN_DEFINE_ACTIVATION (abs)
283- void forward (Argument& act) {
336+ Error __must_check forward (Argument& act) {
284337 SetDevice device (act.deviceId );
285338 Matrix::resizeOrCreate (act.in ,
286339 act.value ->getHeight (),
@@ -290,9 +343,13 @@ void forward(Argument& act) {
290343
291344 act.in ->copyFrom (*act.value );
292345 act.value ->abs2 (*act.value );
346+ return Error ();
293347}
294348
295- void backward (Argument& act) { act.grad ->absDerivative (*act.in ); }
349+ Error __must_check backward (Argument& act) {
350+ act.grad ->absDerivative (*act.in );
351+ return Error ();
352+ }
296353END_DEFINE_ACTIVATION (abs)
297354
298355/* *
@@ -302,7 +359,7 @@ END_DEFINE_ACTIVATION(abs)
302359 * \f]
303360 */
304361BEGIN_DEFINE_ACTIVATION (square)
305- void forward (Argument& act) {
362+ Error __must_check forward (Argument& act) {
306363 SetDevice device (act.deviceId );
307364 Matrix::resizeOrCreate (act.in ,
308365 act.value ->getHeight (),
@@ -312,9 +369,13 @@ void forward(Argument& act) {
312369
313370 act.in ->copyFrom (*act.value );
314371 act.value ->square2 (*act.value );
372+ return Error ();
315373}
316374
317- void backward (Argument& act) { act.grad ->squareDerivative (*act.in ); }
375+ Error __must_check backward (Argument& act) {
376+ act.grad ->squareDerivative (*act.in );
377+ return Error ();
378+ }
318379END_DEFINE_ACTIVATION (square)
319380
320381/* *
@@ -324,9 +385,15 @@ END_DEFINE_ACTIVATION(square)
324385 * \f]
325386 */
326387BEGIN_DEFINE_ACTIVATION (exponential)
327- void forward (Argument& act) { act.value ->exp2 (*act.value ); }
388+ Error __must_check forward (Argument& act) {
389+ act.value ->exp2 (*act.value );
390+ return Error ();
391+ }
328392
329- void backward (Argument& act) { act.grad ->expDerivative (*act.value ); }
393+ Error __must_check backward (Argument& act) {
394+ act.grad ->expDerivative (*act.value );
395+ return Error ();
396+ }
330397END_DEFINE_ACTIVATION (exponential)
331398
332399/* *
@@ -336,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential)
336403 * \f]
337404 */
338405BEGIN_DEFINE_ACTIVATION (log)
339- void forward (Argument& act) {
406+ Error __must_check forward (Argument& act) {
340407 SetDevice device (act.deviceId );
341408 Matrix::resizeOrCreate (act.in ,
342409 act.value ->getHeight (),
@@ -346,9 +413,13 @@ void forward(Argument& act) {
346413
347414 act.in ->copyFrom (*act.value );
348415 act.value ->log2 (*act.value );
416+ return Error ();
349417}
350418
351- void backward (Argument& act) { act.grad ->dotDiv (*act.grad , *act.in ); }
419+ Error __must_check backward (Argument& act) {
420+ act.grad ->dotDiv (*act.grad , *act.in );
421+ return Error ();
422+ }
352423END_DEFINE_ACTIVATION (log)
353424
354425ActivationFunction* ActivationFunction::create (const std::string& type) {
0 commit comments