@@ -37,16 +37,17 @@ __global__ void GatherNdCUDAKernel(const T* input,
3737 size_t  remain_size,
3838 size_t  slice_size,
3939 size_t  end_size) {
40-  int  total_size = remain_size * slice_size;
41-  int  idx = (blockIdx.x  * blockDim.x  + threadIdx.x ) * VecSize;
42-  int64_t  stride = blockDim.x  * gridDim.x  * VecSize;
40+  size_t  total_size = remain_size * slice_size;
41+  size_t  idx =
42+  (static_cast <size_t >(blockIdx.x ) * blockDim.x  + threadIdx.x ) * VecSize;
43+  size_t  stride = static_cast <size_t >(blockDim.x ) * gridDim.x  * VecSize;
4344
4445#pragma  unroll
4546 for  (; idx < total_size; idx += stride) {
46-  int  indices_i = idx / slice_size;
47-  int  slice_i = idx % slice_size;
48-  int64_t  gather_i = 0 ;
49-  int64_t  temp  = slice_size;
47+  size_t  indices_i = idx / slice_size;
48+  size_t  slice_i = idx % slice_size;
49+  size_t  gather_i = 0 ;
50+  size_t  gather_stride  = slice_size;
5051#pragma  unroll
5152 for  (int  j = end_size - 1 ; j >= 0 ; --j) {
5253 auto  index_value = indices[indices_i * end_size + j];
@@ -63,10 +64,10 @@ __global__ void GatherNdCUDAKernel(const T* input,
6364 if  (index_value < 0 ) {
6465 index_value += input_dims[j];
6566 }
66-  gather_i += ( index_value * temp) ;
67-  temp  *= input_dims[j];
67+  gather_i += index_value * gather_stride ;
68+  gather_stride  *= input_dims[j];
6869 }
69-  int64_t  input_i = gather_i + slice_i;
70+  size_t  input_i = gather_i + slice_i;
7071
7172 using  VecType = kps::details::VectorType<T, VecSize>;
7273 const  VecType* src = reinterpret_cast <const  VecType*>(&input[input_i]);
0 commit comments