Skip to content

Commit dd9d720

Browse files
authored
fix group_norm address misalignment (#40657)
* fix group_norm address misalignment * fix vectorize * fix code * fix vectorize length * optimize code
1 parent c63e03b commit dd9d720

File tree

1 file changed

+25
-41
lines changed

1 file changed

+25
-41
lines changed

paddle/fluid/operators/group_norm_op.cu

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,21 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
152152
}
153153
}
154154

155+
template <typename T>
156+
__device__ __forceinline__ void ReduceMeanAndVar(T* mean, T* var, T x_mean,
157+
T x_var, int size) {
158+
const int nc = blockIdx.x;
159+
x_mean = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
160+
x_mean, kps::AddFunctor<T>());
161+
x_var = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
162+
x_var, kps::AddFunctor<T>());
163+
__syncthreads();
164+
if (threadIdx.x == 0) {
165+
mean[nc] = static_cast<T>(x_mean / size);
166+
var[nc] = static_cast<T>(x_var / size);
167+
}
168+
}
169+
155170
template <typename T>
156171
__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
157172
int i = blockIdx.x;
@@ -162,10 +177,7 @@ __global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
162177
x_mean += val;
163178
x_var += val * val;
164179
}
165-
x_mean /= size;
166-
x_var /= size;
167-
CudaAtomicAddWithWarp(&mean[i], x_mean);
168-
CudaAtomicAddWithWarp(&var[i], x_var);
180+
ReduceMeanAndVar<T>(mean, var, x_mean, x_var, size);
169181
}
170182

171183
template <typename T, typename AccT, int VecSize>
@@ -174,21 +186,12 @@ __global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var,
174186
int i = blockIdx.x;
175187
AccT x_mean = static_cast<AccT>(0);
176188
AccT x_var = static_cast<AccT>(0);
177-
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
178189
x += i * size;
190+
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
179191
phi::Array<const T*, 1> ins;
180192
ins[0] = x;
181193
ThreadReduce<T, AccT, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var);
182-
183-
x_mean = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
184-
x_mean, kps::AddFunctor<AccT>());
185-
x_var = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
186-
x_var, kps::AddFunctor<AccT>());
187-
__syncthreads();
188-
if (threadIdx.x == 0) {
189-
mean[i] = static_cast<T>(x_mean / size);
190-
var[i] = static_cast<T>(x_var / size);
191-
}
194+
ReduceMeanAndVar<AccT>(mean, var, x_mean, x_var, size);
192195
}
193196

194197
template <typename T, int flags>
@@ -272,10 +275,6 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
272275
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
273276
Tensor temp_var;
274277
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
275-
276-
set_zero(dev_ctx, mean, static_cast<T>(0));
277-
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
278-
279278
auto* x_data = x->data<T>();
280279
auto* y_data = y->data<T>();
281280
auto* mean_data = mean->data<T>();
@@ -319,7 +318,7 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
319318
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
320319
dim3 grids(x_dims[0] * groups);
321320
dim3 blocks(block_size_nchw);
322-
if (size < vec_size) {
321+
if (size < vec_size * block_size_nchw) {
323322
ScalarGetMeanAndVarNCHW<T><<<grids, blocks, 0, dev_ctx.stream()>>>(
324323
x_data, mean_data, temp_var_data, size);
325324
} else {
@@ -328,6 +327,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
328327
x_data, mean_data, temp_var_data, size);
329328
}
330329
} else {
330+
set_zero(dev_ctx, mean, static_cast<T>(0));
331+
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
331332
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
332333
x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
333334
temp_var_data);
@@ -424,24 +425,15 @@ __global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
424425
int i = blockIdx.x;
425426
AccT ds_sum = static_cast<AccT>(0);
426427
AccT db_sum = static_cast<AccT>(0);
427-
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
428428
x += i * imsize;
429+
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
429430

430431
phi::Array<const T*, 2> ins;
431432
ins[0] = x;
432433
ins[1] = dy;
433434
ThreadReduce<T, AccT, VecSize, 2>(ins, imsize, input_offset, &db_sum,
434435
&ds_sum);
435-
436-
ds_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
437-
ds_sum, kps::AddFunctor<AccT>());
438-
db_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
439-
db_sum, kps::AddFunctor<AccT>());
440-
__syncthreads();
441-
if (threadIdx.x == 0) {
442-
ds[i] = ds_sum;
443-
db[i] = db_sum;
444-
}
436+
ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1);
445437
}
446438

447439
template <typename T>
@@ -455,8 +447,7 @@ __global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
455447
ds_sum += dy[index] * x[index];
456448
db_sum += dy[index];
457449
}
458-
CudaAtomicAddWithWarp(&ds[nc], ds_sum);
459-
CudaAtomicAddWithWarp(&db[nc], db_sum);
450+
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1);
460451
}
461452

462453
template <typename T>
@@ -641,13 +632,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
641632
}
642633
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
643634
dim3 blocks(block_size_nchw);
644-
if (imsize < vec_size) {
645-
if (d_scale) {
646-
set_zero(dev_ctx, d_scale, static_cast<T>(0));
647-
}
648-
if (d_bias) {
649-
set_zero(dev_ctx, d_bias, static_cast<T>(0));
650-
}
635+
if (imsize < vec_size * block_size_nchw) {
651636
ScalarGetDsDbCUDAKernel<
652637
T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
653638
imsize, x_data, dy_data, ds_data, db_data);
@@ -687,7 +672,6 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
687672
imsize, C, group_size, groups, p1_data, p2_data, p3_data, x_data,
688673
dy_data, d_x_data);
689674
}
690-
691675
} else {
692676
if (d_scale) {
693677
set_zero(dev_ctx, d_scale, static_cast<T>(0));

0 commit comments

Comments
 (0)