Skip to content
13 changes: 13 additions & 0 deletions test/legacy_test/test_lu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,19 @@ def test_zero_size(self):
self._test_dygraph()


class TestLUAPI_ZeroSize(unittest.TestCase):
def test_zero_size1(self):
self.x_shape = (2, 0, 12)
self.dtype = "float32"
paddle.disable_static()
a = np.random.randn(*self.x_shape)
x = paddle.to_tensor(a, dtype=self.dtype, stop_gradient=False)
lu, p, info = paddle.linalg.lu(x, get_infos=True)
loss = lu.sum()
loss.backward()
self.assertEqual(x.grad.shape, x.shape)


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Loading