@@ -19,6 +19,7 @@ limitations under the License. */
1919#include " paddle/fluid/operators/elementwise/elementwise_op.h"
2020#include " paddle/fluid/operators/elementwise/elementwise_op_function.h"
2121#include " paddle/fluid/platform/eigen_ext.h"
22+ #include " paddle/fluid/platform/float16.h"
2223
2324namespace paddle {
2425namespace operators {
@@ -56,17 +57,39 @@ class ElementwiseFMinKernel : public framework::OpKernel<T> {
5657template <typename T>
5758struct MinGradDx {
5859 HOSTDEVICE T operator ()(T x, T y, T out, T dout) const {
59- return dout * static_cast <T> (x < y);
60+ return dout * (x < y);
6061 }
6162};
6263
6364template <typename T>
6465struct MinGradDy {
6566 HOSTDEVICE T operator ()(T x, T y, T out, T dout) const {
66- return dout * static_cast <T> (x >= y);
67+ return dout * (x >= y);
6768 }
6869};
6970
71+ #ifdef PADDLE_CUDA_FP16
72+ template <>
73+ struct MinGradDx <platform::float16> {
74+ HOSTDEVICE platform::float16 operator ()(platform::float16 x,
75+ platform::float16 y,
76+ platform::float16 out,
77+ platform::float16 dout) const {
78+ return x < y ? dout : static_cast <platform::float16>(0 );
79+ }
80+ };
81+
82+ template <>
83+ struct MinGradDy <platform::float16> {
84+ HOSTDEVICE platform::float16 operator ()(platform::float16 x,
85+ platform::float16 y,
86+ platform::float16 out,
87+ platform::float16 dout) const {
88+ return x >= y ? dout : static_cast <platform::float16>(0 );
89+ }
90+ };
91+ #endif
92+
7093template <typename DeviceContext, typename T>
7194class ElementwiseMinGradKernel : public ElemwiseGradKernel <T> {
7295 public:
0 commit comments