Skip to content

Commit 3138857

Browse files
authored
Enable ROCm builds (rusty1s#282)
1 parent d1aee18 commit 3138857

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

csrc/cuda/atomics.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ static inline __device__ void atomAdd(float *address, float val) {
55
}
66

77
static inline __device__ void atomAdd(double *address, double val) {
8-
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
8+
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000))
99
unsigned long long int *address_as_ull = (unsigned long long int *)address;
1010
unsigned long long int old = *address_as_ull;
1111
unsigned long long int assumed;

csrc/cuda/utils.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,14 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
1616
const unsigned int delta) {
1717
return __shfl_down_sync(mask, var.operator __half(), delta);
1818
}
19+
20+
#ifdef USE_ROCM
21+
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
22+
return __ldg(reinterpret_cast<const __half*>(ptr));
23+
}
24+
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
25+
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
26+
#else
27+
#define SHFL_UP_SYNC __shfl_up_sync
28+
#define SHFL_DOWN_SYNC __shfl_down_sync
29+
#endif

csrc/version.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
#include <torch/script.h>
55

66
#ifdef WITH_CUDA
7+
#ifdef USE_ROCM
8+
#include <hip/hip_version.h>
9+
#else
710
#include <cuda.h>
811
#endif
12+
#endif
913

1014
#include "macros.h"
1115

@@ -22,7 +26,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
2226
namespace sparse {
2327
SPARSE_API int64_t cuda_version() noexcept {
2428
#ifdef WITH_CUDA
29+
#ifdef USE_ROCM
30+
return HIP_VERSION;
31+
#else
2532
return CUDA_VERSION;
33+
#endif
2634
#else
2735
return -1;
2836
#endif

setup.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
__version__ = '0.6.15'
1919
URL = 'https://github.com/rusty1s/pytorch_sparse'
2020

21-
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
21+
WITH_CUDA = False
22+
if torch.cuda.is_available():
23+
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
2224
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
2325
if os.getenv('FORCE_CUDA', '0') == '1':
2426
suffices = ['cuda', 'cpu']
@@ -40,9 +42,12 @@ def get_extensions():
4042

4143
extensions_dir = osp.join('csrc')
4244
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
45+
# remove generated 'hip' files, in case of rebuilds
46+
main_files = [path for path in main_files if 'hip' not in path]
4347

4448
for main, suffix in product(main_files, suffices):
4549
define_macros = [('WITH_PYTHON', None)]
50+
undef_macros = []
4651

4752
if sys.platform == 'win32':
4853
define_macros += [('torchsparse_EXPORTS', None)]
@@ -84,13 +89,26 @@ def get_extensions():
8489
define_macros += [('WITH_CUDA', None)]
8590
nvcc_flags = os.getenv('NVCC_FLAGS', '')
8691
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
87-
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
92+
nvcc_flags += ['-O2']
8893
extra_compile_args['nvcc'] = nvcc_flags
94+
if torch.version.hip:
95+
# USE_ROCM was added to later versions of PyTorch
96+
# Define here to support older PyTorch versions as well:
97+
define_macros += [('USE_ROCM', None)]
98+
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
99+
else:
100+
nvcc_flags += ['--expt-relaxed-constexpr']
89101

90-
if sys.platform == 'win32':
91-
extra_link_args += ['cusparse.lib']
102+
if torch.version.hip:
103+
if sys.platform == 'win32':
104+
extra_link_args += ['hipsparse.lib']
105+
else:
106+
extra_link_args += ['-lhipsparse', '-l', 'hipsparse']
92107
else:
93-
extra_link_args += ['-lcusparse', '-l', 'cusparse']
108+
if sys.platform == 'win32':
109+
extra_link_args += ['cusparse.lib']
110+
else:
111+
extra_link_args += ['-lcusparse', '-l', 'cusparse']
94112

95113
name = main.split(os.sep)[-1][:-4]
96114
sources = [main]
@@ -111,6 +129,7 @@ def get_extensions():
111129
sources,
112130
include_dirs=[extensions_dir, phmap_dir],
113131
define_macros=define_macros,
132+
undef_macros=undef_macros,
114133
extra_compile_args=extra_compile_args,
115134
extra_link_args=extra_link_args,
116135
libraries=libraries,
@@ -129,6 +148,11 @@ def get_extensions():
129148
'pytest-cov',
130149
]
131150

151+
# work-around hipify abs paths
152+
include_package_data = True
153+
if torch.cuda.is_available() and torch.version.hip:
154+
include_package_data = False
155+
132156
setup(
133157
name='torch_sparse',
134158
version=__version__,
@@ -155,5 +179,5 @@ def get_extensions():
155179
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
156180
},
157181
packages=find_packages(),
158-
include_package_data=True,
182+
include_package_data=include_package_data,
159183
)

0 commit comments

Comments
 (0)