Skip to content

Commit a065a24

Browse files
【2.0 API】Enhance affine grid operator (#26385)
* Enhance affine grid operator: 1. Add cuda kernel 2. Add align corners options test=develop * Move new affine_grid api to functional test=develop * Add CUDA kernel for affine_grid. test=develop * Add more unitest for grid sample API test=develop
1 parent 6f69fbc commit a065a24

File tree

6 files changed

+534
-26
lines changed

6 files changed

+534
-26
lines changed

paddle/fluid/operators/affine_grid_op.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ using Tensor = framework::Tensor;
2828

2929
template <typename T>
3030
struct Linspace<paddle::platform::CPUDeviceContext, T> {
31-
void operator()(T start, T end, int count, framework::Tensor* numbers,
31+
void operator()(T start, T end, int count, bool align_corners,
32+
framework::Tensor* numbers,
3233
const framework::ExecutionContext& ctx) {
3334
T* number_data = numbers->mutable_data<T>({count}, platform::CPUPlace());
3435
T slice = (end - start) / (T)(count - 1);
36+
if (!align_corners) {
37+
slice = (end - start) / (T)count;
38+
start *= (T)(count - 1) / (T)count;
39+
}
3540
for (int i = 0; i < count; ++i) {
3641
number_data[i] = start + (T)i * slice;
3742
}
@@ -130,6 +135,10 @@ class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
130135
"use_cudnn",
131136
"(bool, default false) Only used in cudnn kernel, need install cudnn")
132137
.SetDefault(true);
138+
AddAttr<bool>("align_corners",
139+
"(bool, default false) Whether to align the corners of input"
140+
"and ouput.")
141+
.SetDefault(true);
133142
AddAttr<std::vector<int>>(
134143
"output_shape",
135144
"The target output image shape with format [N, C, H, W].")
@@ -164,10 +173,12 @@ class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
164173
[-1. -0.5 0. 0.5 1. ]
165174
[-1. -0.5 0. 0.5 1. ]
166175
[-1. -0.5 0. 0.5 1. ]]]
167-
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
176+
C[0] is the coordinates in height axis and C[1] is the coordinates in
177+
width axis.
168178
169179
Step2:
170-
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
180+
Tanspose and reshape C to shape [H * W, 2] and append ones to last
181+
dimension. The we get:
171182
C_ = [[-1. -1. 1. ]
172183
[-0.5 -1. 1. ]
173184
[ 0. -1. 1. ]
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/affine_grid_op.h"
17+
#include "paddle/fluid/platform/cuda_device_function.h"
18+
#include "paddle/fluid/platform/gpu_info.h"
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
24+
template <typename T>
25+
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
26+
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
27+
}
28+
29+
template <typename T>
30+
struct Linspace<paddle::platform::CUDADeviceContext, T> {
31+
void operator()(T start, T end, int count, bool align_corners,
32+
framework::Tensor* numbers,
33+
const framework::ExecutionContext& ctx) {
34+
T* number_data = numbers->mutable_data<T>({count}, ctx.GetPlace());
35+
T slice = (end - start) / (T)(count - 1);
36+
if (!align_corners) {
37+
slice = (end - start) / (T)count;
38+
start *= (T)(count - 1) / (T)count;
39+
}
40+
auto stream = ctx.cuda_device_context().stream();
41+
int block = 512;
42+
int grid = (count + block - 1) / block;
43+
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, slice, count,
44+
number_data);
45+
}
46+
};
47+
48+
template <typename T>
49+
__global__ void affine_grid_kernel(const int count, int n, int out_h, int out_w,
50+
T h_start, T w_start, T h_step, T w_step,
51+
const T* theta, // N, 2, 3
52+
T* output) {
53+
CUDA_KERNEL_LOOP(index, count) {
54+
int w = index % out_w;
55+
int h = (index / out_w) % out_h;
56+
int n = index / (out_w * out_h);
57+
58+
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
59+
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
60+
61+
int theta_offset = n * 6; // 2 * 3;
62+
// affine from (h_coor, w_coor) to (x, y)
63+
output[index * 2] = theta[theta_offset] * h_coor +
64+
theta[theta_offset + 1] * w_coor +
65+
theta[theta_offset + 2];
66+
output[index * 2 + 1] = theta[theta_offset + 3] * h_coor +
67+
theta[theta_offset + 4] * w_coor +
68+
theta[theta_offset + 5];
69+
}
70+
}
71+
72+
template <typename T>
73+
__global__ void affine_grid_grad_kernel(const int count, int n, int out_h,
74+
int out_w, T h_start, T w_start,
75+
T h_step, T w_step,
76+
const T* out_grad, // N, H, W, 2
77+
T* theta_grad) { // N, 2, 3
78+
CUDA_KERNEL_LOOP(index, count) {
79+
int w = index % out_w;
80+
int h = (index / out_w) % out_h;
81+
int n = index / (out_w * out_h);
82+
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
83+
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
84+
85+
int theta_offset = n * 6; // 2 * 3;
86+
T out_grad_x = out_grad[index * 2];
87+
atomicAdd(theta_grad + theta_offset, out_grad_x * h_coor);
88+
atomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor);
89+
atomicAdd(theta_grad + theta_offset + 2, out_grad_x);
90+
91+
T out_grad_y = out_grad[index * 2 + 1];
92+
atomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor);
93+
atomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor);
94+
atomicAdd(theta_grad + theta_offset + 5, out_grad_y);
95+
}
96+
}
97+
98+
template <typename T>
99+
class AffineGridOpCUDAKernel : public framework::OpKernel<T> {
100+
public:
101+
void Compute(const framework::ExecutionContext& ctx) const override {
102+
auto* theta = ctx.Input<Tensor>("Theta");
103+
int n = theta->dims()[0];
104+
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
105+
auto align_corners = ctx.Attr<bool>("align_corners");
106+
int h = 0;
107+
int w = 0;
108+
if (size_attr.size() == 0) {
109+
auto* output_shape = ctx.Input<Tensor>("OutputShape");
110+
Tensor h_sizes;
111+
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
112+
const int* h_size_data = h_sizes.data<int>();
113+
h = h_size_data[2];
114+
w = h_size_data[3];
115+
} else {
116+
h = size_attr[2];
117+
w = size_attr[3];
118+
}
119+
auto* output = ctx.Output<Tensor>("Output");
120+
T* out_data = output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
121+
122+
T h_step;
123+
T w_step;
124+
T h_start = -1;
125+
T w_start = -1;
126+
if (align_corners) {
127+
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
128+
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
129+
} else {
130+
h_step = static_cast<T>(2) / static_cast<T>(h);
131+
w_step = static_cast<T>(2) / static_cast<T>(w);
132+
133+
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
134+
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
135+
}
136+
137+
const int count = n * h * w;
138+
int block = 512;
139+
int grid = (count + block - 1) / block;
140+
auto cu_stream = ctx.cuda_device_context().stream();
141+
affine_grid_kernel<<<grid, block, 0, cu_stream>>>(
142+
count, n, h, w, h_start, w_start, h_step, w_step,
143+
theta->data<T>(), // N, 2, 3
144+
out_data);
145+
}
146+
};
147+
148+
template <typename T>
149+
class AffineGridGradOpCUDAKernel : public framework::OpKernel<T> {
150+
public:
151+
void Compute(const framework::ExecutionContext& ctx) const override {
152+
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
153+
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
154+
int n = output_grad->dims()[0];
155+
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
156+
auto align_corners = ctx.Attr<bool>("align_corners");
157+
int h = 0;
158+
int w = 0;
159+
if (size_attr.size() == 0) {
160+
auto* output_shape = ctx.Input<Tensor>("OutputShape");
161+
Tensor h_sizes;
162+
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
163+
const int* h_size_data = h_sizes.data<int>();
164+
h = h_size_data[2];
165+
w = h_size_data[3];
166+
} else {
167+
h = size_attr[2];
168+
w = size_attr[3];
169+
}
170+
T* theta_grad_data = theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
171+
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
172+
ctx.cuda_device_context(), theta_grad, static_cast<T>(0));
173+
174+
T h_step;
175+
T w_step;
176+
T h_start = -1;
177+
T w_start = -1;
178+
if (align_corners) {
179+
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
180+
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
181+
} else {
182+
h_step = static_cast<T>(2) / static_cast<T>(h);
183+
w_step = static_cast<T>(2) / static_cast<T>(w);
184+
185+
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
186+
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
187+
}
188+
const int count = n * h * w;
189+
VLOG(3) << "count: " << count << "; h_step: " << h_step
190+
<< "; w_step: " << w_step << "; h_start: " << h_start
191+
<< "; w_start: " << w_start;
192+
int block = 512;
193+
int grid = (count + block - 1) / block;
194+
auto cu_stream = ctx.cuda_device_context().stream();
195+
affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>(
196+
count, n, h, w, h_start, w_start, h_step, w_step,
197+
output_grad->data<T>(), theta_grad_data);
198+
}
199+
};
200+
201+
} // namespace operators
202+
} // namespace paddle
203+
204+
namespace ops = paddle::operators;
205+
REGISTER_OP_CUDA_KERNEL(affine_grid, ops::AffineGridOpCUDAKernel<float>,
206+
ops::AffineGridOpCUDAKernel<double>);
207+
REGISTER_OP_CUDA_KERNEL(affine_grid_grad,
208+
ops::AffineGridGradOpCUDAKernel<float>,
209+
ops::AffineGridGradOpCUDAKernel<double>);

