Skip to content

Commit b913cd1

Browse files
authored
[PHI] Fix gather_nd and scatter_nd for big tensor (#73335)
1 parent 8ebe012 commit b913cd1

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

paddle/phi/kernels/funcs/gather.cu.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,17 @@ __global__ void GatherNdCUDAKernel(const T* input,
3737
size_t remain_size,
3838
size_t slice_size,
3939
size_t end_size) {
40-
int total_size = remain_size * slice_size;
41-
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
42-
int64_t stride = blockDim.x * gridDim.x * VecSize;
40+
size_t total_size = remain_size * slice_size;
41+
size_t idx =
42+
(static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize;
43+
size_t stride = static_cast<size_t>(blockDim.x) * gridDim.x * VecSize;
4344

4445
#pragma unroll
4546
for (; idx < total_size; idx += stride) {
46-
int indices_i = idx / slice_size;
47-
int slice_i = idx % slice_size;
48-
int64_t gather_i = 0;
49-
int64_t temp = slice_size;
47+
size_t indices_i = idx / slice_size;
48+
size_t slice_i = idx % slice_size;
49+
size_t gather_i = 0;
50+
size_t gather_stride = slice_size;
5051
#pragma unroll
5152
for (int j = end_size - 1; j >= 0; --j) {
5253
auto index_value = indices[indices_i * end_size + j];
@@ -63,10 +64,10 @@ __global__ void GatherNdCUDAKernel(const T* input,
6364
if (index_value < 0) {
6465
index_value += input_dims[j];
6566
}
66-
gather_i += (index_value * temp);
67-
temp *= input_dims[j];
67+
gather_i += index_value * gather_stride;
68+
gather_stride *= input_dims[j];
6869
}
69-
int64_t input_i = gather_i + slice_i;
70+
size_t input_i = gather_i + slice_i;
7071

7172
using VecType = kps::details::VectorType<T, VecSize>;
7273
const VecType* src = reinterpret_cast<const VecType*>(&input[input_i]);

paddle/phi/kernels/funcs/scatter.cu.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,17 @@ __global__ void ScatterNdCUDAKernel(const T* update,
104104
size_t remain_size,
105105
size_t slice_size,
106106
size_t end_size) {
107-
int total_size = remain_size * slice_size;
108-
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
109-
int64_t stride = blockDim.x * gridDim.x * VecSize;
107+
size_t total_size = remain_size * slice_size;
108+
size_t idx =
109+
(static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize;
110+
size_t stride = static_cast<size_t>(blockDim.x) * gridDim.x * VecSize;
110111

111112
#pragma unroll
112113
for (; idx < total_size; idx += stride) {
113-
int indices_i = idx / slice_size;
114-
int slice_i = idx % slice_size;
115-
int64_t gather_i = 0;
116-
int64_t temp = slice_size;
114+
size_t indices_i = idx / slice_size;
115+
size_t slice_i = idx % slice_size;
116+
size_t gather_i = 0;
117+
size_t gather_stride = slice_size;
117118

118119
#pragma unroll
119120
for (int j = end_size - 1; j >= 0; --j) {
@@ -132,11 +133,11 @@ __global__ void ScatterNdCUDAKernel(const T* update,
132133
index_value += output_dims[j];
133134
}
134135

135-
gather_i += (index_value * temp);
136-
temp *= output_dims[j];
136+
gather_i += index_value * gather_stride;
137+
gather_stride *= output_dims[j];
137138
}
138139

139-
int64_t output_i = gather_i + slice_i;
140+
size_t output_i = gather_i + slice_i;
140141

141142
using VecType = kps::details::VectorType<T, VecSize>;
142143
const VecType* src = reinterpret_cast<const VecType*>(&update[idx]);

0 commit comments

Comments
 (0)