@@ -94,7 +94,7 @@ class ReduceMin {
9494static ReduceMin reduce_min;
9595
9696__global__ void CudaMemsetAsync (int * dest, int value, size_t size) {
97- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
97+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
9898 if (tid * sizeof (int ) >= size) return ;
9999 dest[tid] = value;
100100}
@@ -117,7 +117,7 @@ __global__ void ScatterAssignGPUKernel(tensor_t* self_data,
117117 int64_t numel_data,
118118 const func_t & reduce_op,
119119 int * thread_ids) {
120- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
120+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
121121 if (tid >= numel) return ;
122122 int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
123123 // squeezed from the N layers loop.
@@ -316,7 +316,7 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data,
316316 bool include_self,
317317 const func_t & reduce_op,
318318 int * shared_mem) {
319- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
319+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
320320 if (tid >= numel) return ;
321321
322322 int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
@@ -639,7 +639,7 @@ __global__ void ScatterInputGradGPUKernel(tensor_t* grad_data,
639639 int64_t outer_dim_size_data,
640640 int64_t numel,
641641 int64_t numel_data) {
642- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
642+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
643643 if (tid >= numel) return ;
644644 int64_t i, j, k;
645645 i = tid / (select_dim_size * outer_dim_size);
@@ -710,7 +710,7 @@ __global__ void ScatterMulInputGradGPUKernel(tensor_t* grad_data,
710710 int64_t numel,
711711 int64_t numel_grad,
712712 int * thread_ids) {
713- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
713+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
714714 if (tid >= numel) return ;
715715 int64_t i, j, k;
716716 i = tid / (select_dim_size * outer_dim_size);
@@ -746,7 +746,7 @@ __global__ void ScatterMinMaxInputGradGPUKernel(tensor_t* grad_data,
746746 int64_t numel_grad,
747747 const std::string& reduce,
748748 int * shared_mem) {
749- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
749+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
750750 if (tid >= numel) return ;
751751 int64_t i, j, k;
752752 i = tid / (select_dim_size * outer_dim_size);
@@ -869,7 +869,7 @@ __global__ void ScatterMeanInputGradGPUKernel(tensor_t* grad_data,
869869 int64_t numel,
870870 int64_t numel_grad,
871871 int * shared_mem) {
872- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
872+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
873873 if (tid >= numel) return ;
874874 int64_t i, j, k;
875875 i = tid / (select_dim_size * outer_dim_size);
@@ -960,7 +960,7 @@ __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data,
960960 int64_t numel,
961961 int64_t numel_data,
962962 int * thread_ids) {
963- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
963+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
964964 if (tid >= numel) return ;
965965
966966 int64_t i, j, k;
@@ -1054,7 +1054,7 @@ __global__ void ScatterMeanValueGradGPUKernel(tensor_t* grad_data,
10541054 int64_t numel,
10551055 int64_t numel_self,
10561056 int * shared_mem) {
1057- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
1057+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
10581058 if (tid >= numel) return ;
10591059
10601060 int64_t i, j, k;
@@ -1088,7 +1088,7 @@ __global__ void ScatterAddValueGradGPUKernel(tensor_t* grad_data,
10881088 int64_t outer_dim_size_self,
10891089 int64_t outer_dim_size_grad,
10901090 int64_t numel) {
1091- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
1091+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
10921092 if (tid >= numel) return ;
10931093 int64_t i, j, k;
10941094 i = tid / (select_dim_size * outer_dim_size);
@@ -1201,7 +1201,7 @@ __global__ void ScatterMulValueGradGPUKernel(tensor_t* grad_data,
12011201 int64_t outer_dim_size_self,
12021202 int64_t outer_dim_size_grad,
12031203 int64_t numel) {
1204- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
1204+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
12051205 if (tid >= numel) return ;
12061206 int64_t i, j, k;
12071207 i = tid / (select_dim_size * outer_dim_size);
@@ -1236,7 +1236,7 @@ __global__ void ScatterMinMaxValueGradGPUKernel(tensor_t* grad_data,
12361236 int64_t numel_self,
12371237 bool include_self,
12381238 int * shared_mem) {
1239- int tid = threadIdx .x + blockIdx .x * blockDim .x ;
1239+ int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
12401240 if (tid >= numel) return ;
12411241 int64_t i, j, k;
12421242 i = tid / (select_dim_size * outer_dim_size);
0 commit comments