Skip to content

Commit 60b86b2

Browse files
author
zhangkaihuo
authored
Sparse Conv3d gpu backward (#40143)
Sparse conv3d backward(gpu)
1 parent 3e9601b commit 60b86b2

File tree

9 files changed

+430
-215
lines changed

9 files changed

+430
-215
lines changed

paddle/phi/kernels/sparse/convolution_grad_kernel.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
4545
const std::vector<int>& dilations,
4646
const std::vector<int>& strides,
4747
const int groups) {
48-
DenseTensor x_grad = phi::Empty<T, Context>(dev_ctx);
49-
DenseTensor kernel_grad = phi::Empty<T, Context>(dev_ctx);
48+
DenseTensor x_grad =
49+
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
50+
DenseTensor kernel_grad = phi::Empty<Context>(
51+
dev_ctx, DenseTensorMeta(kernel.dtype(), {1}, kernel.layout()));
5052
// TODO(zhangkaihuo): call InferMeta func here
5153
Conv3dGradKernel<T, Context>(dev_ctx,
5254
x,

paddle/phi/kernels/sparse/convolution_kernel.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,6 @@ limitations under the License. */
2020
#include "paddle/phi/kernels/empty_kernel.h"
2121

2222
namespace phi {
23-
24-
template <typename T, typename Context>
25-
DenseTensor Empty(const Context& dev_ctx) {
26-
phi::DenseTensor dense_out(
27-
phi::make_intrusive<paddle::experimental::SharedStorage>(
28-
dev_ctx.GetPlace()),
29-
{paddle::experimental::CppTypeToDataType<T>::Type(),
30-
{-1},
31-
DataLayout::NCHW});
32-
return dense_out;
33-
}
34-
3523
namespace sparse {
3624

3725
struct Dims4D {
@@ -149,8 +137,10 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
149137
const std::vector<int>& strides,
150138
const int groups,
151139
DenseTensor* rulebook) {
152-
DenseTensor indices = phi::Empty<T, Context>(dev_ctx);
153-
DenseTensor values = phi::Empty<T, Context>(dev_ctx);
140+
DenseTensor indices = phi::Empty<Context>(
141+
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
142+
DenseTensor values =
143+
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
154144
SparseCooTensor coo(indices, values, x.dims());
155145
Conv3dKernel<T, Context>(
156146
dev_ctx, x, kernel, paddings, dilations, strides, groups, &coo, rulebook);

paddle/phi/kernels/sparse/cpu/convolution.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ void ProductRuleBook(const Context& dev_ctx,
4545
const int64_t non_zero_num = x.nnz();
4646
const auto& non_zero_indices = x.non_zero_indices();
4747
const int* indices_ptr = non_zero_indices.data<int>();
48-
dev_ctx.Alloc(counter_per_kernel,
49-
counter_per_kernel->dtype(),
50-
sizeof(int) * counter_per_kernel->numel());
5148
int* counter_ptr = counter_per_kernel->data<int>();
5249
int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
5350
memset(counter_ptr, 0, kernel_size * sizeof(int));
@@ -138,8 +135,6 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx,
138135
x.dtype(), {out_non_zero_num, out_channels}, x.layout());
139136
phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta));
140137
phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta));
141-
dev_ctx.Alloc(
142-
&out_indices, out_indices.dtype(), out_indices.numel() * sizeof(int));
143138
int* out_indices_ptr = out_indices.data<int>();
144139
int i = 0;
145140
for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) {

paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
1616
#include "paddle/phi/kernels/funcs/blas/blas.h"
17+
#include "paddle/phi/kernels/funcs/math_function.h"
1718
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
1819

1920
namespace phi {
@@ -60,15 +61,8 @@ void Conv3dGradKernel(const Context& dev_ctx,
6061
phi::DenseTensor out_grad_features =
6162
phi::Empty(dev_ctx, std::move(out_grad_features_meta));
6263

63-
dev_ctx.Alloc(
64-
&in_features, in_features.dtype(), sizeof(T) * in_features.numel());
6564
T* in_features_ptr = in_features.data<T>();
66-
dev_ctx.Alloc(
67-
&d_x_features, d_x_features.dtype(), sizeof(T) * d_x_features.numel());
6865
T* d_x_features_ptr = d_x_features.data<T>();
69-
dev_ctx.Alloc(&out_grad_features,
70-
out_grad_features.dtype(),
71-
sizeof(T) * out_grad_features.numel());
7266
T* out_grad_features_ptr = out_grad_features.data<T>();
7367
kernel_grad->Resize(kernel_dims);
7468
dev_ctx.Alloc(
@@ -156,12 +150,11 @@ void Conv3dGradKernel(const Context& dev_ctx,
156150
} // namespace sparse
157151
} // namespace phi
158152

159-
PD_REGISTER_KERNEL(sparse_conv_grad,
153+
PD_REGISTER_KERNEL(sparse_conv3d_grad,
160154
CPU,
161155
ALL_LAYOUT,
162156
phi::sparse::Conv3dGradKernel,
163157
float,
164158
double) {
165159
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
166-
kernel->InputAt(3).SetDataLayout(phi::DataLayout::SPARSE_COO);
167160
}

paddle/phi/kernels/sparse/cpu/convolution_kernel.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ void Conv3dKernel(const Context& dev_ctx,
8181
phi::Empty(dev_ctx, std::move(in_features_meta));
8282
phi::DenseTensor out_features =
8383
phi::Empty(dev_ctx, std::move(out_features_meta));
84-
dev_ctx.Alloc(&in_features, x.dtype(), sizeof(T) * in_features.numel());
85-
dev_ctx.Alloc(&out_features, x.dtype(), sizeof(T) * out_features.numel());
8684
T* in_features_ptr = in_features.data<T>();
8785
T* out_features_ptr = out_features.data<T>();
8886

@@ -128,9 +126,6 @@ void Conv3dKernel(const Context& dev_ctx,
128126
}
129127

130128
// 4. scatter
131-
dev_ctx.Alloc(out->mutable_non_zero_elements(),
132-
out->mutable_non_zero_elements()->dtype(),
133-
sizeof(T) * in_features.numel());
134129
T* out_values_ptr = out->mutable_non_zero_elements()->data<T>();
135130
memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels);
136131
Scatter<T>(out_features_ptr,
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <thrust/execution_policy.h>
18+
#include <thrust/remove.h>
19+
#include <thrust/sort.h>
20+
#include <thrust/unique.h>
21+
22+
#include "paddle/phi/backends/gpu/gpu_context.h"
23+
#include "paddle/phi/backends/gpu/gpu_info.h"
24+
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
25+
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
26+
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
27+
28+
namespace phi {
29+
namespace sparse {
30+
31+
// TODO(zhangkaihuo): After the GatherCUDAKernel is migrated to phi, replace
32+
// this kernel with phi::GatherCUDAKernel;
33+
// Vectorization can be used to improve read and write bandwidth
34+
/**
35+
* brief: gather data from params according to indices
36+
* params: the inputs
37+
* indices: the indices you want to gather
38+
* output: the outputs
39+
* index_size: the size of indices
40+
* slice_size: slice size corresponding to each index, here is the channel size
41+
**/
42+
template <typename T, typename IndexT = int>
43+
__global__ void GatherKernel(const T* params,
44+
const IndexT* indices,
45+
T* output,
46+
size_t index_size,
47+
size_t slice_size) {
48+
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
49+
int64_t indices_i = i / slice_size;
50+
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
51+
IndexT gather_i = indices[indices_i];
52+
int64_t params_i = gather_i * slice_size + slice_i;
53+
*(output + i) = *(params + params_i);
54+
}
55+
}
56+
57+
/**
58+
* brief: scatter add
59+
* input: the inputs
60+
* unique_value: refer to UpdateIndexKernel notes
61+
* out_index: the output feature index
62+
* non_zero_num: the number of output features
63+
* rulebook_len: the length of rulebook
64+
* channels: the output channel size
65+
* out: the outputs
66+
**/
67+
template <typename T>
68+
__global__ void ScatterKernel(const T* input,
69+
const int* unique_value,
70+
const int* out_index,
71+
const int non_zero_num,
72+
const int rulebook_len,
73+
const int channels,
74+
T* out) {
75+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
76+
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
77+
int indices_i = i / channels;
78+
int channels_i = i - indices_i * channels;
79+
80+
int start = unique_value[indices_i];
81+
int end = indices_i == non_zero_num - 1 ? rulebook_len
82+
: unique_value[indices_i + 1];
83+
// max(end-start) = kernel_size
84+
T sum = static_cast<T>(0);
85+
for (int j = start; j < end; j++) {
86+
const int out_feature_i = out_index[j];
87+
sum += input[out_feature_i * channels + channels_i];
88+
}
89+
out[indices_i * channels + channels_i] = sum;
90+
}
91+
}
92+
93+
template <typename Context>
94+
inline int* SortedAndUniqueIndex(const Context& dev_ctx,
95+
const int* rulebook_ptr,
96+
const int len,
97+
DenseTensor* out_index,
98+
DenseTensor* unique_key,
99+
DenseTensor* unique_value) {
100+
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
101+
dev_ctx, out_index, kps::IdentityFunctor<int>());
102+
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
103+
dev_ctx, unique_value, kps::IdentityFunctor<int>());
104+
105+
phi::backends::gpu::GpuMemcpyAsync(unique_key->data<int>(),
106+
rulebook_ptr,
107+
sizeof(int) * len,
108+
#ifdef PADDLE_WITH_HIP
109+
hipMemcpyDeviceToDevice,
110+
#else
111+
cudaMemcpyDeviceToDevice,
112+
#endif
113+
dev_ctx.stream());
114+
// compared with thrust::sort_by_key, thrust::merge_by_key may achieved higher
115+
// performance, but thrust::merge_by_key limited by data size
116+
#ifdef PADDLE_WITH_HIP
117+
thrust::sort_by_key(thrust::hip::par.on(dev_ctx.stream()),
118+
#else
119+
thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()),
120+
#endif
121+
unique_key->data<int>(),
122+
unique_key->data<int>() + len,
123+
out_index->data<int>());
124+
125+
// 4. unique
126+
thrust::pair<int*, int*> new_end =
127+
#ifdef PADDLE_WITH_HIP
128+
thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()),
129+
#else
130+
thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()),
131+
#endif
132+
unique_key->data<int>(),
133+
unique_key->data<int>() + len,
134+
unique_value->data<int>());
135+
return new_end.first;
136+
}
137+
138+
} // namespace sparse
139+
} // namespace phi

0 commit comments

Comments
 (0)