Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -858,15 +858,20 @@ void BatchNormGradRawKernel(const Context &ctx,
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
// CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_kernel) {
if (x_dims.size() == 2) {
}
// CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
: false;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) {
dim3 block;
dim3 grid;
const int block_size = 512;
Expand Down Expand Up @@ -937,6 +942,21 @@ void BatchNormGradRawKernel(const Context &ctx,
flag_ptr);
}
// 2. reduce_sum(x, dy, mean) => dscale, dbias
BatchNormParamType<T> *dscale = nullptr;
BatchNormParamType<T> *dbias = nullptr;
bool with_scale = false;
if (d_scale && d_bias) {
dscale = ctx.template Alloc<BatchNormParamType<T>>(d_scale);
dbias = ctx.template Alloc<BatchNormParamType<T>>(d_bias);
} else {
DenseTensor dscale_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
DenseTensor dbias_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
dscale = dscale_mem.data<BatchNormParamType<T>>();
dbias = dbias_mem.data<BatchNormParamType<T>>();
}

BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
Expand All @@ -948,8 +968,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D,
epsilon,
block_data_ptr,
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
dscale,
dbias,
flag_ptr);

// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
Expand All @@ -958,8 +978,8 @@ void BatchNormGradRawKernel(const Context &ctx,
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>(),
dscale,
dbias,
mean_ptr,
variance_ptr,
C,
Expand Down Expand Up @@ -1169,6 +1189,7 @@ void BatchNormGradRawKernel(const Context &ctx,
paddle::platform::dynload::cudnnDestroyTensorDescriptor(
bn_param_desc_));
#endif

} else {
const auto *running_mean = mean.get_ptr();
const auto *running_var = variance.get_ptr();
Expand Down