|
| 1 | +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#include <thrust/execution_policy.h> |
| 18 | +#include <thrust/remove.h> |
| 19 | +#include <thrust/sort.h> |
| 20 | +#include <thrust/unique.h> |
| 21 | + |
| 22 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
| 23 | +#include "paddle/phi/backends/gpu/gpu_info.h" |
| 24 | +#include "paddle/phi/backends/gpu/gpu_launch_config.h" |
| 25 | +#include "paddle/phi/kernels/funcs/index_impl.cu.h" |
| 26 | +#include "paddle/phi/kernels/sparse/convolution_kernel.h" |
| 27 | + |
| 28 | +namespace phi { |
| 29 | +namespace sparse { |
| 30 | + |
| 31 | +// TODO(zhangkaihuo): After the GatherCUDAKernel is migrated to phi, replace |
| 32 | +// this kernel with phi::GatherCUDAKernel; |
| 33 | +// Vectorization can be used to improve read and write bandwidth |
| 34 | +/** |
| 35 | + * brief: gather data from params according to indices |
| 36 | + * params: the inputs |
| 37 | + * indices: the indices you want to gather |
| 38 | + * output: the outputs |
| 39 | + * index_size: the size of indices |
| 40 | + * slice_size: slice size corresponding to each index, here is the channel size |
| 41 | +**/ |
| 42 | +template <typename T, typename IndexT = int> |
| 43 | +__global__ void GatherKernel(const T* params, |
| 44 | + const IndexT* indices, |
| 45 | + T* output, |
| 46 | + size_t index_size, |
| 47 | + size_t slice_size) { |
| 48 | + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { |
| 49 | + int64_t indices_i = i / slice_size; |
| 50 | + int64_t slice_i = i - indices_i * slice_size; // offset inside the slice |
| 51 | + IndexT gather_i = indices[indices_i]; |
| 52 | + int64_t params_i = gather_i * slice_size + slice_i; |
| 53 | + *(output + i) = *(params + params_i); |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +/** |
| 58 | + * brief: scatter add |
| 59 | + * input: the inputs |
| 60 | + * unique_value: refer to UpdateIndexKernel notes |
| 61 | + * out_index: the output feature index |
| 62 | + * non_zero_num: the number of output features |
| 63 | + * rulebook_len: the length of rulebook |
| 64 | + * channels: the output channel size |
| 65 | + * out: the outputs |
| 66 | +**/ |
| 67 | +template <typename T> |
| 68 | +__global__ void ScatterKernel(const T* input, |
| 69 | + const int* unique_value, |
| 70 | + const int* out_index, |
| 71 | + const int non_zero_num, |
| 72 | + const int rulebook_len, |
| 73 | + const int channels, |
| 74 | + T* out) { |
| 75 | + int tid = threadIdx.x + blockIdx.x * blockDim.x; |
| 76 | + for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) { |
| 77 | + int indices_i = i / channels; |
| 78 | + int channels_i = i - indices_i * channels; |
| 79 | + |
| 80 | + int start = unique_value[indices_i]; |
| 81 | + int end = indices_i == non_zero_num - 1 ? rulebook_len |
| 82 | + : unique_value[indices_i + 1]; |
| 83 | + // max(end-start) = kernel_size |
| 84 | + T sum = static_cast<T>(0); |
| 85 | + for (int j = start; j < end; j++) { |
| 86 | + const int out_feature_i = out_index[j]; |
| 87 | + sum += input[out_feature_i * channels + channels_i]; |
| 88 | + } |
| 89 | + out[indices_i * channels + channels_i] = sum; |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +template <typename Context> |
| 94 | +inline int* SortedAndUniqueIndex(const Context& dev_ctx, |
| 95 | + const int* rulebook_ptr, |
| 96 | + const int len, |
| 97 | + DenseTensor* out_index, |
| 98 | + DenseTensor* unique_key, |
| 99 | + DenseTensor* unique_value) { |
| 100 | + phi::IndexKernel<int, kps::IdentityFunctor<int>>( |
| 101 | + dev_ctx, out_index, kps::IdentityFunctor<int>()); |
| 102 | + phi::IndexKernel<int, kps::IdentityFunctor<int>>( |
| 103 | + dev_ctx, unique_value, kps::IdentityFunctor<int>()); |
| 104 | + |
| 105 | + phi::backends::gpu::GpuMemcpyAsync(unique_key->data<int>(), |
| 106 | + rulebook_ptr, |
| 107 | + sizeof(int) * len, |
| 108 | +#ifdef PADDLE_WITH_HIP |
| 109 | + hipMemcpyDeviceToDevice, |
| 110 | +#else |
| 111 | + cudaMemcpyDeviceToDevice, |
| 112 | +#endif |
| 113 | + dev_ctx.stream()); |
| 114 | +// compared with thrust::sort_by_key, thrust::merge_by_key may achieved higher |
| 115 | +// performance, but thrust::merge_by_key limited by data size |
| 116 | +#ifdef PADDLE_WITH_HIP |
| 117 | + thrust::sort_by_key(thrust::hip::par.on(dev_ctx.stream()), |
| 118 | +#else |
| 119 | + thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()), |
| 120 | +#endif |
| 121 | + unique_key->data<int>(), |
| 122 | + unique_key->data<int>() + len, |
| 123 | + out_index->data<int>()); |
| 124 | + |
| 125 | + // 4. unique |
| 126 | + thrust::pair<int*, int*> new_end = |
| 127 | +#ifdef PADDLE_WITH_HIP |
| 128 | + thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()), |
| 129 | +#else |
| 130 | + thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()), |
| 131 | +#endif |
| 132 | + unique_key->data<int>(), |
| 133 | + unique_key->data<int>() + len, |
| 134 | + unique_value->data<int>()); |
| 135 | + return new_end.first; |
| 136 | +} |
| 137 | + |
| 138 | +} // namespace sparse |
| 139 | +} // namespace phi |
0 commit comments