Skip to content

Commit 02268b1

Browse files
committed
Fix
1 parent 3003eb2 commit 02268b1

File tree

4 files changed

+22
-2
lines changed

4 files changed

+22
-2
lines changed

paddle/phi/kernels/cpu/send_u_recv_grad_kernel.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "paddle/phi/core/kernel_registry.h"
2121
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
22+
#include "paddle/phi/kernels/full_kernel.h"
2223

2324
namespace phi {
2425

@@ -131,6 +132,11 @@ void SendURecvGradKernel(const Context& dev_ctx,
131132
dev_ctx.template Alloc<T>(x_grad);
132133
return;
133134
}
135+
if (src_index.numel() == 0) {
136+
phi::Full<T, Context>(
137+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
138+
return;
139+
}
134140
auto index_type = src_index.dtype();
135141
if (index_type == phi::DataType::INT32) {
136142
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(

paddle/phi/kernels/cpu/send_u_recv_kernel.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,11 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& dev_ctx,
112112
T* p_output = out->data<T>();
113113
const size_t& memset_bytes = memset_size * sizeof(T);
114114
memset(p_output, 0, memset_bytes);
115-
if (x.numel() == 0) {
115+
if (x.numel() == 0 || src_index.numel() == 0) {
116+
if (out->numel() != 0) {
117+
phi::Full<T, Context>(
118+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
119+
}
116120
if (dst_count) {
117121
int64_t input_size = out_size <= 0 ? src_dims[0] : out_size;
118122
// dst_count shape [-1], need to Resize

paddle/phi/kernels/gpu/send_u_recv_grad_kernel.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/common/hostdevice.h"
2121
#include "paddle/phi/backends/gpu/gpu_context.h"
2222
#include "paddle/phi/core/kernel_registry.h"
23+
#include "paddle/phi/kernels/full_kernel.h"
2324
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
2425

2526
namespace phi {
@@ -108,6 +109,11 @@ void SendURecvGradKernel(const Context& dev_ctx,
108109
dev_ctx.template Alloc<T>(x_grad);
109110
return;
110111
}
112+
if (src_index.numel() == 0) {
113+
phi::Full<T, Context>(
114+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
115+
return;
116+
}
111117
auto index_type = src_index.dtype();
112118
if (index_type == phi::DataType::INT32) {
113119
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(

paddle/phi/kernels/gpu/send_u_recv_kernel.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& dev_ctx,
7070
} else if (reduce_op == "MIN") {
7171
constant_functor(dev_ctx, out, std::numeric_limits<T>::max());
7272
}
73-
if (x.numel() == 0) {
73+
if (x.numel() == 0 || src_index.numel() == 0) {
74+
if (out->numel() != 0) {
75+
phi::Full<T, Context>(
76+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
77+
}
7478
if (dst_count) {
7579
int64_t input_size = out_size <= 0 ? src_dims[0] : out_size;
7680
// dst_count shape [-1], need to Resize

0 commit comments

Comments
 (0)