Skip to content

Commit 6ea1f8d

Browse files
committed
[Major] Update convolution.
1 parent 6a53325 commit 6ea1f8d

File tree

2 files changed

+116
-2
lines changed

2 files changed

+116
-2
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
year = {2020}
1010
}
1111
```
12+
13+
**[NEW!!]** We are releasing `torchsparse` v1.1, which has a significant speedup over the previous v1.0, please have a look!
14+
15+
1216
## Overview
1317

1418
We release `torchsparse`, a high-performance computing library for efficient 3D sparse convolution. This library aims at accelerating sparse computation in 3D, in particular the Sparse Convolution operation.

torchsparse/src/convolution/convolution.cpp

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,60 @@ void ConvolutionForwardGPU(at::Tensor in_feat, at::Tensor out_feat,
2424

2525

2626
int kernel_volume = kernel.size(0);
27+
int in_buffer_size = 1;
28+
bool flag = false;
29+
// memory optimization
30+
if(kernel_volume % 2 && out_nrows == in_feat.size(0)){
31+
flag = true;
32+
in_buffer_size = *std::max_element(neighbor_offset.data_ptr<int>(),
33+
neighbor_offset.data_ptr<int>() + kernel_volume/2);
34+
in_buffer_size = std::max(in_buffer_size,
35+
*std::max_element(neighbor_offset.data_ptr<int>() + kernel_volume/2+1,
36+
neighbor_offset.data_ptr<int>() + kernel_volume));
37+
in_buffer_size = std::max(in_buffer_size, 1);
38+
39+
torch::mm_out(out_feat, in_feat, kernel[kernel_volume / 2]);
40+
}
41+
else{
42+
in_buffer_size = *std::max_element(neighbor_offset.data_ptr<int>(),
43+
neighbor_offset.data_ptr<int>() + kernel_volume);
44+
}
2745

46+
auto options =
47+
torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device());
48+
auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options);
49+
auto out_buffer = torch::zeros({in_buffer_size, kernel.size(2)}, options);
50+
int cur_offset = 0;
51+
for(int i = 0; i < kernel_volume; i++){
52+
if(flag && (i == kernel_volume / 2)){
53+
cur_offset += 2 * neighbor_offset.data_ptr<int>()[i];
54+
continue;
55+
}
56+
57+
if(neighbor_offset.data_ptr<int>()[i]==0){
58+
continue;
59+
}
60+
61+
auto out_buffer_activated =
62+
torch::from_blob(out_buffer.data_ptr<float>(),
63+
{neighbor_offset.data_ptr<int>()[i], kernel.size(2)}, options);
64+
auto in_buffer_activated =
65+
torch::from_blob(in_buffer.data_ptr<float>(),
66+
{neighbor_offset.data_ptr<int>()[i], in_feat.size(1)}, options);
67+
// gather
68+
gather_launch(in_buffer_activated.size(0), in_feat.size(0), kernel.size(1),
69+
in_feat.data_ptr<float>(), in_buffer_activated.data_ptr<float>(),
70+
neighbor_map.data_ptr<int>() + cur_offset, transpose);
71+
// GEMM
72+
torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]);
73+
// scatter
74+
scatter_launch(neighbor_offset.data_ptr<int>()[i], out_nrows, kernel.size(2), out_buffer_activated.data_ptr<float>(),
75+
out_feat.data_ptr<float>(), neighbor_map.data_ptr<int>() + cur_offset, transpose);
76+
cur_offset += 2 * neighbor_offset.data_ptr<int>()[i];
77+
}
78+
79+
80+
/*
2881
cublasHandle_t handle =
2982
//THCState_getCurrentBlasHandle(at::globalContext().getTHCState());
3083
at::cuda::getCurrentCUDABlasHandle();
@@ -35,7 +88,7 @@ void ConvolutionForwardGPU(at::Tensor in_feat, at::Tensor out_feat,
3588
neighbor_offset.data_ptr<int>(), in_feat.size(0), out_feat.size(0),
3689
kernel.size(0), transpose, handle,
3790
at::cuda::getCurrentCUDAStream());
38-
91+
*/
3992

4093

4194
}
@@ -52,7 +105,63 @@ void ConvolutionBackwardGPU(
52105

53106
int kernel_volume = kernel.size(0);
54107
bool flag = false;
108+
int in_buffer_size;
109+
in_buffer_size = *std::max_element(neighbor_offset.data_ptr<int>(),
110+
neighbor_offset.data_ptr<int>() + kernel_volume);
55111

112+
auto options =
113+
torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device());
114+
auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options);
115+
auto in_grad_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options);
116+
auto out_grad_buffer = torch::zeros({in_buffer_size, kernel.size(2)}, options);
117+
118+
119+
int cur_offset = 0;
120+
for(int i = 0; i < kernel_volume; i++){
121+
auto kernel_grad_buffer = grad_kernel[i];
122+
if(flag && (i == kernel_volume / 2)){
123+
cur_offset += 2 * neighbor_offset.data_ptr<int>()[i];
124+
continue;
125+
}
126+
127+
if(neighbor_offset.data_ptr<int>()[i]==0){
128+
continue;
129+
}
130+
131+
auto out_grad_buffer_activated =
132+
torch::from_blob(out_grad_buffer.data_ptr<float>(),
133+
{neighbor_offset.data_ptr<int>()[i], kernel.size(2)}, options);
134+
auto in_grad_buffer_activated =
135+
torch::from_blob(in_grad_buffer.data_ptr<float>(),
136+
{neighbor_offset.data_ptr<int>()[i], in_feat.size(1)}, options);
137+
auto in_buffer_activated =
138+
torch::from_blob(in_buffer.data_ptr<float>(),
139+
{neighbor_offset.data_ptr<int>()[i], in_feat.size(1)}, options);
140+
// gather
141+
142+
gather_launch(out_grad_buffer_activated.size(0), grad_out_feat.size(0), kernel.size(2),
143+
grad_out_feat.data_ptr<float>(), out_grad_buffer_activated.data_ptr<float>(),
144+
neighbor_map.data_ptr<int>() + cur_offset, !transpose);
145+
146+
gather_launch(in_buffer_activated.size(0), in_feat.size(0), kernel.size(1),
147+
in_feat.data_ptr<float>(), in_buffer_activated.data_ptr<float>(),
148+
neighbor_map.data_ptr<int>() + cur_offset, transpose);
149+
150+
// GEMM
151+
//torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]);
152+
torch::mm_out(in_grad_buffer_activated, out_grad_buffer_activated, torch::transpose(kernel[i], 0, 1));
153+
torch::mm_out(kernel_grad_buffer, torch::transpose(in_buffer_activated, 0, 1), out_grad_buffer_activated);
154+
// scatter
155+
//grad_kernel[i] = kernel_grad_buffer;
156+
157+
scatter_launch(neighbor_offset.data_ptr<int>()[i], in_feat.size(0), kernel.size(1), in_grad_buffer_activated.data_ptr<float>(),
158+
grad_in_feat.data_ptr<float>(), neighbor_map.data_ptr<int>() + cur_offset, !transpose);
159+
160+
cur_offset += 2 * neighbor_offset.data_ptr<int>()[i];
161+
162+
}
163+
164+
/*
56165
cublasHandle_t handle =
57166
//THCState_getCurrentBlasHandle(at::globalContext().getTHCState());
58167
at::cuda::getCurrentCUDABlasHandle();
@@ -62,7 +171,7 @@ void ConvolutionBackwardGPU(
62171
grad_kernel.data_ptr<float>(), neighbor_map.data_ptr<int>(), neighbor_offset.data_ptr<int>(),
63172
in_feat.size(0), grad_out_feat.size(0), kernel.size(0),
64173
transpose, handle, at::cuda::getCurrentCUDAStream());
65-
174+
*/
66175
}
67176

68177

@@ -72,3 +181,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
72181
m.def("sparseconv_backward", &ConvolutionBackwardGPU, "point cloud convolution backward (CUDA)");
73182
}
74183
*/
184+

0 commit comments

Comments
 (0)