Skip to content

Commit 5d9e11a

Browse files
authored
Modified sigmoid by the elementwise interface. (#39898)
* Modified sigmoid by elementwise interface. * using TensorReduceImpl to repalce Sum function * using reduceimpl to calculate the norm variable * Removed useless code
1 parent 3e56e81 commit 5d9e11a

File tree

1 file changed

+139
-68
lines changed

1 file changed

+139
-68
lines changed

paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu

Lines changed: 139 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ namespace cub = hipcub;
2020
#endif
2121
#include "paddle/fluid/memory/malloc.h"
2222
#include "paddle/fluid/operators/math.h"
23+
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
2324
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
2425
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
2526
#include "paddle/phi/core/hostdevice.h"
27+
#include "paddle/phi/kernels/funcs/elementwise_base.h"
2628

2729
namespace paddle {
2830
namespace operators {
@@ -42,71 +44,86 @@ static inline int NumBlocks(const int N) {
4244
}
4345

4446
template <typename T>
45-
__global__ void GPUSigmoidForward(const T *x_data, const T *label_data,
46-
const int ignore_index, const int limit,
47-
T *out_data, T *counts) {
48-
CUDA_KERNEL_LOOP(i, limit) {
49-
T x = x_data[i];
50-
T label = label_data[i];
51-
T eps = static_cast<T>(1e-5);
52-
T diff = label - static_cast<T>(ignore_index);
47+
struct NonzeroFunctor {
48+
HOSTDEVICE explicit inline NonzeroFunctor() {}
49+
HOSTDEVICE inline T operator()(const T x) const {
50+
return static_cast<T>(static_cast<double>(x) != 0);
51+
}
52+
};
53+
54+
template <typename T>
55+
struct SigmoidFwdFunctor {
56+
T ignore_index_;
57+
T eps = static_cast<T>(1e-5);
58+
59+
HOSTDEVICE inline SigmoidFwdFunctor(const T ignore_index)
60+
: ignore_index_(ignore_index) {}
61+
62+
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x, const T label) {
63+
T counts;
64+
T out_data;
65+
66+
T diff = label - static_cast<T>(ignore_index_);
5367
if ((diff > -eps) && (diff < eps)) {
54-
out_data[i] = static_cast<T>(0.);
55-
counts[i] = 0;
68+
out_data = static_cast<T>(0.);
69+
counts = 0;
5670
} else {
5771
T term1 = (x > 0) ? x : 0;
5872
T term2 = x * label;
5973
T term3 = real_log(static_cast<T>(1) + real_exp(static_cast<T>(-abs(x))));
60-
out_data[i] = term1 - term2 + term3;
61-
counts[i] = 1;
74+
75+
out_data = term1 - term2 + term3;
76+
counts = 1;
6277
}
63-
}
64-
}
78+
phi::Array<T, 2> outs;
6579

66-
template <typename T, int BlockDim>
67-
__global__ void Sum(const T *counts, int num, const T eps, T *sum) {
68-
typedef cub::BlockReduce<double, BlockDim> BlockReduce;
69-
__shared__ typename BlockReduce::TempStorage temp_storage;
70-
T in = 0;
71-
for (int i = threadIdx.x; i < num; i += BlockDim) {
72-
in += counts[i];
80+
outs[0] = out_data;
81+
outs[1] = counts;
82+
return outs;
7383
}
74-
__syncthreads();
75-
auto out =
76-
BlockReduce(temp_storage).Reduce(static_cast<double>(in), cub::Sum());
77-
__syncthreads();
78-
if (threadIdx.x == 0) {
79-
T a = out > eps ? out : eps;
80-
sum[0] = a;
81-
}
82-
}
84+
};
8385

8486
template <typename T>
85-
__global__ void Div(T *loss, const int num, const T *norm) {
86-
CUDA_KERNEL_LOOP(i, num) { loss[i] /= norm[0]; }
87-
}
87+
struct SigmoidBwdFunctor {
88+
T ignore_index_;
89+
T eps = static_cast<T>(1e-5);
8890

89-
template <typename T>
90-
__global__ void GPUSigmoidBackward(const T *x_data, const T *label_data,
91-
const int ignore_index, const T *dout_data,
92-
const int limit, T *dx_data, T *counts) {
93-
CUDA_KERNEL_LOOP(i, limit) {
94-
T x = x_data[i];
95-
T label = label_data[i];
96-
T dout = dout_data[i];
97-
T eps = static_cast<T>(1e-5);
98-
T diff = label - static_cast<T>(ignore_index);
91+
HOSTDEVICE inline SigmoidBwdFunctor(const T ignore_index)
92+
: ignore_index_(ignore_index) {}
93+
94+
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x, const T label,
95+
const T dout) {
96+
T counts;
97+
T dx_data;
98+
99+
T diff = label - static_cast<T>(ignore_index_);
99100
if ((diff > -eps) && (diff < eps)) {
100-
dx_data[i] = static_cast<T>(0.);
101-
counts[i] = 0;
101+
dx_data = static_cast<T>(0.);
102+
counts = 0;
102103
} else {
103104
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + real_exp(-x));
104105
T diff = simoid_x - label;
105-
dx_data[i] = dout * diff;
106-
counts[i] = 1;
106+
dx_data = dout * diff;
107+
counts = 1;
107108
}
109+
phi::Array<T, 2> outs;
110+
111+
outs[0] = dx_data;
112+
outs[1] = counts;
113+
return outs;
108114
}
109-
}
115+
};
116+
117+
template <typename T>
118+
struct DivFunctor {
119+
const T norm_;
120+
HOSTDEVICE inline DivFunctor(const T norm) : norm_(norm) {}
121+
122+
HOSTDEVICE inline T operator()(T loss) {
123+
loss /= norm_;
124+
return loss;
125+
}
126+
};
110127

111128
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
112129
template <typename DeviceContext, typename T>
@@ -123,20 +140,48 @@ class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
123140
bool normalize = context.Attr<bool>("normalize");
124141

125142
// Temporary memory
126-
auto cnt_ptr = memory::Alloc(dev_ctx, Labels->numel() * sizeof(T));
127-
T *counts = reinterpret_cast<T *>(cnt_ptr->ptr());
128-
143+
Tensor *counts_tensor = new Tensor();
144+
counts_tensor->mutable_data<T>(context.GetPlace(),
145+
Labels->numel() * sizeof(T));
146+
counts_tensor->Resize(Out->dims());
129147
int limit = Out->numel();
130148
int blocks = NumBlocks(limit);
131149
int threads = kNumCUDAThreads;
132-
GPUSigmoidForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
133-
X->data<T>(), Labels->data<T>(), ignore_index, limit, out_data, counts);
150+
std::vector<const framework::Tensor *> ins = {X, Labels};
151+
std::vector<framework::Tensor *> outs = {Out, counts_tensor};
152+
auto functor = SigmoidFwdFunctor<T>(ignore_index);
153+
constexpr int Size = 2;
154+
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(dev_ctx, ins,
155+
&outs, functor);
134156
if (normalize) {
135-
auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T));
136-
T *norm = reinterpret_cast<T *>(norm_ptr->ptr());
137-
Sum<T, kNumCUDAThreads><<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>(
138-
counts, limit, static_cast<T>(1e-5), norm);
139-
Div<T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_data, limit, norm);
157+
T *counts = counts_tensor->mutable_data<T>(context.GetPlace());
158+
Tensor *norm_tensor = new Tensor();
159+
norm_tensor->mutable_data<T>(context.GetPlace(), sizeof(T));
160+
auto dims = phi::vectorize(counts_tensor->dims());
161+
std::vector<int> reduce_dim = {};
162+
for (int i = 0; i < dims.size(); i++) {
163+
reduce_dim.push_back(i);
164+
}
165+
166+
TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
167+
context.cuda_device_context(), *counts_tensor, norm_tensor,
168+
NonzeroFunctor<T>(), reduce_dim, dev_ctx.stream());
169+
T *norm = norm_tensor->mutable_data<T>(context.GetPlace());
170+
auto norm_cpu_mem = memory::Alloc(platform::CPUPlace(), sizeof(T));
171+
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
172+
memory::Copy(platform::CPUPlace(), norm_cpu_ptr, dev_ctx.GetPlace(), norm,
173+
sizeof(T), dev_ctx.stream());
174+
auto eps = static_cast<T>(1e-5);
175+
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
176+
177+
std::vector<const framework::Tensor *> div_ins = {Out};
178+
std::vector<framework::Tensor *> div_outs = {Out};
179+
auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
180+
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs,
181+
div_functor);
182+
183+
delete norm_tensor;
184+
delete counts_tensor;
140185
}
141186
}
142187
};
@@ -157,22 +202,48 @@ class GPUSigmoidCrossEntropyWithLogitsGradKernel
157202

