@@ -24,7 +24,60 @@ void ConvolutionForwardGPU(at::Tensor in_feat, at::Tensor out_feat,
2424
2525
2626 int kernel_volume = kernel.size (0 );
27+ int in_buffer_size = 1 ;
28+ bool flag = false ;
29+ // memory optimization
30+ if (kernel_volume % 2 && out_nrows == in_feat.size (0 )){
31+ flag = true ;
32+ in_buffer_size = *std::max_element (neighbor_offset.data_ptr <int >(),
33+ neighbor_offset.data_ptr <int >() + kernel_volume/2 );
34+ in_buffer_size = std::max (in_buffer_size,
35+ *std::max_element (neighbor_offset.data_ptr <int >() + kernel_volume/2 +1 ,
36+ neighbor_offset.data_ptr <int >() + kernel_volume));
37+ in_buffer_size = std::max (in_buffer_size, 1 );
38+
39+ torch::mm_out (out_feat, in_feat, kernel[kernel_volume / 2 ]);
40+ }
41+ else {
42+ in_buffer_size = *std::max_element (neighbor_offset.data_ptr <int >(),
43+ neighbor_offset.data_ptr <int >() + kernel_volume);
44+ }
2745
46+ auto options =
47+ torch::TensorOptions ().dtype (in_feat.dtype ()).device (in_feat.device ());
48+ auto in_buffer = torch::zeros ({in_buffer_size, in_feat.size (1 )}, options);
49+ auto out_buffer = torch::zeros ({in_buffer_size, kernel.size (2 )}, options);
50+ int cur_offset = 0 ;
51+ for (int i = 0 ; i < kernel_volume; i++){
52+ if (flag && (i == kernel_volume / 2 )){
53+ cur_offset += 2 * neighbor_offset.data_ptr <int >()[i];
54+ continue ;
55+ }
56+
57+ if (neighbor_offset.data_ptr <int >()[i]==0 ){
58+ continue ;
59+ }
60+
61+ auto out_buffer_activated =
62+ torch::from_blob (out_buffer.data_ptr <float >(),
63+ {neighbor_offset.data_ptr <int >()[i], kernel.size (2 )}, options);
64+ auto in_buffer_activated =
65+ torch::from_blob (in_buffer.data_ptr <float >(),
66+ {neighbor_offset.data_ptr <int >()[i], in_feat.size (1 )}, options);
67+ // gather
68+ gather_launch (in_buffer_activated.size (0 ), in_feat.size (0 ), kernel.size (1 ),
69+ in_feat.data_ptr <float >(), in_buffer_activated.data_ptr <float >(),
70+ neighbor_map.data_ptr <int >() + cur_offset, transpose);
71+ // GEMM
72+ torch::mm_out (out_buffer_activated, in_buffer_activated, kernel[i]);
73+ // scatter
74+ scatter_launch (neighbor_offset.data_ptr <int >()[i], out_nrows, kernel.size (2 ), out_buffer_activated.data_ptr <float >(),
75+ out_feat.data_ptr <float >(), neighbor_map.data_ptr <int >() + cur_offset, transpose);
76+ cur_offset += 2 * neighbor_offset.data_ptr <int >()[i];
77+ }
78+
79+
80+ /*
2881 cublasHandle_t handle =
2982 //THCState_getCurrentBlasHandle(at::globalContext().getTHCState());
3083 at::cuda::getCurrentCUDABlasHandle();
@@ -35,7 +88,7 @@ void ConvolutionForwardGPU(at::Tensor in_feat, at::Tensor out_feat,
3588 neighbor_offset.data_ptr<int>(), in_feat.size(0), out_feat.size(0),
3689 kernel.size(0), transpose, handle,
3790 at::cuda::getCurrentCUDAStream());
38-
91+ */
3992
4093
4194}
@@ -52,7 +105,63 @@ void ConvolutionBackwardGPU(
52105
53106 int kernel_volume = kernel.size (0 );
54107 bool flag = false ;
108+ int in_buffer_size;
109+ in_buffer_size = *std::max_element (neighbor_offset.data_ptr <int >(),
110+ neighbor_offset.data_ptr <int >() + kernel_volume);
55111
112+ auto options =
113+ torch::TensorOptions ().dtype (in_feat.dtype ()).device (in_feat.device ());
114+ auto in_buffer = torch::zeros ({in_buffer_size, in_feat.size (1 )}, options);
115+ auto in_grad_buffer = torch::zeros ({in_buffer_size, in_feat.size (1 )}, options);
116+ auto out_grad_buffer = torch::zeros ({in_buffer_size, kernel.size (2 )}, options);
117+
118+
119+ int cur_offset = 0 ;
120+ for (int i = 0 ; i < kernel_volume; i++){
121+ auto kernel_grad_buffer = grad_kernel[i];
122+ if (flag && (i == kernel_volume / 2 )){
123+ cur_offset += 2 * neighbor_offset.data_ptr <int >()[i];
124+ continue ;
125+ }
126+
127+ if (neighbor_offset.data_ptr <int >()[i]==0 ){
128+ continue ;
129+ }
130+
131+ auto out_grad_buffer_activated =
132+ torch::from_blob (out_grad_buffer.data_ptr <float >(),
133+ {neighbor_offset.data_ptr <int >()[i], kernel.size (2 )}, options);
134+ auto in_grad_buffer_activated =
135+ torch::from_blob (in_grad_buffer.data_ptr <float >(),
136+ {neighbor_offset.data_ptr <int >()[i], in_feat.size (1 )}, options);
137+ auto in_buffer_activated =
138+ torch::from_blob (in_buffer.data_ptr <float >(),
139+ {neighbor_offset.data_ptr <int >()[i], in_feat.size (1 )}, options);
140+ // gather
141+
142+ gather_launch (out_grad_buffer_activated.size (0 ), grad_out_feat.size (0 ), kernel.size (2 ),
143+ grad_out_feat.data_ptr <float >(), out_grad_buffer_activated.data_ptr <float >(),
144+ neighbor_map.data_ptr <int >() + cur_offset, !transpose);
145+
146+ gather_launch (in_buffer_activated.size (0 ), in_feat.size (0 ), kernel.size (1 ),
147+ in_feat.data_ptr <float >(), in_buffer_activated.data_ptr <float >(),
148+ neighbor_map.data_ptr <int >() + cur_offset, transpose);
149+
150+ // GEMM
151+ // torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]);
152+ torch::mm_out (in_grad_buffer_activated, out_grad_buffer_activated, torch::transpose (kernel[i], 0 , 1 ));
153+ torch::mm_out (kernel_grad_buffer, torch::transpose (in_buffer_activated, 0 , 1 ), out_grad_buffer_activated);
154+ // scatter
155+ // grad_kernel[i] = kernel_grad_buffer;
156+
157+ scatter_launch (neighbor_offset.data_ptr <int >()[i], in_feat.size (0 ), kernel.size (1 ), in_grad_buffer_activated.data_ptr <float >(),
158+ grad_in_feat.data_ptr <float >(), neighbor_map.data_ptr <int >() + cur_offset, !transpose);
159+
160+ cur_offset += 2 * neighbor_offset.data_ptr <int >()[i];
161+
162+ }
163+
164+ /*
56165 cublasHandle_t handle =
57166 //THCState_getCurrentBlasHandle(at::globalContext().getTHCState());
58167 at::cuda::getCurrentCUDABlasHandle();
@@ -62,7 +171,7 @@ void ConvolutionBackwardGPU(
62171 grad_kernel.data_ptr<float>(), neighbor_map.data_ptr<int>(), neighbor_offset.data_ptr<int>(),
63172 in_feat.size(0), grad_out_feat.size(0), kernel.size(0),
64173 transpose, handle, at::cuda::getCurrentCUDAStream());
65-
174+ */
66175}
67176
68177
@@ -72,3 +181,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
72181 m.def("sparseconv_backward", &ConvolutionBackwardGPU, "point cloud convolution backward (CUDA)");
73182}
74183*/
184+
0 commit comments