Skip to content

Commit c86d777

Browse files
miaonengStrigesrusty1s
authored
Use macro for __shfl_* functions for ROCm (#296)
* Use macro for __shfl_* functions * Update test_matmul.py Co-authored-by: jytang <striges@users.noreply.github.com> Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
1 parent 1bf1276 commit c86d777

File tree

3 files changed

+21
-20
lines changed

3 files changed

+21
-20
lines changed

csrc/cuda/spmm_cuda.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
6363
#pragma unroll
6464
for (int i = 0; i < 32; i++) {
6565
// Communication between all threads in a warp.
66-
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
66+
mat_rows[i] = SHFL_SYNC(FULL_MASK, mat_row, i);
6767
if (HAS_VALUE)
68-
vals[i] = __shfl_sync(FULL_MASK, val, i);
68+
vals[i] = SHFL_SYNC(FULL_MASK, val, i);
6969
}
7070

7171
#pragma unroll
@@ -179,7 +179,7 @@ spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
179179

180180
#pragma unroll
181181
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
182-
val += __shfl_down_sync(FULL_MASK, val, i);
182+
val += SHFL_DOWN_SYNC(FULL_MASK, val, i);
183183
}
184184

185185
if (lane_idx == 0) {

csrc/cuda/utils.cuh

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
1717
return __shfl_down_sync(mask, var.operator __half(), delta);
1818
}
1919

20-
#ifdef USE_ROCM
21-
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
22-
return __ldg(reinterpret_cast<const __half*>(ptr));
23-
}
24-
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
25-
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
26-
#else
27-
#define SHFL_UP_SYNC __shfl_up_sync
28-
#define SHFL_DOWN_SYNC __shfl_down_sync
29-
#endif
20+
#ifdef USE_ROCM
21+
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
22+
return __ldg(reinterpret_cast<const __half*>(ptr));
23+
}
24+
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
25+
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
26+
#define SHFL_SYNC(mask, var, delta) __shfl(var, delta)
27+
#else
28+
#define SHFL_UP_SYNC __shfl_up_sync
29+
#define SHFL_DOWN_SYNC __shfl_down_sync
30+
#define SHFL_SYNC __shfl_sync
31+
#endif

test/test_matmul.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,13 @@ def test_spmm(dtype, device, reduce):
4343
out = matmul(src, other, reduce)
4444
out.backward(grad_out)
4545

46+
atol = 1e-7
4647
if dtype == torch.float16 or dtype == torch.bfloat16:
47-
assert torch.allclose(expected, out, atol=1e-1)
48-
assert torch.allclose(expected_grad_value, value.grad, atol=1e-1)
49-
assert torch.allclose(expected_grad_other, other.grad, atol=1e-1)
50-
else:
51-
assert torch.allclose(expected, out)
52-
assert torch.allclose(expected_grad_value, value.grad)
53-
assert torch.allclose(expected_grad_other, other.grad)
48+
atol = 1e-1
49+
50+
assert torch.allclose(expected, out, atol=atol)
51+
assert torch.allclose(expected_grad_value, value.grad, atol=atol)
52+
assert torch.allclose(expected_grad_other, other.grad, atol=atol)
5453

5554

5655
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))

0 commit comments

Comments
 (0)