@@ -20,9 +20,11 @@ namespace cub = hipcub;
2020#endif
2121#include " paddle/fluid/memory/malloc.h"
2222#include " paddle/fluid/operators/math.h"
23+ #include " paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
2324#include " paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
2425#include " paddle/fluid/platform/device/gpu/gpu_primitives.h"
2526#include " paddle/phi/core/hostdevice.h"
27+ #include " paddle/phi/kernels/funcs/elementwise_base.h"
2628
2729namespace paddle {
2830namespace operators {
@@ -42,71 +44,86 @@ static inline int NumBlocks(const int N) {
4244}
4345
4446template <typename T>
45- __global__ void GPUSigmoidForward (const T *x_data, const T *label_data,
46- const int ignore_index, const int limit,
47- T *out_data, T *counts) {
48- CUDA_KERNEL_LOOP (i, limit) {
49- T x = x_data[i];
50- T label = label_data[i];
51- T eps = static_cast <T>(1e-5 );
52- T diff = label - static_cast <T>(ignore_index);
47+ struct NonzeroFunctor {
48+ HOSTDEVICE explicit inline NonzeroFunctor () {}
49+ HOSTDEVICE inline T operator ()(const T x) const {
50+ return static_cast <T>(static_cast <double >(x) != 0 );
51+ }
52+ };
53+
54+ template <typename T>
55+ struct SigmoidFwdFunctor {
56+ T ignore_index_;
57+ T eps = static_cast <T>(1e-5 );
58+
59+ HOSTDEVICE inline SigmoidFwdFunctor (const T ignore_index)
60+ : ignore_index_(ignore_index) {}
61+
62+ HOSTDEVICE inline phi::Array<T, 2 > operator ()(const T x, const T label) {
63+ T counts;
64+ T out_data;
65+
66+ T diff = label - static_cast <T>(ignore_index_);
5367 if ((diff > -eps) && (diff < eps)) {
54- out_data[i] = static_cast <T>(0 .);
55- counts[i] = 0 ;
68+ out_data = static_cast <T>(0 .);
69+ counts = 0 ;
5670 } else {
5771 T term1 = (x > 0 ) ? x : 0 ;
5872 T term2 = x * label;
5973 T term3 = real_log (static_cast <T>(1 ) + real_exp (static_cast <T>(-abs (x))));
60- out_data[i] = term1 - term2 + term3;
61- counts[i] = 1 ;
74+
75+ out_data = term1 - term2 + term3;
76+ counts = 1 ;
6277 }
63- }
64- }
78+ phi::Array<T, 2 > outs;
6579
66- template <typename T, int BlockDim>
67- __global__ void Sum (const T *counts, int num, const T eps, T *sum) {
68- typedef cub::BlockReduce<double , BlockDim> BlockReduce;
69- __shared__ typename BlockReduce::TempStorage temp_storage;
70- T in = 0 ;
71- for (int i = threadIdx .x ; i < num; i += BlockDim) {
72- in += counts[i];
80+ outs[0 ] = out_data;
81+ outs[1 ] = counts;
82+ return outs;
7383 }
74- __syncthreads ();
75- auto out =
76- BlockReduce (temp_storage).Reduce (static_cast <double >(in), cub::Sum ());
77- __syncthreads ();
78- if (threadIdx .x == 0 ) {
79- T a = out > eps ? out : eps;
80- sum[0 ] = a;
81- }
82- }
84+ };
8385
8486template <typename T>
85- __global__ void Div (T *loss, const int num, const T *norm) {
86- CUDA_KERNEL_LOOP (i, num) { loss[i] /= norm[ 0 ]; }
87- }
87+ struct SigmoidBwdFunctor {
88+ T ignore_index_;
89+ T eps = static_cast <T>( 1e-5 );
8890
89- template <typename T>
90- __global__ void GPUSigmoidBackward (const T *x_data, const T *label_data,
91- const int ignore_index, const T *dout_data,
92- const int limit, T *dx_data, T *counts) {
93- CUDA_KERNEL_LOOP (i, limit) {
94- T x = x_data[i];
95- T label = label_data[i];
96- T dout = dout_data[i];
97- T eps = static_cast <T>(1e-5 );
98- T diff = label - static_cast <T>(ignore_index);
91+ HOSTDEVICE inline SigmoidBwdFunctor (const T ignore_index)
92+ : ignore_index_(ignore_index) {}
93+
94+ HOSTDEVICE inline phi::Array<T, 2 > operator ()(const T x, const T label,
95+ const T dout) {
96+ T counts;
97+ T dx_data;
98+
99+ T diff = label - static_cast <T>(ignore_index_);
99100 if ((diff > -eps) && (diff < eps)) {
100- dx_data[i] = static_cast <T>(0 .);
101- counts[i] = 0 ;
101+ dx_data = static_cast <T>(0 .);
102+ counts = 0 ;
102103 } else {
103104 T simoid_x = static_cast <T>(1 ) / (static_cast <T>(1 ) + real_exp (-x));
104105 T diff = simoid_x - label;
105- dx_data[i] = dout * diff;
106- counts[i] = 1 ;
106+ dx_data = dout * diff;
107+ counts = 1 ;
107108 }
109+ phi::Array<T, 2 > outs;
110+
111+ outs[0 ] = dx_data;
112+ outs[1 ] = counts;
113+ return outs;
108114 }
109- }
115+ };
116+
117+ template <typename T>
118+ struct DivFunctor {
119+ const T norm_;
120+ HOSTDEVICE inline DivFunctor (const T norm) : norm_(norm) {}
121+
122+ HOSTDEVICE inline T operator ()(T loss) {
123+ loss /= norm_;
124+ return loss;
125+ }
126+ };
110127
111128// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
112129template <typename DeviceContext, typename T>
@@ -123,20 +140,48 @@ class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
123140 bool normalize = context.Attr <bool >(" normalize" );
124141
125142 // Temporary memory
126- auto cnt_ptr = memory::Alloc (dev_ctx, Labels->numel () * sizeof (T));
127- T *counts = reinterpret_cast <T *>(cnt_ptr->ptr ());
128-
143+ Tensor *counts_tensor = new Tensor ();
144+ counts_tensor->mutable_data <T>(context.GetPlace (),
145+ Labels->numel () * sizeof (T));
146+ counts_tensor->Resize (Out->dims ());
129147 int limit = Out->numel ();
130148 int blocks = NumBlocks (limit);
131149 int threads = kNumCUDAThreads ;
132- GPUSigmoidForward<T><<<blocks, threads, 0 , dev_ctx.stream()>>> (
133- X->data <T>(), Labels->data <T>(), ignore_index, limit, out_data, counts);
150+ std::vector<const framework::Tensor *> ins = {X, Labels};
151+ std::vector<framework::Tensor *> outs = {Out, counts_tensor};
152+ auto functor = SigmoidFwdFunctor<T>(ignore_index);
153+ constexpr int Size = 2 ;
154+ phi::funcs::ElementwiseKernel<T, decltype (functor), Size>(dev_ctx, ins,
155+ &outs, functor);
134156 if (normalize) {
135- auto norm_ptr = memory::Alloc (dev_ctx, sizeof (T));
136- T *norm = reinterpret_cast <T *>(norm_ptr->ptr ());
137- Sum<T, kNumCUDAThreads ><<<1 , kNumCUDAThreads , 0 , dev_ctx.stream()>>> (
138- counts, limit, static_cast <T>(1e-5 ), norm);
139- Div<T><<<blocks, threads, 0 , dev_ctx.stream()>>> (out_data, limit, norm);
157+ T *counts = counts_tensor->mutable_data <T>(context.GetPlace ());
158+ Tensor *norm_tensor = new Tensor ();
159+ norm_tensor->mutable_data <T>(context.GetPlace (), sizeof (T));
160+ auto dims = phi::vectorize (counts_tensor->dims ());
161+ std::vector<int > reduce_dim = {};
162+ for (int i = 0 ; i < dims.size (); i++) {
163+ reduce_dim.push_back (i);
164+ }
165+
166+ TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
167+ context.cuda_device_context (), *counts_tensor, norm_tensor,
168+ NonzeroFunctor<T>(), reduce_dim, dev_ctx.stream ());
169+ T *norm = norm_tensor->mutable_data <T>(context.GetPlace ());
170+ auto norm_cpu_mem = memory::Alloc (platform::CPUPlace (), sizeof (T));
171+ T *norm_cpu_ptr = reinterpret_cast <T *>(norm_cpu_mem->ptr ());
172+ memory::Copy (platform::CPUPlace (), norm_cpu_ptr, dev_ctx.GetPlace (), norm,
173+ sizeof (T), dev_ctx.stream ());
174+ auto eps = static_cast <T>(1e-5 );
175+ *norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
176+
177+ std::vector<const framework::Tensor *> div_ins = {Out};
178+ std::vector<framework::Tensor *> div_outs = {Out};
179+ auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
180+ phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs,
181+ div_functor);
182+
183+ delete norm_tensor;
184+ delete counts_tensor;
140185 }
141186 }
142187};
@@ -157,22 +202,48 @@ class GPUSigmoidCrossEntropyWithLogitsGradKernel
157202
158203 auto &dev_ctx = context.cuda_device_context ();
159204 // Temporary memory
160- auto cnt_ptr = memory::Alloc (dev_ctx, X->numel () * sizeof (T));
161- T *counts = reinterpret_cast <T *>(cnt_ptr->ptr ());
205+ Tensor *counts_tensor = new Tensor ();
206+ counts_tensor->mutable_data <T>(context.GetPlace (),
207+ Labels->numel () * sizeof (T));
208+ counts_tensor->Resize (dX->dims ());
162209
163210 int limit = dX->numel ();
164211 int blocks = NumBlocks (limit);
165212 int threads = kNumCUDAThreads ;
166- GPUSigmoidBackward<T><<<blocks, threads, 0 , dev_ctx.stream()>>> (
167- X->data <T>(), Labels->data <T>(), ignore_index, dOut->data <T>(), limit,
168- dx_data, counts);
213+ std::vector<const framework::Tensor *> ins = {X, Labels, dOut};
214+ std::vector<framework::Tensor *> outs = {dX, counts_tensor};
215+ auto functor = SigmoidBwdFunctor<T>(ignore_index);
216+ constexpr int Size = 2 ;
217+ phi::funcs::ElementwiseKernel<T, decltype (functor), Size>(dev_ctx, ins,
218+ &outs, functor);
169219 bool normalize = context.Attr <bool >(" normalize" );
170220 if (normalize) {
171- auto norm_ptr = memory::Alloc (dev_ctx, sizeof (T));
172- T *norm = reinterpret_cast <T *>(norm_ptr->ptr ());
173- Sum<T, kNumCUDAThreads ><<<1 , kNumCUDAThreads , 0 , dev_ctx.stream()>>> (
174- counts, limit, static_cast <T>(1e-5 ), norm);
175- Div<T><<<blocks, threads, 0 , dev_ctx.stream()>>> (dx_data, limit, norm);
221+ T *counts = counts_tensor->mutable_data <T>(context.GetPlace ());
222+ Tensor *norm_tensor = new Tensor ();
223+ norm_tensor->mutable_data <T>(context.GetPlace (), sizeof (T));
224+ auto dims = phi::vectorize (counts_tensor->dims ());
225+ std::vector<int > reduce_dim = {};
226+ for (int i = 0 ; i < dims.size (); i++) {
227+ reduce_dim.push_back (i);
228+ }
229+
230+ TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
231+ context.cuda_device_context (), *counts_tensor, norm_tensor,
232+ NonzeroFunctor<T>(), reduce_dim, dev_ctx.stream ());
233+ T *norm = norm_tensor->mutable_data <T>(context.GetPlace ());
234+ auto norm_cpu_mem = memory::Alloc (platform::CPUPlace (), sizeof (T));
235+ T *norm_cpu_ptr = reinterpret_cast <T *>(norm_cpu_mem->ptr ());
236+ memory::Copy (platform::CPUPlace (), norm_cpu_ptr, dev_ctx.GetPlace (), norm,
237+ sizeof (T), dev_ctx.stream ());
238+ auto eps = static_cast <T>(1e-5 );
239+ *norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
240+
241+ std::vector<const framework::Tensor *> div_ins = {dX};
242+ std::vector<framework::Tensor *> div_outs = {dX};
243+ auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
244+ phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs,
245+ div_functor);
246+ delete norm_tensor;
176247 }
177248 }
178249};
0 commit comments