158203
auto &dev_ctx = context.cuda_device_context();
159204
// Temporary memory
160-
auto cnt_ptr = memory::Alloc(dev_ctx, X->numel() * sizeof(T));
161-
T *counts = reinterpret_cast<T *>(cnt_ptr->ptr());
205+
Tensor *counts_tensor = new Tensor();
206+
counts_tensor->mutable_data<T>(context.GetPlace(),
207+
Labels->numel() * sizeof(T));
208+
counts_tensor->Resize(dX->dims());
162209

163210
int limit = dX->numel();
164211
int blocks = NumBlocks(limit);
165212
int threads = kNumCUDAThreads;
166-
GPUSigmoidBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
167-
X->data<T>(), Labels->data<T>(), ignore_index, dOut->data<T>(), limit,
168-
dx_data, counts);
213+
std::vector<const framework::Tensor *> ins = {X, Labels, dOut};
214+
std::vector<framework::Tensor *> outs = {dX, counts_tensor};
215+
auto functor = SigmoidBwdFunctor<T>(ignore_index);
216+
constexpr int Size = 2;
217+
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(dev_ctx, ins,
218+
&outs, functor);
169219
bool normalize = context.Attr<bool>("normalize");
170220
if (normalize) {
171-
auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T));
172-
T *norm = reinterpret_cast<T *>(norm_ptr->ptr());
173-
Sum<T, kNumCUDAThreads><<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>(
174-
counts, limit, static_cast<T>(1e-5), norm);
175-
Div<T><<<blocks, threads, 0, dev_ctx.stream()>>>(dx_data, limit, norm);
221+
T *counts = counts_tensor->mutable_data<T>(context.GetPlace());
222+
Tensor *norm_tensor = new Tensor();
223+
norm_tensor->mutable_data<T>(context.GetPlace(), sizeof(T));
224+
auto dims = phi::vectorize(counts_tensor->dims());
225+
std::vector<int> reduce_dim = {};
226+
for (int i = 0; i < dims.size(); i++) {
227+
reduce_dim.push_back(i);
228+
}
229+
230+
TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
231+
context.cuda_device_context(), *counts_tensor, norm_tensor,
232+
NonzeroFunctor<T>(), reduce_dim, dev_ctx.stream());
233+
T *norm = norm_tensor->mutable_data<T>(context.GetPlace());
234+
auto norm_cpu_mem = memory::Alloc(platform::CPUPlace(), sizeof(T));
235+
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
236+
memory::Copy(platform::CPUPlace(), norm_cpu_ptr, dev_ctx.GetPlace(), norm,
237+
sizeof(T), dev_ctx.stream());
238+
auto eps = static_cast<T>(1e-5);
239+
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
240+
241+
std::vector<const framework::Tensor *> div_ins = {dX};
242+
std::vector<framework::Tensor *> div_outs = {dX};
243+
auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
244+
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs,
245+
div_functor);
246+
delete norm_tensor;
176247
}
177248
}
178249
};

0 commit comments

Comments
 (0)