Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ namespace tensorrt {
namespace plugin {
using DataLayout = phi::DataLayout;

static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; }
template <typename T>
static inline T divUp(T m, T n) {
return (m + n - 1) / n;
}

static inline __device__ __host__ float sigmoid(float x) {
return 1.F / (1.F + expf(-x));
Expand All @@ -54,12 +57,12 @@ struct GroupSumsOp {
}
};

static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
static int64_t findMaxDivisor(int64_t n, int64_t maxAllowedDivisor) {
int64_t maxDivisor = -1;
for (int64_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
int64_t divisor1 = n / i;
int64_t divisor2 = i;

if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
Expand Down Expand Up @@ -90,9 +93,9 @@ __global__ void groupNormNCHW32SumKernelQDQ(
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;

// The first activation loaded by that block.
int32_t dhwBegin = blockIdx.y * params.dhwPerBlock;
int64_t dhwBegin = blockIdx.y * params.dhwPerBlock;
// The last activation loaded by that block.
int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);
int64_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);

// The sums.
float sum = 0.F;
Expand All @@ -102,11 +105,10 @@ __global__ void groupNormNCHW32SumKernelQDQ(

// nchw32 layout
// batch offset + channel offset
int nc_offset = static_cast<int64_t>(ni) * params.dhwc +
ci / 32 * params.dhw * 32 + ci % 32;
int64_t nc_offset = ni * params.dhwc + ci / 32 * params.dhw * 32 + ci % 32;

// Iterate over the activations to compute the sums.
for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
for (int64_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
// The offset.
int64_t offset = nc_offset + static_cast<int64_t>(dhwi) * 32;

Expand Down Expand Up @@ -233,15 +235,15 @@ __global__ void groupNormNCHW32ScaleKernelQDQ(
float invStdDev = rsqrtf(var + params.eps);

// The first activation loaded by that block.
int32_t dhwBegin = blockIdx.y * params.dhwPerBlock;
int64_t dhwBegin = blockIdx.y * params.dhwPerBlock;
// The last activation loaded by that block.
int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);
int64_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);

// nchw32 layout
int c_offset = ci / 32 * params.dhw * 32 + ci % 32;
int64_t c_offset = ci / 32 * params.dhw * 32 + ci % 32;

// Iterate over the activations to compute the sums.
for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
for (int64_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
// The src/dst offset.
int64_t offset = static_cast<int64_t>(ni) * params.dhwc + c_offset +
static_cast<int64_t>(dhwi) * 32;
Expand Down Expand Up @@ -581,7 +583,7 @@ int GroupNormPluginDynamic::enqueue(
const auto input_ddim = common::make_ddim(input_shape);

int C = input_shape[1];
int image_size = input_shape[2] * input_shape[3];
int64_t image_size = input_shape[2] * input_shape[3];
int batchSize = input_shape[0];

PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -693,8 +695,9 @@ int GroupNormPluginDynamic::enqueue(
// params_.w = input_desc[0].dims.d[3];
// params_.c = input_desc[0].dims.d[1];
params_.groups = groups_;
params_.dhw = params_.d * params_.h * params_.w;
const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW);
params_.dhw = static_cast<int64_t>(params_.d) * params_.h * params_.w;
const int64_t blocksPerDHW =
findMaxDivisor(params_.dhw, static_cast<int64_t>(maxBlocksPerDHW));
params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW);
params_.cPerBlock = cPerBlock;
params_.cPerGroup = params_.c / params_.groups;
Expand Down Expand Up @@ -774,8 +777,9 @@ int GroupNormPluginDynamic::enqueue(
// params_.w = input_desc[0].dims.d[3];
// params_.c = input_desc[0].dims.d[1];
params_.groups = groups_;
params_.dhw = params_.d * params_.h * params_.w;
const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW);
params_.dhw = static_cast<int64_t>(params_.d) * params_.h * params_.w;
const int64_t blocksPerDHW =
findMaxDivisor(params_.dhw, static_cast<int64_t>(maxBlocksPerDHW));
params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW);
params_.cPerBlock = cPerBlock;
params_.cPerGroup = params_.c / params_.groups;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,17 @@ nvinfer1::DataType PrelnGroupnormActPluginDynamic::getOutputDataType(

int PrelnGroupnormActPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }

static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; }
template <typename T>
static inline T divUp(T m, T n) {
return (m + n - 1) / n;
}

static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
static int64_t findMaxDivisor(int64_t n, int64_t maxAllowedDivisor) {
int64_t maxDivisor = -1;
for (int64_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
int64_t divisor1 = n / i;
int64_t divisor2 = i;

if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
Expand Down Expand Up @@ -137,16 +140,16 @@ __global__ void prelnGroupNormNDHWCSumKernel(
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;

// The first activation loaded by that block.
int32_t dhwBegin = blockIdx.y * params.dhwPerBlock;
int64_t dhwBegin = blockIdx.y * params.dhwPerBlock;
// The last activation loaded by that block.
int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);
int64_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);

// The sums.
float sum = 0.F;
float sumSq = 0.F;

// Iterate over the activations to compute the sums.
for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
for (int64_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.dhwc +
static_cast<int64_t>(dhwi) * params.c + ci;
Expand Down Expand Up @@ -306,12 +309,12 @@ __global__ void prelnGroupNormNDHWCScaleKernel(
float invStdDev = rsqrtf(var + params.eps);

// The first activation loaded by that block.
int32_t dhwBegin = blockIdx.y * params.dhwPerBlock;
int64_t dhwBegin = blockIdx.y * params.dhwPerBlock;
// The last activation loaded by that block.
int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);
int64_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);

// Iterate over the activations to compute the sums.
for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
for (int64_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.dhwc + dhwi * params.c + ci;

Expand Down Expand Up @@ -465,8 +468,9 @@ int PrelnGroupnormActPluginDynamic::enqueue(
// params_.w = input_desc[0].dims.d[3];
// params_.c = input_desc[0].dims.d[1];
params_.groups = groups_;
params_.dhw = params_.d * params_.h * params_.w;
const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW);
params_.dhw = static_cast<int64_t>(params_.d) * params_.h * params_.w;
const int64_t blocksPerDHW =
findMaxDivisor(params_.dhw, static_cast<int64_t>(maxBlocksPerDHW));
params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW);
params_.cPerBlock = cPerBlock;
params_.cPerGroup = params_.c / params_.groups;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,17 @@ nvinfer1::DataType SkipGroupnormActPluginDynamic::getOutputDataType(
}
int SkipGroupnormActPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }

static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; }
template <typename T>
static inline T divUp(T m, T n) {
return (m + n - 1) / n;
}

static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
static int64_t findMaxDivisor(int64_t n, int64_t maxAllowedDivisor) {
int64_t maxDivisor = -1;
for (int64_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
int64_t divisor1 = n / i;
int64_t divisor2 = i;

if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
Expand Down Expand Up @@ -148,16 +151,16 @@ __global__ void skipGroupNormNDHWCSumKernel(
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;

// The first activation loaded by that block.
int32_t dhwBegin = blockIdx.y * params.dhwPerBlock;
int64_t dhwBegin = blockIdx.y * params.dhwPerBlock;
// The last activation loaded by that block.
int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);
int64_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);

// The sums.
float sum = 0.F;
float sumSq = 0.F;

// Iterate over the activations to compute the sums.
for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
for (int64_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.dhwc +
static_cast<int64_t>(dhwi) * params.c + ci;
Expand Down Expand Up @@ -318,12 +321,12 @@ __global__ void skipGroupNormNDHWCScaleKernel(
float invStdDev = rsqrtf(var + params.eps);

// The first activation loaded by that block.
int32_t dhwBegin = blockIdx.y * params.dhwPerBlock;
int64_t dhwBegin = blockIdx.y * params.dhwPerBlock;
// The last activation loaded by that block.
int32_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);
int64_t dhwEnd = min(dhwBegin + params.dhwPerBlock, params.dhw);

// Iterate over the activations to compute the sums.
for (int32_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
for (int64_t dhwi = dhwBegin; dhwi < dhwEnd; ++dhwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.dhwc + dhwi * params.c + ci;

Expand Down Expand Up @@ -475,8 +478,9 @@ int SkipGroupnormActPluginDynamic::enqueue(
// params_.w = input_desc[0].dims.d[3];
// params_.c = input_desc[0].dims.d[1];
params_.groups = groups_;
params_.dhw = params_.d * params_.h * params_.w;
const int32_t blocksPerDHW = findMaxDivisor(params_.dhw, maxBlocksPerDHW);
params_.dhw = static_cast<int64_t>(params_.d) * params_.h * params_.w;
const int64_t blocksPerDHW =
findMaxDivisor(params_.dhw, static_cast<int64_t>(maxBlocksPerDHW));
params_.dhwPerBlock = divUp(params_.dhw, blocksPerDHW);
params_.cPerBlock = cPerBlock;
params_.cPerGroup = params_.c / params_.groups;
Expand Down
66 changes: 34 additions & 32 deletions paddle/phi/kernels/gpu/group_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,27 +181,27 @@ __global__ void GetScaleBiasGradientCUDAKernel(int64_t N,
const AccT* db,
T* d_scale,
T* d_bias) {
// TODO(guoxiangmin) :add check when C / block >= gridDim.x
const int64_t c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < C) {
const int G = group;
const int64_t D = C / G;
AccT sum1 = static_cast<AccT>(0);
AccT sum2 = static_cast<AccT>(0);
for (int64_t n = 0; n < N; ++n) {
const int64_t nc = n * C + c;
const int64_t ng = n * G + c / D;
sum1 +=
(d_scale == nullptr)
? AccT(0)
: ((ds[nc] - db[nc] * (mean[ng])) * (rsqrt((var[ng]) + epsilon)));
sum2 += (d_bias == nullptr) ? AccT(0) : db[nc];
}
if (d_scale != nullptr) {
d_scale[c] = static_cast<T>(sum1);
}
if (d_bias != nullptr) {
d_bias[c] = static_cast<T>(sum2);
for (int64_t c = blockIdx.x * blockDim.x + threadIdx.x; c < C;
c += gridDim.x * blockDim.x) {
if (c < C) {
const int G = group;
const int64_t D = C / G;
AccT sum1 = static_cast<AccT>(0);
AccT sum2 = static_cast<AccT>(0);
for (int64_t n = 0; n < N; ++n) {
const int64_t nc = n * C + c;
const int64_t ng = n * G + c / D;
sum1 += (d_scale == nullptr) ? AccT(0)
: ((ds[nc] - db[nc] * (mean[ng])) *
(rsqrt((var[ng]) + epsilon)));
sum2 += (d_bias == nullptr) ? AccT(0) : db[nc];
}
if (d_scale != nullptr) {
d_scale[c] = static_cast<T>(sum1);
}
if (d_bias != nullptr) {
d_bias[c] = static_cast<T>(sum2);
}
}
}
}
Expand Down Expand Up @@ -407,17 +407,19 @@ void GroupNormGradKernel(const Context& dev_ctx,
if (d_scale || d_bias) {
const int block = 256;
GetScaleBiasGradientCUDAKernel<T, AccT>
<<<(C + block - 1) / block, block, 0, dev_ctx.stream()>>>(
x_dims[0],
C,
groups,
epsilon,
mean_data,
var_data,
ds_data,
db_data,
d_scale_data,
d_bias_data);
<<<std::min(max_grid_x, (C + block - 1) / block),
block,
0,
dev_ctx.stream()>>>(x_dims[0],
C,
groups,
epsilon,
mean_data,
var_data,
ds_data,
db_data,
d_scale_data,
d_bias_data);
}

if (d_x_data != nullptr) {
Expand Down
Loading