@@ -152,6 +152,21 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
152152 }
153153}
154154
155+ template <typename T>
156+ __device__ __forceinline__ void ReduceMeanAndVar (T* mean, T* var, T x_mean,
157+ T x_var, int size) {
158+ const int nc = blockIdx .x ;
159+ x_mean = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
160+ x_mean, kps::AddFunctor<T>());
161+ x_var = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
162+ x_var, kps::AddFunctor<T>());
163+ __syncthreads ();
164+ if (threadIdx .x == 0 ) {
165+ mean[nc] = static_cast <T>(x_mean / size);
166+ var[nc] = static_cast <T>(x_var / size);
167+ }
168+ }
169+
155170template <typename T>
156171__global__ void ScalarGetMeanAndVarNCHW (const T* x, T* mean, T* var, int size) {
157172 int i = blockIdx .x ;
@@ -162,10 +177,7 @@ __global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
162177 x_mean += val;
163178 x_var += val * val;
164179 }
165- x_mean /= size;
166- x_var /= size;
167- CudaAtomicAddWithWarp (&mean[i], x_mean);
168- CudaAtomicAddWithWarp (&var[i], x_var);
180+ ReduceMeanAndVar<T>(mean, var, x_mean, x_var, size);
169181}
170182
171183template <typename T, typename AccT, int VecSize>
@@ -174,21 +186,12 @@ __global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var,
174186 int i = blockIdx .x ;
175187 AccT x_mean = static_cast <AccT>(0 );
176188 AccT x_var = static_cast <AccT>(0 );
177- const int input_offset = ((uint64_t )x) % ALIGN_BYTES / sizeof (T);
178189 x += i * size;
190+ const int input_offset = ((uint64_t )x) % ALIGN_BYTES / sizeof (T);
179191 phi::Array<const T*, 1 > ins;
180192 ins[0 ] = x;
181193 ThreadReduce<T, AccT, VecSize, 1 >(ins, size, input_offset, &x_mean, &x_var);
182-
183- x_mean = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
184- x_mean, kps::AddFunctor<AccT>());
185- x_var = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
186- x_var, kps::AddFunctor<AccT>());
187- __syncthreads ();
188- if (threadIdx .x == 0 ) {
189- mean[i] = static_cast <T>(x_mean / size);
190- var[i] = static_cast <T>(x_var / size);
191- }
194+ ReduceMeanAndVar<AccT>(mean, var, x_mean, x_var, size);
192195}
193196
194197template <typename T, int flags>
@@ -272,10 +275,6 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
272275 auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
273276 Tensor temp_var;
274277 temp_var.mutable_data <T>(var->dims (), ctx.GetPlace ());
275-
276- set_zero (dev_ctx, mean, static_cast <T>(0 ));
277- set_zero (dev_ctx, &temp_var, static_cast <T>(0 ));
278-
279278 auto * x_data = x->data <T>();
280279 auto * y_data = y->data <T>();
281280 auto * mean_data = mean->data <T>();
@@ -319,7 +318,7 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
319318 block_size_nchw = std::max (block_size_nchw, kps::details::kWarpSize );
320319 dim3 grids (x_dims[0 ] * groups);
321320 dim3 blocks (block_size_nchw);
322- if (size < vec_size) {
321+ if (size < vec_size * block_size_nchw ) {
323322 ScalarGetMeanAndVarNCHW<T><<<grids, blocks, 0 , dev_ctx.stream()>>> (
324323 x_data, mean_data, temp_var_data, size);
325324 } else {
@@ -328,6 +327,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
328327 x_data, mean_data, temp_var_data, size);
329328 }
330329 } else {
330+ set_zero (dev_ctx, mean, static_cast <T>(0 ));
331+ set_zero (dev_ctx, &temp_var, static_cast <T>(0 ));
331332 GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0 , dev_ctx.stream()>>> (
332333 x_data, x_dims[0 ], C, W, imsize, groups, group_size, mean_data,
333334 temp_var_data);
@@ -424,24 +425,15 @@ __global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
424425 int i = blockIdx .x ;
425426 AccT ds_sum = static_cast <AccT>(0 );
426427 AccT db_sum = static_cast <AccT>(0 );
427- const int input_offset = ((uint64_t )x) % ALIGN_BYTES / sizeof (T);
428428 x += i * imsize;
429+ const int input_offset = ((uint64_t )x) % ALIGN_BYTES / sizeof (T);
429430
430431 phi::Array<const T*, 2 > ins;
431432 ins[0 ] = x;
432433 ins[1 ] = dy;
433434 ThreadReduce<T, AccT, VecSize, 2 >(ins, imsize, input_offset, &db_sum,
434435 &ds_sum);
435-
436- ds_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
437- ds_sum, kps::AddFunctor<AccT>());
438- db_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
439- db_sum, kps::AddFunctor<AccT>());
440- __syncthreads ();
441- if (threadIdx .x == 0 ) {
442- ds[i] = ds_sum;
443- db[i] = db_sum;
444- }
436+ ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1 );
445437}
446438
447439template <typename T>
@@ -455,8 +447,7 @@ __global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
455447 ds_sum += dy[index] * x[index];
456448 db_sum += dy[index];
457449 }
458- CudaAtomicAddWithWarp (&ds[nc], ds_sum);
459- CudaAtomicAddWithWarp (&db[nc], db_sum);
450+ ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1 );
460451}
461452
462453template <typename T>
@@ -641,13 +632,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
641632 }
642633 block_size_nchw = std::max (block_size_nchw, kps::details::kWarpSize );
643634 dim3 blocks (block_size_nchw);
644- if (imsize < vec_size) {
645- if (d_scale) {
646- set_zero (dev_ctx, d_scale, static_cast <T>(0 ));
647- }
648- if (d_bias) {
649- set_zero (dev_ctx, d_bias, static_cast <T>(0 ));
650- }
635+ if (imsize < vec_size * block_size_nchw) {
651636 ScalarGetDsDbCUDAKernel<
652637 T><<<x_dims[0 ] * C, blocks, 0 , dev_ctx.stream()>>> (
653638 imsize, x_data, dy_data, ds_data, db_data);
@@ -687,7 +672,6 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
687672 imsize, C, group_size, groups, p1_data, p2_data, p3_data, x_data,
688673 dy_data, d_x_data);
689674 }
690-
691675 } else {
692676 if (d_scale) {
693677 set_zero (dev_ctx, d_scale, static_cast <T>(0 ));
0 commit comments