Skip to content

Commit 888788d

Browse files
pnunna93MISHANMAURYAamcamdPrasanth Nunna
authored
Enable ROCm backend with custom ops integration (#1683)
* Port ROCm changes from multi-backend-refactor branch * Update ops.py * Update functional.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update functional.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update functional.py * Update functional.py * Update functional.py * Update functional.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update functional.py * Update functional.py * Update functional.py * Update test_ops.py * Update test_functional.py * Update test_ops.py * Update test_functional.py * Update test_functional.py * Update functional.py * Update functional.py * Update ops.py * Update ops.py * Update test_functional.py * Update test_functional.py * Update cextension.py * Update cuda_specs.py * Update cuda_specs.py * Update test_functional.py * Update test_linear4bit.py * Update test_cuda_setup_evaluator.py * Update test_functional.py * Update modules.py * Update modules.py * Update ops.py * Update test_linear4bit.py * Update ops.py * Update ops.py * Update test_linear4bit.py * Update test_linear4bit.py * Update python-package.yml * Update python-package.yml * Update python-package.yml * Update python-package.yml * Create build-rocm.sh * Update cuda_specs.py * Fix trailing whitespace * Remove conflicts.diff * update for hipblasVersionMajor >=3 * Update test_functional.py * Update test_linear4bit.py * Update test_ops.py * Update main.py * Update test_functional.py * Update test_linear4bit.py * Update test_ops.py * Update test_linear4bit.py * Lint * Lint * Update helpers.py * Update test_functional.py * Update test_linear4bit.py * Update test_ops.py * Lint * Update pythonInterface.cpp * lint fix * lint * Update pythonInterface.cpp * revert permissions change * Fix indentation * Update kernels_hip.cuh * Update kernels.hip * Update ops.hip * Update ops_hip.cuh * Update kernels_hip.cuh * Update kernels.hip * Update kernels.hip * Update ops.hip * Update ops_hip.cuh * Update ops.hip * Update CMakeLists.txt * Update functional.py * Update cextension.py * Update cextension.py --------- Co-authored-by: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Co-authored-by: MISHANMAUYRA <mishanmaurya31081@gmail.com> Co-authored-by: amcamd <andrew.chapman@amd.com> Co-authored-by: Prasanth Nunna <root@banff-cyxtera-s78-1.amd.com>
1 parent a1cd3f6 commit 888788d

File tree

21 files changed

+4763
-77
lines changed

21 files changed

+4763
-77
lines changed

.github/scripts/build-rocm.sh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
declare build_arch
3+
declare build_os
4+
declare rocm_version
5+
6+
set -xeuo pipefail
7+
bnb_rocm_arch="gfx90a;gfx942;gfx1100"
8+
if [ "${build_os:0:6}" == ubuntu ]; then
9+
image=rocm/dev-ubuntu-22.04:${rocm_version}-complete
10+
echo "Using image $image"
11+
docker run --rm --platform "linux/$build_arch" -i \
12+
-w /src -v "$PWD:/src" "$image" sh -c \
13+
"apt-get update \
14+
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
15+
&& cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \
16+
&& cmake --build ."
17+
fi
18+
19+
output_dir="output/${build_os}/${build_arch}"
20+
mkdir -p "${output_dir}"
21+
(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}")

.github/workflows/python-package.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,55 @@ jobs:
102102
path: output/*
103103
retention-days: 7
104104

105+
build-shared-libs-rocm:
106+
strategy:
107+
matrix:
108+
os: [ubuntu-22.04]
109+
arch: [x86_64]
110+
rocm_version:
111+
["6.1.2", "6.2.4", "6.3.2"]
112+
runs-on: ${{ matrix.os }}
113+
steps:
114+
- uses: actions/checkout@v4
115+
- name: Set up Docker multiarch
116+
uses: docker/setup-qemu-action@v3
117+
- name: Clean up disk space
118+
run: |
119+
sudo rm -rf \
120+
/usr/share/dotnet \
121+
/opt/ghc \
122+
"/usr/local/share/boost" \
123+
"$AGENT_TOOLSDIRECTORY" \
124+
/opt/hostedtoolcache \
125+
/opt/google/chrome \
126+
/opt/microsoft/msedge \
127+
/opt/microsoft/powershell \
128+
/opt/pipx \
129+
/usr/lib/mono \
130+
/usr/local/julia* \
131+
/usr/local/lib/android \
132+
/usr/local/lib/node_modules \
133+
/usr/local/share/chromium \
134+
/usr/local/share/powershell \
135+
/usr/share/swift
136+
- name: Build C++
137+
run: bash .github/scripts/build-rocm.sh
138+
env:
139+
build_os: ${{ matrix.os }}
140+
build_arch: ${{ matrix.arch }}
141+
rocm_version: ${{ matrix.rocm_version }}
142+
- name: Upload build artifact
143+
uses: actions/upload-artifact@v4
144+
with:
145+
name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }}
146+
path: output/*
147+
retention-days: 7
148+
105149
build-wheels:
106150
needs:
107151
- build-shared-libs
108152
- build-shared-libs-cuda
153+
- build-shared-libs-rocm
109154
strategy:
110155
matrix:
111156
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest]
@@ -173,6 +218,7 @@ jobs:
173218
merge-multiple: true
174219

175220
- name: Inspect tmp directory after downloading artifacts
221+
176222
run: |
177223
ls -alFR tmp/
178224
WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l)
@@ -210,6 +256,7 @@ jobs:
210256
- uses: actions/checkout@v4
211257
with:
212258
path: repo
259+
213260
- name: Delete old pre-release (if exists)
214261
run: |
215262
cd repo && gh release delete continuous-release_main --cleanup-tag -y

CMakeLists.txt

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@ endif()
2525
# Define included source files
2626
set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
2727
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
28+
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2829
set(MPS_FILES csrc/mps_ops.mm)
2930
set(METAL_FILES csrc/mps_kernels.metal)
3031
# C++ sources are always included
3132
list(APPEND SRC_FILES ${CPP_FILES})
3233

33-
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)")
34-
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps)
34+
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
35+
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
3536
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
3637

3738
if(APPLE)
@@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
4748
message(FATAL_ERROR "CUDA is not supported on macOS" )
4849
endif()
4950
set(BUILD_CUDA ON)
51+
set(BUILD_HIP OFF)
52+
set(BUILD_MPS OFF)
53+
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
54+
if(APPLE)
55+
message(FATAL_ERROR "HIP is not supported on macOS" )
56+
endif()
57+
set(BUILD_CUDA OFF)
58+
set(BUILD_HIP ON)
5059
set(BUILD_MPS OFF)
5160
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
5261
if(NOT APPLE)
5362
message(FATAL_ERROR "MPS is only supported on macOS" )
5463
endif()
5564
set(BUILD_CUDA OFF)
65+
set(BUILD_HIP OFF)
5666
set(BUILD_MPS ON)
5767
else()
5868
set(BUILD_CUDA OFF)
69+
set(BUILD_HIP OFF)
5970
set(BUILD_MPS OFF)
6071
endif()
6172

@@ -160,6 +171,33 @@ if(BUILD_CUDA)
160171

161172
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
162173
add_compile_definitions(BUILD_CUDA)
174+
elseif(BUILD_HIP)
175+
enable_language(HIP)
176+
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
177+
if(DEFINED BNB_ROCM_ARCH)
178+
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
179+
else()
180+
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
181+
set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100")
182+
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
183+
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
184+
endif()
185+
endif()
186+
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
187+
188+
list(APPEND SRC_FILES ${HIP_FILES})
189+
190+
string(APPEND BNB_OUTPUT_NAME "_rocm")
191+
192+
# get hip version
193+
execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION)
194+
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
195+
string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}")
196+
197+
string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}")
198+
add_compile_definitions(__HIP_PLATFORM_AMD__)
199+
add_compile_definitions(__HIP_PLATFORM_HCC__)
200+
add_compile_definitions(BUILD_HIP)
163201
elseif(BUILD_MPS)
164202
if(NOT APPLE)
165203
message(FATAL_ERROR "MPS is only supported on macOS" )
@@ -208,6 +246,41 @@ if(BUILD_CUDA)
208246
CUDA_SEPARABLE_COMPILATION ON
209247
)
210248
endif()
249+
if(BUILD_HIP)
250+
if(NOT DEFINED ENV{ROCM_PATH})
251+
set(ROCM_PATH /opt/rocm)
252+
else()
253+
set(ROCM_PATH $ENV{ROCM_PATH})
254+
endif()
255+
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
256+
macro(find_package_and_print_version PACKAGE_NAME)
257+
find_package("${PACKAGE_NAME}" ${ARGN})
258+
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
259+
endmacro()
260+
find_package_and_print_version(hipblas REQUIRED)
261+
find_package_and_print_version(hiprand REQUIRED)
262+
find_package_and_print_version(hipsparse REQUIRED)
263+
264+
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
265+
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
266+
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
267+
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
268+
269+
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
270+
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
271+
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
272+
273+
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
274+
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
275+
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
276+
277+
if(HIP_VERSION VERSION_LESS "6.1")
278+
target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT)
279+
else()
280+
find_package(hipblaslt)
281+
target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt)
282+
endif()
283+
endif()
211284
if(BUILD_MPS)
212285
add_dependencies(bitsandbytes metallib)
213286
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")

bitsandbytes/backends/cuda/ops.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
99

1010
from ..._ops import register_kernel
11-
from ...cextension import lib
11+
from ...cextension import HIP_ENVIRONMENT, lib
1212

1313

1414
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
@@ -210,7 +210,12 @@ def _get_col_absmax(
210210
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
211211
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
212212
torch._check_is_size(blocksize)
213-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
213+
214+
if HIP_ENVIRONMENT:
215+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
216+
else:
217+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
218+
214219
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
215220

216221
n = A.numel()
@@ -264,7 +269,11 @@ def _(
264269
def _dequantize_blockwise_impl(
265270
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
266271
) -> None:
267-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
272+
if HIP_ENVIRONMENT:
273+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
274+
else:
275+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
276+
268277
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
269278
torch._check(
270279
dtype in [torch.float16, torch.bfloat16, torch.float32],
@@ -294,7 +303,11 @@ def _dequantize_blockwise_impl(
294303
def _(
295304
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
296305
) -> tuple[torch.Tensor, torch.Tensor]:
297-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
306+
if HIP_ENVIRONMENT:
307+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
308+
else:
309+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
310+
298311
torch._check(quant_type in ["fp4", "nf4"])
299312
torch._check(
300313
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
@@ -372,7 +385,11 @@ def _dequantize_4bit_impl(
372385
dtype: torch.dtype,
373386
out: torch.Tensor,
374387
) -> None:
375-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
388+
if HIP_ENVIRONMENT:
389+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
390+
else:
391+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
392+
376393
torch._check(quant_type in ["fp4", "nf4"])
377394
torch._check(
378395
dtype in [torch.bfloat16, torch.float16, torch.float32],

0 commit comments

Comments
 (0)