Skip to content

Commit 9e1f762

Browse files
authored
Optimize bilinear_interp backward (#39423)
* bilinear_bw init * optimize code * optimize * optimize 2 * optimize functions * modify func name
1 parent 2c21d24 commit 9e1f762

File tree

1 file changed

+108
-77
lines changed

1 file changed

+108
-77
lines changed

paddle/fluid/operators/interpolate_v2_op.cu

Lines changed: 108 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
6161

6262
template <typename T>
6363
__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
64-
int* in_img_idx, int* w_id, T* w1lambda, T* w2lambda, T src_w,
65-
const int in_img_w) {
66-
src_w = (src_w > 0) ? src_w : 0.f;
67-
*in_img_idx = static_cast<int>(src_w);
68-
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0;
69-
*w1lambda = src_w - *in_img_idx;
70-
*w2lambda = 1.f - *w1lambda;
64+
int* in_img_idx, int* x_id, T* lambda1, T* lambda2, T src_x,
65+
const int in_img_x) {
66+
src_x = (src_x > 0) ? src_x : 0.f;
67+
*in_img_idx = static_cast<int>(src_x);
68+
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
69+
*lambda1 = src_x - *in_img_idx;
70+
*lambda2 = 1.f - *lambda1;
7171
}
7272

7373
struct FastDivModForInterpolate {
@@ -670,83 +670,102 @@ __global__ void KeBilinearInterpBwShareMemory(
670670
}
671671
}
672672