paddle/fluid/operators/affine_grid_op.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,33 @@ using Array4 = Eigen::DSizes<int64_t, 4>;
3737
*/
3838
template <typename DeviceContext, typename T>
3939
struct Linspace {
40-
void operator()(T start, T end, int count, framework::Tensor* numbers,
40+
void operator()(T start, T end, int count, bool align_corners,
41+
framework::Tensor* numbers,
4142
const framework::ExecutionContext& ctx);
4243
};
4344

4445
template <typename DeviceContext, typename T>
45-
inline void GetIdxMap(int n, int h, int w, Tensor* grid,
46+
inline void GetIdxMap(int n, int h, int w, bool align_corners, Tensor* grid,
4647
const framework::ExecutionContext& ctx) {
4748
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
4849
grid->mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
4950
auto grid_t = EigenTensor<T, 4>::From(*grid);
5051
// Get indexes of height with shape [height, width, 1]
5152
Tensor h_idx;
5253
Linspace<DeviceContext, T> linspace;
53-
linspace((T)-1, (T)1, h, &h_idx, ctx);
54+
linspace((T)-1, (T)1, h, align_corners, &h_idx, ctx);
5455
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
5556
// Get indexes of width with shape [height, width, 1]
5657
Tensor w_idx;
57-
linspace((T)-1, (T)1, w, &w_idx, ctx);
58+
linspace((T)-1, (T)1, w, align_corners, &w_idx, ctx);
5859
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
5960
// Get constant ones tensor with shape [height, width, 1]
6061
Tensor ones;
6162
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
62-
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
63+
64+
math::SetConstant<DeviceContext, T>()(
65+
ctx.template device_context<DeviceContext>(), &ones, static_cast<T>(1));
66+
auto ones_t = EigenTensor<T, 3>::From(ones);
6367
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
6468
// ones
6569
Tensor w_idx_map;
@@ -74,11 +78,9 @@ inline void GetIdxMap(int n, int h, int w, Tensor* grid,
7478
Tensor w_h_one_idx_map;
7579
w_h_one_idx_map.mutable_data<T>({h, w, 3}, ctx.GetPlace());
7680
auto w_h_one_idx_map_t = EigenTensor<T, 3>::From(w_h_one_idx_map);
77-
7881
w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w))
7982
.broadcast(Array2(h, 1))
8083
.reshape(Array3(h, w, 1));
81-
8284
h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h))
8385
.broadcast(Array2(w, 1))
8486
.shuffle(Array2(1, 0))
@@ -97,6 +99,7 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
9799
auto* theta = ctx.Input<Tensor>("Theta");
98100
int n = theta->dims()[0];
99101
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
102+
auto align_corners = ctx.Attr<bool>("align_corners");
100103
int h = 0;
101104
int w = 0;
102105
if (size_attr.size() == 0) {
@@ -116,7 +119,7 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
116119
ctx.template device_context<DeviceContext>(), output,
117120
static_cast<T>(0));
118121
Tensor grid;
119-
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
122+
GetIdxMap<DeviceContext, T>(n, h, w, align_corners, &grid, ctx);
120123
// output = grid * theta.T
121124
// TODO(wanghaoshuang): Refine batched matrix multiply
122125
auto blas = math::GetBlas<DeviceContext, T>(ctx);
@@ -140,6 +143,7 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
140143
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
141144
int n = output_grad->dims()[0];
142145
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
146+
auto align_corners = ctx.Attr<bool>("align_corners");
143147
int h = 0;
144148
int w = 0;
145149
if (size_attr.size() == 0) {
@@ -158,7 +162,7 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
158162
ctx.template device_context<DeviceContext>(), theta_grad,
159163
static_cast<T>(0));
160164
Tensor grid;
161-
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
165+
GetIdxMap<DeviceContext, T>(n, h, w, align_corners, &grid, ctx);
162166
// output = grid * theta.T
163167
// TODO(wanghaoshuang): Refine batched matrix multiply
164168
auto blas = math::GetBlas<DeviceContext, T>(ctx);

0 commit comments

Comments
 (0)