Skip to content

Commit d068de7

Browse files
committed
multi-outputs init commit
1 parent 10759be commit d068de7

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

paddle/fluid/operators/elementwise/elementwise_min_op.cu

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,106 @@ class ElementwiseMinKernel<platform::CUDADeviceContext, T>
3535
}
3636
};
3737

38+
template <typename InT, typename OutT>
39+
struct MinGradXYFunctor {
40+
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(
41+
const InT& a, // x
42+
const InT& b, // y
43+
const InT& c) { // dout
44+
paddle::framework::Array<OutT, 2> outs;
45+
// dx = dout * (x < y)
46+
outs[0] = a < b ? c : static_cast<InT>(0);
47+
// dy = dout * (x >= y)
48+
outs[1] = (a > b || a == b) ? c : static_cast<InT>(0);
49+
return outs;
50+
}
51+
};
52+
53+
template <typename T>
54+
void ReduceWrapper(const platform::CUDADeviceContext& dev_ctx, int axis,
55+
const framework::Tensor* in, const framework::Tensor* out,
56+
framework::Tensor* src, framework::Tensor* dst) {
57+
std::vector<int> reduce_dims = GetReduceDim(in->dims(), out->dims(), axis);
58+
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
59+
*src, dst, kps::IdentityFunctor<T>(), reduce_dims, dev_ctx.stream());
60+
}
61+
62+
template <typename DeviceContext, typename T>
63+
void DefaultElementMinGrad(const framework::ExecutionContext& ctx,
64+
const framework::Tensor* x,
65+
const framework::Tensor* y,
66+
const framework::Tensor* out,
67+
const framework::Tensor* dout, framework::Tensor* dx,
68+
framework::Tensor* dy) {
69+
int axis = ctx.Attr<int>("axis");
70+
const auto& dev_ctx =
71+
ctx.template device_context<platform::CUDADeviceContext>();
72+
framework::Tensor tmp_dx;
73+
framework::Tensor tmp_dy;
74+
tmp_dx.mutable_data<T>(dout->dims(), ctx.GetPlace());
75+
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace());
76+
77+
if (dx != nullptr && dy != nullptr) {
78+
dx->mutable_data<T>(ctx.GetPlace());
79+
dy->mutable_data<T>(ctx.GetPlace());
80+
std::vector<const framework::Tensor*> ins = {x, y, dout};
81+
std::vector<framework::Tensor*> outs;
82+
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
83+
outs = {dx, dy};
84+
} else if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
85+
outs = {&tmp_dx, dy};
86+
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) {
87+
outs = {dx, &tmp_dy};
88+
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) {
89+
outs = {&tmp_dx, &tmp_dy};
90+
}
91+
auto functor = MinGradXYFunctor<T, T>();
92+
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T,
93+
decltype(functor), 2>(dev_ctx, ins, &outs, axis,
94+
functor);
95+
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
96+
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
97+
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) {
98+
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
99+
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) {
100+
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
101+
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
102+
}
103+
104+
} else if (dx != nullptr && dy == nullptr) {
105+
dx->mutable_data<T>(ctx.GetPlace());
106+
std::vector<const framework::Tensor*> ins = {x, y, dout};
107+
std::vector<framework::Tensor*> outs;
108+
if (dx->dims() != dout->dims()) {
109+
outs = {&tmp_dx};
110+
} else {
111+
outs = {dx};
112+
}
113+
114+
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>(
115+
dev_ctx, ins, &outs, axis, TernaryLessThanFunctor<T>());
116+
if (dx->dims() != dout->dims()) {
117+
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
118+
}
119+
} else if (dx == nullptr && dy != nullptr) {
120+
dy->mutable_data<T>(ctx.GetPlace());
121+
std::vector<const framework::Tensor*> ins = {x, y, dout};
122+
std::vector<framework::Tensor*> outs;
123+
if (dy->dims() != dout->dims()) {
124+
outs = {&tmp_dy};
125+
} else {
126+
outs = {dy};
127+
}
128+
129+
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>(
130+
dev_ctx, ins, &outs, axis, TernaryGreaterEqualThanFunctor<T>());
131+
if (dy->dims() != dout->dims()) {
132+
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
133+
}
134+
}
135+
}
136+
137+
/*
38138
template <typename DeviceContext, typename T>
39139
void DefaultElementMinGrad(const framework::ExecutionContext& ctx,
40140
const framework::Tensor* x,
@@ -95,6 +195,7 @@ void DefaultElementMinGrad(const framework::ExecutionContext& ctx,
95195
}
96196
}
97197
}
198+
*/
98199

99200
template <typename T>
100201
class ElementwiseMinGradKernel<platform::CUDADeviceContext, T>

0 commit comments

Comments
 (0)