@@ -61,8 +61,8 @@ __global__ void GraphSendRecvCUDAKernel(const T* params,
6161 const  IndexT* src_indices,
6262 const  IndexT* dst_indices,
6363 T* output,
64-  size_t  index_size,
65-  size_t  slice_size,
64+  int64_t  index_size,
65+  int64_t  slice_size,
6666 Functor functor) {
6767 CUDA_KERNEL_LOOP_TYPE (i, index_size * slice_size, int64_t ) {
6868 int64_t  indices_i = i / slice_size;
@@ -78,8 +78,8 @@ __global__ void GraphSendRecvCUDAKernel(const T* params,
7878//  For max
7979template  <typename  T>
8080__global__ void  InputResetMaxCUDAKernel (T* output,
81-  size_t  input_size,
82-  size_t  slice_size) {
81+  int64_t  input_size,
82+  int64_t  slice_size) {
8383 CUDA_KERNEL_LOOP_TYPE (i, input_size * slice_size, int64_t ) {
8484 if  (*(output + i) == std::numeric_limits<T>::lowest ()) {
8585 *(output + i) = 0 ;
@@ -90,8 +90,8 @@ __global__ void InputResetMaxCUDAKernel(T* output,
9090//  For min
9191template  <typename  T>
9292__global__ void  InputResetMinCUDAKernel (T* output,
93-  size_t  input_size,
94-  size_t  slice_size) {
93+  int64_t  input_size,
94+  int64_t  slice_size) {
9595 CUDA_KERNEL_LOOP_TYPE (i, input_size * slice_size, int64_t ) {
9696 if  (*(output + i) == std::numeric_limits<T>::max ()) {
9797 *(output + i) = 0 ;
@@ -130,8 +130,8 @@ __global__ void ManipulateMeanGradCUDAKernel(const T* params,
130130 const  IndexT* src_indices,
131131 const  IndexT* dst_indices,
132132 T* output,
133-  size_t  index_size,
134-  size_t  slice_size,
133+  int64_t  index_size,
134+  int64_t  slice_size,
135135 const  int32_t * dst_count) {
136136 CUDA_KERNEL_LOOP_TYPE (i, index_size * slice_size, int64_t ) {
137137 int64_t  indices_i = i / slice_size;
0 commit comments