|
| 1 | +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/op_registry.h" |
| 16 | +#include "paddle/fluid/operators/affine_grid_op.h" |
| 17 | +#include "paddle/fluid/platform/cuda_device_function.h" |
| 18 | +#include "paddle/fluid/platform/gpu_info.h" |
| 19 | +namespace paddle { |
| 20 | +namespace operators { |
| 21 | + |
| 22 | +using Tensor = framework::Tensor; |
| 23 | + |
| 24 | +template <typename T> |
| 25 | +__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { |
| 26 | + CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } |
| 27 | +} |
| 28 | + |
| 29 | +template <typename T> |
| 30 | +struct Linspace<paddle::platform::CUDADeviceContext, T> { |
| 31 | + void operator()(T start, T end, int count, bool align_corners, |
| 32 | + framework::Tensor* numbers, |
| 33 | + const framework::ExecutionContext& ctx) { |
| 34 | + T* number_data = numbers->mutable_data<T>({count}, ctx.GetPlace()); |
| 35 | + T slice = (end - start) / (T)(count - 1); |
| 36 | + if (!align_corners) { |
| 37 | + slice = (end - start) / (T)count; |
| 38 | + start *= (T)(count - 1) / (T)count; |
| 39 | + } |
| 40 | + auto stream = ctx.cuda_device_context().stream(); |
| 41 | + int block = 512; |
| 42 | + int grid = (count + block - 1) / block; |
| 43 | + LinspaceKernel<T><<<grid, block, 0, stream>>>(start, slice, count, |
| 44 | + number_data); |
| 45 | + } |
| 46 | +}; |
| 47 | + |
| 48 | +template <typename T> |
| 49 | +__global__ void affine_grid_kernel(const int count, int n, int out_h, int out_w, |
| 50 | + T h_start, T w_start, T h_step, T w_step, |
| 51 | + const T* theta, // N, 2, 3 |
| 52 | + T* output) { |
| 53 | + CUDA_KERNEL_LOOP(index, count) { |
| 54 | + int w = index % out_w; |
| 55 | + int h = (index / out_w) % out_h; |
| 56 | + int n = index / (out_w * out_h); |
| 57 | + |
| 58 | + T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start); |
| 59 | + T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start); |
| 60 | + |
| 61 | + int theta_offset = n * 6; // 2 * 3; |
| 62 | + // affine from (h_coor, w_coor) to (x, y) |
| 63 | + output[index * 2] = theta[theta_offset] * h_coor + |
| 64 | + theta[theta_offset + 1] * w_coor + |
| 65 | + theta[theta_offset + 2]; |
| 66 | + output[index * 2 + 1] = theta[theta_offset + 3] * h_coor + |
| 67 | + theta[theta_offset + 4] * w_coor + |
| 68 | + theta[theta_offset + 5]; |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +template <typename T> |
| 73 | +__global__ void affine_grid_grad_kernel(const int count, int n, int out_h, |
| 74 | + int out_w, T h_start, T w_start, |
| 75 | + T h_step, T w_step, |
| 76 | + const T* out_grad, // N, H, W, 2 |
| 77 | + T* theta_grad) { // N, 2, 3 |
| 78 | + CUDA_KERNEL_LOOP(index, count) { |
| 79 | + int w = index % out_w; |
| 80 | + int h = (index / out_w) % out_h; |
| 81 | + int n = index / (out_w * out_h); |
| 82 | + T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start); |
| 83 | + T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start); |
| 84 | + |
| 85 | + int theta_offset = n * 6; // 2 * 3; |
| 86 | + T out_grad_x = out_grad[index * 2]; |
| 87 | + atomicAdd(theta_grad + theta_offset, out_grad_x * h_coor); |
| 88 | + atomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor); |
| 89 | + atomicAdd(theta_grad + theta_offset + 2, out_grad_x); |
| 90 | + |
| 91 | + T out_grad_y = out_grad[index * 2 + 1]; |
| 92 | + atomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor); |
| 93 | + atomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor); |
| 94 | + atomicAdd(theta_grad + theta_offset + 5, out_grad_y); |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +template <typename T> |
| 99 | +class AffineGridOpCUDAKernel : public framework::OpKernel<T> { |
| 100 | + public: |
| 101 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 102 | + auto* theta = ctx.Input<Tensor>("Theta"); |
| 103 | + int n = theta->dims()[0]; |
| 104 | + auto size_attr = ctx.Attr<std::vector<int>>("output_shape"); |
| 105 | + auto align_corners = ctx.Attr<bool>("align_corners"); |
| 106 | + int h = 0; |
| 107 | + int w = 0; |
| 108 | + if (size_attr.size() == 0) { |
| 109 | + auto* output_shape = ctx.Input<Tensor>("OutputShape"); |
| 110 | + Tensor h_sizes; |
| 111 | + framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); |
| 112 | + const int* h_size_data = h_sizes.data<int>(); |
| 113 | + h = h_size_data[2]; |
| 114 | + w = h_size_data[3]; |
| 115 | + } else { |
| 116 | + h = size_attr[2]; |
| 117 | + w = size_attr[3]; |
| 118 | + } |
| 119 | + auto* output = ctx.Output<Tensor>("Output"); |
| 120 | + T* out_data = output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace()); |
| 121 | + |
| 122 | + T h_step; |
| 123 | + T w_step; |
| 124 | + T h_start = -1; |
| 125 | + T w_start = -1; |
| 126 | + if (align_corners) { |
| 127 | + h_step = static_cast<T>(2) / static_cast<T>(h - 1); |
| 128 | + w_step = static_cast<T>(2) / static_cast<T>(w - 1); |
| 129 | + } else { |
| 130 | + h_step = static_cast<T>(2) / static_cast<T>(h); |
| 131 | + w_step = static_cast<T>(2) / static_cast<T>(w); |
| 132 | + |
| 133 | + h_start *= static_cast<T>(h - 1) / static_cast<T>(h); |
| 134 | + w_start *= static_cast<T>(w - 1) / static_cast<T>(w); |
| 135 | + } |
| 136 | + |
| 137 | + const int count = n * h * w; |
| 138 | + int block = 512; |
| 139 | + int grid = (count + block - 1) / block; |
| 140 | + auto cu_stream = ctx.cuda_device_context().stream(); |
| 141 | + affine_grid_kernel<<<grid, block, 0, cu_stream>>>( |
| 142 | + count, n, h, w, h_start, w_start, h_step, w_step, |
| 143 | + theta->data<T>(), // N, 2, 3 |
| 144 | + out_data); |
| 145 | + } |
| 146 | +}; |
| 147 | + |
| 148 | +template <typename T> |
| 149 | +class AffineGridGradOpCUDAKernel : public framework::OpKernel<T> { |
| 150 | + public: |
| 151 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 152 | + auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output")); |
| 153 | + auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta")); |
| 154 | + int n = output_grad->dims()[0]; |
| 155 | + auto size_attr = ctx.Attr<std::vector<int>>("output_shape"); |
| 156 | + auto align_corners = ctx.Attr<bool>("align_corners"); |
| 157 | + int h = 0; |
| 158 | + int w = 0; |
| 159 | + if (size_attr.size() == 0) { |
| 160 | + auto* output_shape = ctx.Input<Tensor>("OutputShape"); |
| 161 | + Tensor h_sizes; |
| 162 | + framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); |
| 163 | + const int* h_size_data = h_sizes.data<int>(); |
| 164 | + h = h_size_data[2]; |
| 165 | + w = h_size_data[3]; |
| 166 | + } else { |
| 167 | + h = size_attr[2]; |
| 168 | + w = size_attr[3]; |
| 169 | + } |
| 170 | + T* theta_grad_data = theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace()); |
| 171 | + math::SetConstant<paddle::platform::CUDADeviceContext, T>()( |
| 172 | + ctx.cuda_device_context(), theta_grad, static_cast<T>(0)); |
| 173 | + |
| 174 | + T h_step; |
| 175 | + T w_step; |
| 176 | + T h_start = -1; |
| 177 | + T w_start = -1; |
| 178 | + if (align_corners) { |
| 179 | + h_step = static_cast<T>(2) / static_cast<T>(h - 1); |
| 180 | + w_step = static_cast<T>(2) / static_cast<T>(w - 1); |
| 181 | + } else { |
| 182 | + h_step = static_cast<T>(2) / static_cast<T>(h); |
| 183 | + w_step = static_cast<T>(2) / static_cast<T>(w); |
| 184 | + |
| 185 | + h_start *= static_cast<T>(h - 1) / static_cast<T>(h); |
| 186 | + w_start *= static_cast<T>(w - 1) / static_cast<T>(w); |
| 187 | + } |
| 188 | + const int count = n * h * w; |
| 189 | + VLOG(3) << "count: " << count << "; h_step: " << h_step |
| 190 | + << "; w_step: " << w_step << "; h_start: " << h_start |
| 191 | + << "; w_start: " << w_start; |
| 192 | + int block = 512; |
| 193 | + int grid = (count + block - 1) / block; |
| 194 | + auto cu_stream = ctx.cuda_device_context().stream(); |
| 195 | + affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>( |
| 196 | + count, n, h, w, h_start, w_start, h_step, w_step, |
| 197 | + output_grad->data<T>(), theta_grad_data); |
| 198 | + } |
| 199 | +}; |
| 200 | + |
| 201 | +} // namespace operators |
| 202 | +} // namespace paddle |
| 203 | + |
| 204 | +namespace ops = paddle::operators; |
| 205 | +REGISTER_OP_CUDA_KERNEL(affine_grid, ops::AffineGridOpCUDAKernel<float>, |
| 206 | + ops::AffineGridOpCUDAKernel<double>); |
| 207 | +REGISTER_OP_CUDA_KERNEL(affine_grid_grad, |
| 208 | + ops::AffineGridGradOpCUDAKernel<float>, |
| 209 | + ops::AffineGridGradOpCUDAKernel<double>); |
0 commit comments