Skip to content

Commit 83531ae

Browse files
authored
reinforce send_ue_recv for big tensor (#74213)
1 parent aa15502 commit 83531ae

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

paddle/phi/kernels/gpu/graph_send_recv_funcs.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ __global__ void GraphSendRecvCUDAKernel(const T* params,
6161
const IndexT* src_indices,
6262
const IndexT* dst_indices,
6363
T* output,
64-
size_t index_size,
65-
size_t slice_size,
64+
int64_t index_size,
65+
int64_t slice_size,
6666
Functor functor) {
6767
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
6868
int64_t indices_i = i / slice_size;
@@ -78,8 +78,8 @@ __global__ void GraphSendRecvCUDAKernel(const T* params,
7878
// For max
7979
template <typename T>
8080
__global__ void InputResetMaxCUDAKernel(T* output,
81-
size_t input_size,
82-
size_t slice_size) {
81+
int64_t input_size,
82+
int64_t slice_size) {
8383
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
8484
if (*(output + i) == std::numeric_limits<T>::lowest()) {
8585
*(output + i) = 0;
@@ -90,8 +90,8 @@ __global__ void InputResetMaxCUDAKernel(T* output,
9090
// For min
9191
template <typename T>
9292
__global__ void InputResetMinCUDAKernel(T* output,
93-
size_t input_size,
94-
size_t slice_size) {
93+
int64_t input_size,
94+
int64_t slice_size) {
9595
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
9696
if (*(output + i) == std::numeric_limits<T>::max()) {
9797
*(output + i) = 0;
@@ -130,8 +130,8 @@ __global__ void ManipulateMeanGradCUDAKernel(const T* params,
130130
const IndexT* src_indices,
131131
const IndexT* dst_indices,
132132
T* output,
133-
size_t index_size,
134-
size_t slice_size,
133+
int64_t index_size,
134+
int64_t slice_size,
135135
const int32_t* dst_count) {
136136
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
137137
int64_t indices_i = i / slice_size;

paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ __global__ void GraphSendUERecvCUDAKernel(const T* x_data,
138138
bool use_bcast,
139139
ComputeFunctor cfunctor,
140140
ReduceFunctor rfunctor) {
141-
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
142-
const IndexT stride_y = blockDim.y * gridDim.y;
141+
IndexT ty = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;
142+
const IndexT stride_y = static_cast<IndexT>(blockDim.y) * gridDim.y;
143143

144144
while (ty < index_size) {
145145
IndexT src = src_indices[ty];
146146
IndexT dst = dst_indices[ty];
147-
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
148-
int64_t stride_x = blockDim.x * gridDim.x;
147+
int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
148+
int64_t stride_x = blockDim.x * static_cast<int64_t>(gridDim.x);
149149

150150
const T* x_off = x_data + src * x_len;
151151
const T* e_off = e_data + ty * e_len;

0 commit comments

Comments
 (0)