Skip to content

Commit 3c18889

Browse files
committed
fix fp16 reduce_mean
1 parent 25c35ba commit 3c18889

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

paddle/fluid/operators/elementwise/elementwise_min_op.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
2020
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
2121
#include "paddle/fluid/platform/eigen_ext.h"
22+
#include "paddle/fluid/platform/float16.h"
2223

2324
namespace paddle {
2425
namespace operators {
@@ -56,17 +57,39 @@ class ElementwiseFMinKernel : public framework::OpKernel<T> {
5657
template <typename T>
5758
struct MinGradDx {
5859
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
59-
return dout * static_cast<T>(x < y);
60+
return dout * (x < y);
6061
}
6162
};
6263

6364
template <typename T>
6465
struct MinGradDy {
6566
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
66-
return dout * static_cast<T>(x >= y);
67+
return dout * (x >= y);
6768
}
6869
};
6970

71+
#ifdef PADDLE_CUDA_FP16
72+
template <>
73+
struct MinGradDx<platform::float16> {
74+
HOSTDEVICE platform::float16 operator()(platform::float16 x,
75+
platform::float16 y,
76+
platform::float16 out,
77+
platform::float16 dout) const {
78+
return x < y ? dout : static_cast<platform::float16>(0);
79+
}
80+
};
81+
82+
template <>
83+
struct MinGradDy<platform::float16> {
84+
HOSTDEVICE platform::float16 operator()(platform::float16 x,
85+
platform::float16 y,
86+
platform::float16 out,
87+
platform::float16 dout) const {
88+
return x >= y ? dout : static_cast<platform::float16>(0);
89+
}
90+
};
91+
#endif
92+
7093
template <typename DeviceContext, typename T>
7194
class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
7295
public:

paddle/fluid/operators/reduce_ops/reduce_mean_op.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
1818

1919
REGISTER_OP_CUDA_KERNEL(
20+
reduce_mean,
2021
ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>,
2122
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
2223
kps::DivideFunctor>,

0 commit comments

Comments
 (0)