Skip to content

Commit 0d5c9f8

Browse files
authored
Add mixed-precision support
1 parent a50506a commit 0d5c9f8

File tree

18 files changed

+448
-347
lines changed

18 files changed

+448
-347
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ We here provides an entire training example with dummy input [here](examples/exa
107107

108108
You are also welcomed to check out our [SPVNAS](https://github.com/mit-han-lab/e3d) project to implement training / inference with real data.
109109

110+
### Mixed Precision (float16) Support
111+
112+
Mixed precision training is supported via `torch.cuda.amp.autocast` and `torch.cuda.amp.GradScaler`. Enabling mixed precision training can speed up training and reduce GPU memory usage. By wrapping your training code in a `torch.cuda.amp.autocast` block, feature tensors will automatically be converted to float16 if possible. See [here](examples/example.py) for a complete example.
110113

111114
## Speed Comparison Between torchsparse and MinkowskiEngine
112115

examples/example.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torchsparse.nn as spnn
66
from torchsparse import SparseTensor
77
from torchsparse.utils import sparse_collate_fn, sparse_quantize
8+
import argparse
89

910

1011
def generate_random_point_cloud(size=100000, voxel_size=0.2):
@@ -39,7 +40,7 @@ def generate_batched_random_point_clouds(size=100000,
3940
return sparse_collate_fn(batch)
4041

4142

42-
def dummy_train(device):
43+
def dummy_train(device, mixed=False):
4344
model = nn.Sequential(
4445
spnn.Conv3d(4, 32, kernel_size=3, stride=1), spnn.BatchNorm(32),
4546
spnn.ReLU(True), spnn.Conv3d(32, 64, kernel_size=2, stride=2),
@@ -50,21 +51,32 @@ def dummy_train(device):
5051
spnn.ReLU(True), spnn.Conv3d(32, 10, kernel_size=1)).to(device)
5152
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
5253
criterion = nn.CrossEntropyLoss().to(device)
54+
scaler = torch.cuda.amp.GradScaler(enabled=mixed)
5355

5456
print('Starting dummy training...')
5557
for i in range(10):
58+
optimizer.zero_grad()
5659
feed_dict = generate_batched_random_point_clouds()
5760
inputs = feed_dict['lidar'].to(device)
5861
targets = feed_dict['targets'].F.to(device).long()
59-
outputs = model(inputs)
60-
optimizer.zero_grad()
61-
loss = criterion(outputs.F, targets)
62-
loss.backward()
63-
optimizer.step()
62+
with torch.cuda.amp.autocast(enabled=mixed):
63+
outputs = model(inputs)
64+
loss = criterion(outputs.F, targets)
65+
scaler.scale(loss).backward()
66+
scaler.step(optimizer)
67+
scaler.update()
6468
print('[step %d] loss = %f.' % (i, loss.item()))
6569
print('Finished dummy training!')
6670

6771

6872
if __name__ == '__main__':
73+
parser = argparse.ArgumentParser()
74+
parser.add_argument("--mixed", action="store_true")
75+
args = parser.parse_args()
76+
77+
# set seeds for reproducibility
78+
np.random.seed(2021)
79+
torch.manual_seed(2021)
80+
6981
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
70-
dummy_train(device)
82+
dummy_train(device, args.mixed)

setup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
file_lis = [
1818
'torchsparse/src/torchsparse_bindings_gpu.cpp',
1919
'torchsparse/src/convolution/convolution_cpu.cpp',
20-
'torchsparse/src/convolution/convolution.cpp',
20+
'torchsparse/src/convolution/convolution.cu',
2121
'torchsparse/src/convolution/convolution_gpu.cu',
2222
'torchsparse/src/hash/hash_cpu.cpp',
2323
'torchsparse/src/hash/hash.cpp',
2424
'torchsparse/src/hash/hash_gpu.cu',
2525
'torchsparse/src/hashmap/hashmap.cu',
2626
'torchsparse/src/hashmap/hashmap_cpu.cpp',
27-
'torchsparse/src/interpolation/devox.cpp',
2827
'torchsparse/src/interpolation/devox_gpu.cu',
2928
'torchsparse/src/interpolation/devox_deterministic.cpp',
3029
'torchsparse/src/interpolation/devox_deterministic_gpu.cu',
@@ -35,7 +34,6 @@
3534
'torchsparse/src/others/count.cpp',
3635
'torchsparse/src/others/count_gpu.cu',
3736
'torchsparse/src/others/count_cpu.cpp',
38-
'torchsparse/src/others/insertion.cpp',
3937
'torchsparse/src/others/insertion_gpu.cu',
4038
'torchsparse/src/others/insertion_cpu.cpp',
4139
'torchsparse/src/others/query.cpp',

torchsparse/nn/functional/conv.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torchsparse_backend
55
from torch.autograd import Function
6+
from torch.cuda.amp import custom_fwd, custom_bwd
67
from torchsparse import *
78
from torchsparse.nn.functional.convert_neighbor_map import *
89
from torchsparse.nn.functional.downsample import *
@@ -15,6 +16,7 @@
1516

1617
class SpConvolution(Function):
1718
@staticmethod
19+
@custom_fwd(cast_inputs=torch.half)
1820
def forward(ctx,
1921
features,
2022
kernel,
@@ -27,11 +29,13 @@ def forward(ctx,
2729
if not transpose:
2830
out = torch.zeros(sizes[1],
2931
kernel.size(-1),
32+
dtype=features.dtype,
3033
device=features.device)
3134
else:
3235
# tbd: ensure the original, upsampled size to be the same.
3336
out = torch.zeros(sizes[0],
3437
kernel.size(-1),
38+
dtype=features.dtype,
3539
device=features.device)
3640

3741
if 'cuda' in str(features.device):
@@ -61,12 +65,13 @@ def forward(ctx,
6165
return out
6266

6367
@staticmethod
68+
@custom_bwd
6469
def backward(ctx, grad_out):
6570
features, kernel, neighbor_map, neighbor_offset, transpose = ctx.for_backwards
6671
K, c_in, c_out = kernel.size()
6772
N_in = features.size(0)
68-
grad_features = torch.zeros(N_in, c_in, device=features.device)
69-
grad_kernel = torch.zeros(K, c_in, c_out, device=kernel.device)
73+
grad_features = torch.zeros(N_in, c_in, device=features.device, dtype=features.dtype)
74+
grad_kernel = torch.zeros(K, c_in, c_out, device=kernel.device, dtype=features.dtype)
7075

7176
if 'cuda' in str(features.device):
7277
torchsparse_backend.sparseconv_backward(features, grad_features,

torchsparse/nn/functional/devox.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torchsparse_backend
33
from torch.autograd import Function
4+
from torch.cuda.amp import custom_fwd, custom_bwd
45

56
__all__ = ['spdevoxelize', 'calc_ti_weights']
67

@@ -61,6 +62,7 @@ def calc_ti_weights(pc, idx_query, scale=1.0):
6162

6263
class DevoxelizationGPU(Function):
6364
@staticmethod
65+
@custom_fwd(cast_inputs=torch.half)
6466
def forward(ctx, feat, indices, weights):
6567
if 'cuda' in str(feat.device):
6668
out = torchsparse_backend.devoxelize_forward(
@@ -77,6 +79,7 @@ def forward(ctx, feat, indices, weights):
7779
return out
7880

7981
@staticmethod
82+
@custom_bwd
8083
def backward(ctx, grad_out):
8184
indices, weights, n = ctx.for_backwards
8285

torchsparse/nn/functional/downsample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torchsparse_backend
33
from torch.autograd import Function
44
from torchsparse.nn.functional.hash import *
5+
from torchsparse.nn.functional.voxelize import spvoxelize
56

67
__all__ = ['spdownsample']
78

@@ -23,8 +24,7 @@ def forward(ctx, coords, ratio):
2324
# rounding is necessary
2425
# gpu
2526
if 'cuda' in str(coords.device):
26-
uq_coords = torch.round(
27-
torchsparse_backend.insertion_forward(coords_new.float(), inv,
27+
uq_coords = torch.round(spvoxelize(coords_new.float(), inv,
2828
cnt))
2929
elif 'cpu' in str(coords.device):
3030
uq_coords = torch.round(

torchsparse/nn/functional/voxelize.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1+
import torch
12
import torchsparse_backend
23
from torch.autograd import Function
4+
from torch.cuda.amp import custom_fwd, custom_bwd
35
from torchsparse.nn.functional.hash import *
46

57
__all__ = ['spvoxelize']
68

79

810
class VoxelizeGPU(Function):
911
@staticmethod
12+
@custom_fwd(cast_inputs=torch.half)
1013
def forward(ctx, feat, idx, cnt):
11-
out = torchsparse_backend.insertion_forward(feat.float().contiguous(),
14+
out = torchsparse_backend.insertion_forward(feat.contiguous(),
1215
idx.int().contiguous(),
1316
cnt)
1417
ctx.for_backwards = (idx.int().contiguous(), cnt, feat.shape[0])
1518
return out
1619

1720
@staticmethod
21+
@custom_bwd
1822
def backward(ctx, top_grad):
1923
idx, cnt, N = ctx.for_backwards
2024
bottom_grad = torchsparse_backend.insertion_backward(
21-
top_grad.float().contiguous(), idx, cnt, N)
25+
top_grad.contiguous(), idx, cnt, N)
2226
return bottom_grad, None, None
2327

2428

torchsparse/src/common/gpu.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <exception>
1414
#include <iostream>
1515
#include <vector>
16+
#include <torch/torch.h>
1617

1718

1819
//
@@ -103,6 +104,11 @@ template <typename Dtype1, typename Dtype2>
103104
void print(const thrust::device_vector<Dtype1> &v1,
104105
const thrust::device_vector<Dtype2> &v2);
105106

107+
// atomicadd for half types (from aten/src/THC/THCAtomics.cuh)
108+
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
109+
return atomicAdd(reinterpret_cast<__half*>(address), val);
110+
}
111+
106112
// AtomicAddition for double with cuda arch <= 600
107113
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
108114
#else

torchsparse/src/convolution/convolution.cpp

Lines changed: 0 additions & 154 deletions
This file was deleted.

0 commit comments

Comments
 (0)