@@ -35,6 +35,106 @@ class ElementwiseMinKernel<platform::CUDADeviceContext, T>
3535 }
3636};
3737
38+ template <typename InT, typename OutT>
39+ struct MinGradXYFunctor {
40+ inline HOSTDEVICE paddle::framework::Array<OutT, 2 > operator ()(
41+ const InT& a, // x
42+ const InT& b, // y
43+ const InT& c) { // dout
44+ paddle::framework::Array<OutT, 2 > outs;
45+ // dx = dout * (x < y)
46+ outs[0 ] = a < b ? c : static_cast <InT>(0 );
47+ // dy = dout * (x >= y)
48+ outs[1 ] = (a > b || a == b) ? c : static_cast <InT>(0 );
49+ return outs;
50+ }
51+ };
52+
53+ template <typename T>
54+ void ReduceWrapper (const platform::CUDADeviceContext& dev_ctx, int axis,
55+ const framework::Tensor* in, const framework::Tensor* out,
56+ framework::Tensor* src, framework::Tensor* dst) {
57+ std::vector<int > reduce_dims = GetReduceDim (in->dims (), out->dims (), axis);
58+ TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
59+ *src, dst, kps::IdentityFunctor<T>(), reduce_dims, dev_ctx.stream ());
60+ }
61+
62+ template <typename DeviceContext, typename T>
63+ void DefaultElementMinGrad (const framework::ExecutionContext& ctx,
64+ const framework::Tensor* x,
65+ const framework::Tensor* y,
66+ const framework::Tensor* out,
67+ const framework::Tensor* dout, framework::Tensor* dx,
68+ framework::Tensor* dy) {
69+ int axis = ctx.Attr <int >(" axis" );
70+ const auto & dev_ctx =
71+ ctx.template device_context <platform::CUDADeviceContext>();
72+ framework::Tensor tmp_dx;
73+ framework::Tensor tmp_dy;
74+ tmp_dx.mutable_data <T>(dout->dims (), ctx.GetPlace ());
75+ tmp_dy.mutable_data <T>(dout->dims (), ctx.GetPlace ());
76+
77+ if (dx != nullptr && dy != nullptr ) {
78+ dx->mutable_data <T>(ctx.GetPlace ());
79+ dy->mutable_data <T>(ctx.GetPlace ());
80+ std::vector<const framework::Tensor*> ins = {x, y, dout};
81+ std::vector<framework::Tensor*> outs;
82+ if (dx->dims () == dout->dims () && dy->dims () == dout->dims ()) {
83+ outs = {dx, dy};
84+ } else if (dx->dims () != dout->dims () && dy->dims () == dout->dims ()) {
85+ outs = {&tmp_dx, dy};
86+ } else if (dx->dims () == dout->dims () && dy->dims () != dout->dims ()) {
87+ outs = {dx, &tmp_dy};
88+ } else if (dx->dims () != dout->dims () && dy->dims () != dout->dims ()) {
89+ outs = {&tmp_dx, &tmp_dy};
90+ }
91+ auto functor = MinGradXYFunctor<T, T>();
92+ LaunchElementwiseCudaKernel<ElementwiseType::kTernary , T, T,
93+ decltype (functor), 2 >(dev_ctx, ins, &outs, axis,
94+ functor);
95+ if (dx->dims () != dout->dims () && dy->dims () == dout->dims ()) {
96+ ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
97+ } else if (dx->dims () == dout->dims () && dy->dims () != dout->dims ()) {
98+ ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
99+ } else if (dx->dims () != dout->dims () && dy->dims () != dout->dims ()) {
100+ ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
101+ ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
102+ }
103+
104+ } else if (dx != nullptr && dy == nullptr ) {
105+ dx->mutable_data <T>(ctx.GetPlace ());
106+ std::vector<const framework::Tensor*> ins = {x, y, dout};
107+ std::vector<framework::Tensor*> outs;
108+ if (dx->dims () != dout->dims ()) {
109+ outs = {&tmp_dx};
110+ } else {
111+ outs = {dx};
112+ }
113+
114+ LaunchElementwiseCudaKernel<ElementwiseType::kTernary , T, T>(
115+ dev_ctx, ins, &outs, axis, TernaryLessThanFunctor<T>());
116+ if (dx->dims () != dout->dims ()) {
117+ ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
118+ }
119+ } else if (dx == nullptr && dy != nullptr ) {
120+ dy->mutable_data <T>(ctx.GetPlace ());
121+ std::vector<const framework::Tensor*> ins = {x, y, dout};
122+ std::vector<framework::Tensor*> outs;
123+ if (dy->dims () != dout->dims ()) {
124+ outs = {&tmp_dy};
125+ } else {
126+ outs = {dy};
127+ }
128+
129+ LaunchElementwiseCudaKernel<ElementwiseType::kTernary , T, T>(
130+ dev_ctx, ins, &outs, axis, TernaryGreaterEqualThanFunctor<T>());
131+ if (dy->dims () != dout->dims ()) {
132+ ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
133+ }
134+ }
135+ }
136+
137+ /*
38138template <typename DeviceContext, typename T>
39139void DefaultElementMinGrad(const framework::ExecutionContext& ctx,
40140 const framework::Tensor* x,
@@ -95,6 +195,7 @@ void DefaultElementMinGrad(const framework::ExecutionContext& ctx,
95195 }
96196 }
97197}
198+ */
98199
99200template <typename T>
100201class ElementwiseMinGradKernel <platform::CUDADeviceContext, T>
0 commit comments