673+
__device__ __forceinline__ int GetInputIndex(const size_t nc, const int height,
674+
const int width, const int h,
675+
const int w) {
676+
return (nc * height + h) * width + w;
677+
}
678+
679+
template <typename T>
680+
__global__ void KeBilinearInterpNCHWBw(T* in, const int in_h, const int in_w,
681+
const int out_h, const int out_w,
682+
const int n, const int num_channels,
683+
float ratio_h, float ratio_w,
684+
const T* __restrict__ out,
685+
const T align_type_value) {
686+
int index = threadIdx.x + blockDim.x * blockIdx.x;
687+
int stride = blockDim.x * gridDim.x;
688+
int num_out = n * num_channels * out_h * out_w;
689+
int num_in = n * num_channels * in_h * in_w;
690+
691+
for (; index < num_out; index += stride) {
692+
int index_tmp = index;
693+
int w2 = index_tmp % out_w;
694+
index_tmp /= out_w;
695+
int h2 = index_tmp % out_h;
696+
int nc = index_tmp / out_h;
697+
698+
int h1, y_id;
699+
T h1lambda, h0lambda;
700+
T src_y = ratio_h * (h2 + align_type_value) - align_type_value;
701+
702+
PreCalculatorForLinearInterpInputIndex(&h1, &y_id, &h1lambda, &h0lambda,
703+
src_y, in_h);
704+
int w1, x_id;
705+
T w1lambda, w0lambda;
706+
T src_x = ratio_w * (w2 + align_type_value) - align_type_value;
707+
PreCalculatorForLinearInterpInputIndex(&w1, &x_id, &w1lambda, &w0lambda,
708+
src_x, in_w);
709+
710+
T d2val = out[index];
711+
712+
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
713+
h0lambda * w0lambda * d2val);
714+
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
715+
h0lambda * w1lambda * d2val);
716+
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
717+
h1lambda * w0lambda * d2val);
718+
platform::CudaAtomicAdd(
719+
in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
720+
h1lambda * w1lambda * d2val);
721+
}
722+
}
723+
673724
template <typename T>
674725
__global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
675726
const T* __restrict__ out, const int out_h,
676727
const int out_w, const int n,
677-
const int num_channels, float ratio_h,
678-
float ratio_w, const T align_type_value,
679-
bool is_nchw) {
728+
const int out_chw, const int num_channels,
729+
float ratio_h, float ratio_w,
730+
const T align_type_value,
731+
FastDivModForInterpolate divmods) {
680732
int tid = blockIdx.x * blockDim.x + threadIdx.x;
681733
int stride = blockDim.x * gridDim.x;
682734
int in_chw = in_h * in_w * num_channels;
683-
int out_chw = num_channels * out_h * out_w;
684735
int nthreads = n * out_chw;
685736

686-
if (is_nchw) {
687-
for (; tid < nthreads; tid += stride) {
688-
int out_id_h = tid / out_chw;
689-
int out_id_w = tid % out_chw;
690-
const int in_img_size = in_h * in_w;
691-
const int out_img_size = out_h * out_w;
692-
T value = out[out_id_h * out_chw + out_id_w];
693-
694-
int channel_id = out_id_w / out_img_size;
695-
int out_img_idy = (out_id_w % out_img_size) / out_w;
696-
int out_img_idx = tid % out_w;
697-
int in_img_idx, in_img_idy, w_id, h_id;
698-
T w1lambda, h1lambda, w2lambda, h2lambda;
699-
700-
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
701-
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
702-
703-
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
704-
&w2lambda, src_w, in_w);
705-
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
706-
&h2lambda, src_h, in_h);
707-
708-
T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size +
709-
in_img_idy * in_w + in_img_idx];
710-
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
711-
platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * value);
712-
platform::CudaAtomicAdd(&in_pos[h_id * in_w],
713-
h1lambda * w2lambda * value);
714-
platform::CudaAtomicAdd(&in_pos[h_id * in_w + w_id],
715-
h1lambda * w1lambda * value);
716-
}
717-
} else {
718-
for (; tid < nthreads; tid += stride) {
719-
int out_id_h = tid / out_chw;
720-
int out_id_w = tid % out_chw;
721-
const int in_img_size = in_h * in_w;
722-
const int out_img_size = out_h * out_w;
723-
T value = out[out_id_h * out_chw + out_id_w];
724-
725-
int out_img_idy = out_id_w / (out_w * num_channels);
726-
int out_img_idx = out_id_w % (out_w * num_channels) / num_channels;
727-
int channel_id = tid % num_channels;
728-
729-
int in_img_idx, in_img_idy, w_id, h_id;
730-
T w1lambda, h1lambda, w2lambda, h2lambda;
731-
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
732-
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
733-
734-
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
735-
&w2lambda, src_w, in_w);
736-
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
737-
&h2lambda, src_h, in_h);
738-
739-
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
740-
in_img_idx * num_channels + channel_id];
741-
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
742-
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
743-
h2lambda * w1lambda * value);
744-
platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
745-
h1lambda * w2lambda * value);
746-
platform::CudaAtomicAdd(
747-
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
748-
h1lambda * w1lambda * value);
749-
}
737+
for (; tid < nthreads; tid += stride) {
738+
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
739+
int out_id_h = out_id_divmod.val[0];
740+
int out_id_w = out_id_divmod.val[1];
741+
742+
int channel_id = divmods.channels_div.Divmod(tid).val[1];
743+
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
744+
int out_img_idy = outimg_id_divmod.val[0];
745+
int out_img_idx =
746+
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
747+
748+
int in_img_idx, in_img_idy, w_id, h_id;
749+
T w1lambda, h1lambda, w2lambda, h2lambda;
750+
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
751+
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
752+
753+
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
754+
&w2lambda, src_w, in_w);
755+
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
756+
&h2lambda, src_h, in_h);
757+
758+
T value = out[tid];
759+
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
760+
in_img_idx * num_channels + channel_id];
761+
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
762+
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
763+
h2lambda * w1lambda * value);
764+
platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
765+
h1lambda * w2lambda * value);
766+
platform::CudaAtomicAdd(
767+
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
768+
h1lambda * w1lambda * value);
750769
}
751770
}
752771

@@ -1907,11 +1926,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
19071926
ctx.cuda_device_context().stream()>>>(
19081927
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
19091928
ratio_h, ratio_w, align_type_value, is_nchw);
1929+
} else if (!optimize_flag & is_nchw) {
1930+
//
1931+
const int num_kernels = n * c * out_h * out_w;
1932+
const int num_threads =
1933+
std::min(ctx.cuda_device_context().GetMaxThreadsPerBlock(), 1024);
1934+
KeBilinearInterpNCHWBw<
1935+
T><<<platform::DivUp(num_kernels, num_threads), num_threads, 0,
1936+
ctx.cuda_device_context().stream()>>>(
1937+
input_grad_data, in_h, in_w, out_h, out_w, n, c, ratio_h, ratio_w,
1938+
output_grad_data, align_type_value);
19101939
} else {
1940+
int64_t cw = c * out_w;
1941+
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
19111942
KeBilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
19121943
ctx.cuda_device_context().stream()>>>(
1913-
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
1914-
ratio_h, ratio_w, align_type_value, is_nchw);
1944+
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n,
1945+
out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods);
19151946
}
19161947
} else if ("bicubic" == interp_method) {
19171948
#ifdef __HIPCC__

0 commit comments

Comments
 (0)