66 AT_ASSERTM (x.device().is_cuda(), #x " must be CUDA tensor" )
77#define CHECK_INPUT (x ) AT_ASSERTM(x, " Input mismatch" )
88
9- __device__ __inline__ at::Half
10- __shfl_sync (const unsigned mask, const at::Half var, const int srcLane) {
11- return __shfl_sync (mask, var.operator __half (), srcLane);
9+ __device__ __inline__ at::Half __shfl_up_sync (const unsigned mask,
10+ const at::Half var,
11+ const unsigned int delta) {
12+ return __shfl_up_sync (mask, var.operator __half (), delta);
1213}
1314
1415__device__ __inline__ at::Half __shfl_down_sync (const unsigned mask,
@@ -17,6 +18,27 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
1718 return __shfl_down_sync (mask, var.operator __half (), delta);
1819}
1920
21+ __device__ __inline__ at::Half __shfl_sync (const unsigned mask,
22+ const at::Half var,
23+ const int delta) {
24+ return __shfl_sync (mask, var.operator __half (), delta);
25+ }
26+
27+ __device__ __inline__ at::Half __shfl_up (const at::Half var,
28+ const unsigned int delta) {
29+ return __shfl_up (var.operator __half (), delta);
30+ }
31+
32+ __device__ __inline__ at::Half __shfl_down (const at::Half var,
33+ const unsigned int delta) {
34+ return __shfl_down (var.operator __half (), delta);
35+ }
36+
37+ __device__ __inline__ at::Half
38+ __shfl (const at::Half var, const int delta) {
39+ return __shfl (var.operator __half (), delta);
40+ }
41+
2042#ifdef USE_ROCM
2143__device__ __inline__ at::Half __ldg (const at::Half* ptr) {
2244 return __ldg (reinterpret_cast <const __half*>(ptr));
0 commit comments