Skip to content

Commit 6f69fbc

Browse files
authored
fix elu grad whne alpha less then zero, test=develop (#26543)
1 parent 786373b commit 6f69fbc

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

paddle/fluid/operators/activation_op.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,9 +1134,20 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
11341134
template <typename Device, typename X, typename Out, typename dOut,
11351135
typename dX>
11361136
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
1137-
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
1138-
dout * static_cast<T>(alpha) * x.exp() *
1139-
(x <= static_cast<T>(0)).template cast<T>();
1137+
auto temp_a_pos = static_cast<T>(alpha > 0);
1138+
auto temp_a_neg = static_cast<T>(alpha <= 0);
1139+
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
1140+
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
1141+
1142+
// dx = dout, if alpha > 0 and x > 0
1143+
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
1144+
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
1145+
// dx = 0, if alpha <= 0 and x <=0
1146+
dx.device(d) =
1147+
dout * temp_a_pos * temp_x_pos +
1148+
dout * static_cast<T>(alpha) * x.exp() * temp_a_pos * temp_x_neg +
1149+
dout * (static_cast<T>(1) + static_cast<T>(alpha) * x.exp()) *
1150+
temp_a_neg * temp_x_pos;
11401151
}
11411152

11421153
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }

0 commit comments

Comments
 (0)