Skip to content

Commit 3efb8db

Browse files
authored
revert swiglu dx zero set (#73680)
1 parent f6222c9 commit 3efb8db

File tree

2 files changed

+2
-9
lines changed

2 files changed

+2
-9
lines changed

paddle/phi/kernels/swiglu_grad_kernel.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,6 @@ void SwiGLUGradKernel(const Context &dev_ctx,
4545
}
4646
return;
4747
}
48-
if (dy && dy->numel() == 0) {
49-
dev_ctx.template Alloc<T>(dy);
50-
if (dx) {
51-
phi::Full<T, Context>(
52-
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
53-
}
54-
return;
55-
}
5648
const auto *x_ptr = x.data<T>();
5749
const auto *dz_ptr = dz.data<T>();
5850
auto *dx_ptr = dx ? dev_ctx.template Alloc<T>(dx) : nullptr;

test/legacy_test/test_swiglu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def test_swiglu(self):
297297
self.assertEqual(out[1].shape, y.shape)
298298

299299

300+
'''
300301
class TestSwigluOp_ZeroSize(OpTest):
301302
def config(self):
302303
self.x_shape = (0, 128)
@@ -338,6 +339,6 @@ def config(self):
338339
self.y_shape = (0, 128)
339340
self.out_shape = (0, 128)
340341
341-
342+
'''
342343
if __name__ == "__main__":
343344
unittest.main()

0 commit comments

Comments
 (0)