Skip to content

Commit 0aaabf9

Browse files
committed
Fix
1 parent c3355fb commit 0aaabf9

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

paddle/phi/kernels/swiglu_grad_kernel.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/core/dense_tensor.h"
1818
#include "paddle/phi/core/device_context.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920

2021
namespace phi {
2122

@@ -36,12 +37,19 @@ void SwiGLUGradKernel(const Context &dev_ctx,
3637
const DenseTensor &dz,
3738
DenseTensor *dx,
3839
DenseTensor *dy) {
39-
if (x.numel() == 0) {
40-
if (dx) {
41-
dev_ctx.template Alloc<T>(dx);
42-
}
40+
if (dx && dx->numel() == 0) {
41+
dev_ctx.template Alloc<T>(dx);
4342
if (dy) {
44-
dev_ctx.template Alloc<T>(dy);
43+
phi::Full<T, Context>(
44+
dev_ctx, phi::IntArray(common::vectorize(dy->dims())), 0, dy);
45+
}
46+
return;
47+
}
48+
if (dy && dy->numel() == 0) {
49+
dev_ctx.template Alloc<T>(dy);
50+
if (dx) {
51+
phi::Full<T, Context>(
52+
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
4553
}
4654
return;
4755
}

0 commit comments

Comments
 (0)