Skip to content

Commit f93badd

Browse files
bug fix:trapezoid test=develop
1 parent 8f28938 commit f93badd

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

paddle/phi/kernels/funcs/gather.cu.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,9 @@ __global__ void GatherGradGPUKernel(const T* input,
196196
int64_t input_index_dim_size,
197197
int64_t out_index_dim_size,
198198
int64_t size) {
199-
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
200-
for (; idx < size; idx += blockDim.x * gridDim.x) {
199+
int64_t idx = static_cast<int64_t>(blockDim.x) * blockIdx.x + threadIdx.x;
200+
const int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x;
201+
for (; idx < size; idx += stride) {
201202
int64_t inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
202203
int64_t next_idx = idx % (outer_dim_size * input_index_dim_size);
203204
int64_t index_dim_index = next_idx / (outer_dim_size);

paddle/phi/kernels/gpu/elementwise_grad.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,14 @@ static __global__ void SimpleElemwiseSubGradCUDAKernel(const T *dout,
243243
int64_t size,
244244
T *dx,
245245
T *dy) {
246-
int col = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X;
246+
int64_t col = static_cast<int64_t>(BLOCK_ID_X) * BLOCK_NUM_X + THREAD_ID_X;
247247

248248
while (col < size) {
249249
if (dx != nullptr) {
250250
dx[col] = dout[col];
251251
}
252252
dy[col] = -dout[col];
253-
col += BLOCK_NUM_X * GRID_NUM_X;
253+
col += static_cast<int64_t>(BLOCK_NUM_X) * GRID_NUM_X;
254254
}
255255
}
256256

paddle/phi/kernels/gpu/gather_grad_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ void GatherGradKernel(const Context& dev_ctx,
3434
if (axis_v < 0) {
3535
axis_v += static_cast<int>(x.dims().size());
3636
}
37+
3738
if (axis_v != 0) {
3839
if (index_type == DataType::INT32) {
3940
phi::funcs::GatherV2GradCUDAFunction<T, int32_t>(
@@ -44,6 +45,7 @@ void GatherGradKernel(const Context& dev_ctx,
4445
}
4546
return;
4647
}
48+
4749
dev_ctx.template Alloc<T>(x_grad);
4850
auto dxt = EigenVector<T>::Flatten(*x_grad);
4951
auto& place = *dev_ctx.eigen_device();

0 commit comments

Comments
 (0)