1818__version__ = '0.6.15'
1919URL = 'https://github.com/rusty1s/pytorch_sparse'
2020
21- WITH_CUDA = torch .cuda .is_available () and CUDA_HOME is not None
21+ WITH_CUDA = False
22+ if torch .cuda .is_available ():
23+ WITH_CUDA = CUDA_HOME is not None or torch .version .hip
2224suffices = ['cpu' , 'cuda' ] if WITH_CUDA else ['cpu' ]
2325if os .getenv ('FORCE_CUDA' , '0' ) == '1' :
2426 suffices = ['cuda' , 'cpu' ]
@@ -40,9 +42,12 @@ def get_extensions():
4042
4143 extensions_dir = osp .join ('csrc' )
4244 main_files = glob .glob (osp .join (extensions_dir , '*.cpp' ))
45+ # remove generated 'hip' files, in case of rebuilds
46+ main_files = [path for path in main_files if 'hip' not in path ]
4347
4448 for main , suffix in product (main_files , suffices ):
4549 define_macros = [('WITH_PYTHON' , None )]
50+ undef_macros = []
4651
4752 if sys .platform == 'win32' :
4853 define_macros += [('torchsparse_EXPORTS' , None )]
@@ -84,13 +89,26 @@ def get_extensions():
8489 define_macros += [('WITH_CUDA' , None )]
8590 nvcc_flags = os .getenv ('NVCC_FLAGS' , '' )
8691 nvcc_flags = [] if nvcc_flags == '' else nvcc_flags .split (' ' )
87- nvcc_flags += ['--expt-relaxed-constexpr' , '- O2' ]
92+ nvcc_flags += ['-O2' ]
8893 extra_compile_args ['nvcc' ] = nvcc_flags
94+ if torch .version .hip :
95+ # USE_ROCM was added to later versions of PyTorch
96+ # Define here to support older PyTorch versions as well:
97+ define_macros += [('USE_ROCM' , None )]
98+ undef_macros += ['__HIP_NO_HALF_CONVERSIONS__' ]
99+ else :
100+ nvcc_flags += ['--expt-relaxed-constexpr' ]
89101
90- if sys .platform == 'win32' :
91- extra_link_args += ['cusparse.lib' ]
102+ if torch .version .hip :
103+ if sys .platform == 'win32' :
104+ extra_link_args += ['hipsparse.lib' ]
105+ else :
106+ extra_link_args += ['-lhipsparse' , '-l' , 'hipsparse' ]
92107 else :
93- extra_link_args += ['-lcusparse' , '-l' , 'cusparse' ]
108+ if sys .platform == 'win32' :
109+ extra_link_args += ['cusparse.lib' ]
110+ else :
111+ extra_link_args += ['-lcusparse' , '-l' , 'cusparse' ]
94112
95113 name = main .split (os .sep )[- 1 ][:- 4 ]
96114 sources = [main ]
@@ -111,6 +129,7 @@ def get_extensions():
111129 sources ,
112130 include_dirs = [extensions_dir , phmap_dir ],
113131 define_macros = define_macros ,
132+ undef_macros = undef_macros ,
114133 extra_compile_args = extra_compile_args ,
115134 extra_link_args = extra_link_args ,
116135 libraries = libraries ,
@@ -129,6 +148,11 @@ def get_extensions():
129148 'pytest-cov' ,
130149]
131150
151+ # work-around hipify abs paths
152+ include_package_data = True
153+ if torch .cuda .is_available () and torch .version .hip :
154+ include_package_data = False
155+
132156setup (
133157 name = 'torch_sparse' ,
134158 version = __version__ ,
@@ -155,5 +179,5 @@ def get_extensions():
155179 BuildExtension .with_options (no_python_abi_suffix = True , use_ninja = False )
156180 },
157181 packages = find_packages (),
158- include_package_data = True ,
182+ include_package_data = include_package_data ,
159183)
0 commit comments