Skip to content

Commit 47fa806

Browse files
authored
【0D output】support 0D output for matrix_rank/multi_dot (#52861)
* support_0D_output_for_matrix_rank_multi_dot, test=allcase * add 0D output test for matrox_rank and mutli_dot test=allcase * fix assert error ,test=allcase * fix test error, test=allcase * fix other test error, test=allcase * fix other test error, test=allcase * fix test error, test=allcase * fix matrix_rank and multi dot test err test=allcase * fix test error test=allcase * fix test zero dim test, test=allcase * add static backward test for multi_dot, test=allcase * add tol 2d broadcast test case, test=allcase
1 parent 07878a3 commit 47fa806

File tree

4 files changed

+117
-3
lines changed

4 files changed

+117
-3
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
7272
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
7373
auto x_vec = phi::vectorize(dim_x);
7474
if (x_vec.size() == 2) {
75-
return phi::make_ddim({1});
75+
return phi::make_ddim({});
7676
}
7777
x_vec.erase(x_vec.end() - 2, x_vec.end());
7878
return phi::make_ddim(x_vec);

paddle/phi/infermeta/multiary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2345,7 +2345,7 @@ void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
23452345
// If the last tensor is 1D of size n view it as a column vector (n, 1)
23462346
if (last_dim.size() == 1) {
23472347
last_dim = phi::make_ddim({static_cast<int>(last_dim[0]), 1});
2348-
out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]});
2348+
out_dim = is_vector ? phi::make_ddim({}) : phi::make_ddim({first_dim[0]});
23492349
} else {
23502350
out_dim = is_vector ? phi::make_ddim({last_dim[1]})
23512351
: phi::make_ddim({first_dim[0], last_dim[1]});

paddle/phi/infermeta/unary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ namespace detail {
3838
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
3939
auto x_vec = phi::vectorize(dim_x);
4040
if (x_vec.size() == 2) {
41-
return phi::make_ddim({1});
41+
return phi::make_ddim({});
4242
}
4343
x_vec.erase(x_vec.end() - 2, x_vec.end());
4444
return phi::make_ddim(x_vec);

python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,6 +2123,23 @@ def test_linalg_slogdet(self):
21232123
self.assertTrue(out1.shape, [2, 3])
21242124
self.assertTrue(x1.grad.shape, [3, 3, 3])
21252125

2126+
def test_multi_dot(self):
2127+
a = paddle.randn([4])
2128+
a.stop_gradient = False
2129+
b = paddle.randn([4, 5])
2130+
b.stop_gradient = False
2131+
c = paddle.randn([5])
2132+
c.stop_gradient = False
2133+
2134+
out = paddle.linalg.multi_dot([a, b, c])
2135+
out.retain_grads()
2136+
out.backward()
2137+
2138+
self.assertEqual(out.shape, [])
2139+
self.assertEqual(a.grad.shape, [4])
2140+
self.assertEqual(b.grad.shape, [4, 5])
2141+
self.assertEqual(c.grad.shape, [5])
2142+
21262143

21272144
class TestSundryAPIStatic(unittest.TestCase):
21282145
def setUp(self):
@@ -3710,6 +3727,26 @@ def test_linalg_slogdet(self):
37103727
self.assertEqual(res[0].shape, (2, 3))
37113728
self.assertEqual(res[1].shape, (3, 3, 3))
37123729

3730+
@prog_scope()
3731+
def test_multi_dot(self):
3732+
a = paddle.randn([4])
3733+
a.stop_gradient = False
3734+
b = paddle.randn([4, 5])
3735+
b.stop_gradient = False
3736+
c = paddle.randn([5])
3737+
c.stop_gradient = False
3738+
3739+
out = paddle.linalg.multi_dot([a, b, c])
3740+
paddle.static.append_backward(out.sum())
3741+
prog = paddle.static.default_main_program()
3742+
res = self.exe.run(
3743+
prog, fetch_list=[out, a.grad_name, b.grad_name, c.grad_name]
3744+
)
3745+
self.assertEqual(res[0].shape, ())
3746+
self.assertEqual(res[1].shape, (4,))
3747+
self.assertEqual(res[2].shape, (4, 5))
3748+
self.assertEqual(res[3].shape, (5,))
3749+
37133750

37143751
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
37153752
class TestNoBackwardAPI(unittest.TestCase):
@@ -3901,6 +3938,38 @@ def test_unique(self):
39013938
self.assertEqual(inverse.shape, [1])
39023939
self.assertEqual(counts.shape, [1])
39033940

3941+
def test_matrix_rank(self):
3942+
x = paddle.eye(10)
3943+
x.stop_gradient = False
3944+
out = paddle.linalg.matrix_rank(x)
3945+
3946+
self.assertEqual(out.shape, [])
3947+
np.testing.assert_equal(out, np.array(10))
3948+
3949+
c = paddle.ones(shape=[3, 4, 5])
3950+
c.stop_gradient = False
3951+
out_c = paddle.linalg.matrix_rank(c)
3952+
self.assertEqual(out_c.shape, [3])
3953+
np.testing.assert_equal(out_c, np.array([1, 1, 1]))
3954+
3955+
# 2D, tol->float : OUTPUT 0D
3956+
x_tol = paddle.eye(10)
3957+
x_tol.stop_gradient = False
3958+
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
3959+
self.assertEqual(out_tol.shape, [])
3960+
3961+
# 3D, tol->float : OUTPUT 1D
3962+
c_tol = paddle.ones(shape=[3, 4, 5])
3963+
c_tol.stop_gradient = False
3964+
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
3965+
self.assertEqual(out_c_tol.shape, [3])
3966+
3967+
tol_2 = paddle.randn([2])
3968+
# 2D, tol->Tensor[1,2] : OUTPUT 1D
3969+
d = paddle.eye(10)
3970+
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
3971+
self.assertEqual(out_d.shape, [2])
3972+
39043973

39053974
class TestNoBackwardAPIStatic(unittest.TestCase):
39063975
def setUp(self):
@@ -4135,6 +4204,51 @@ def test_unique(self):
41354204
self.assertEqual(res[2].shape, (1,))
41364205
self.assertEqual(res[3].shape, (1,))
41374206

4207+
@prog_scope()
4208+
def test_static_matrix_rank(self):
4209+
# 2D : OUTPUT 0D
4210+
x = paddle.eye(10)
4211+
x.stop_gradient = False
4212+
out = paddle.linalg.matrix_rank(x)
4213+
prog = paddle.static.default_main_program()
4214+
res = self.exe.run(prog, fetch_list=[out])
4215+
self.assertEqual(res[0].shape, ())
4216+
4217+
# 3D : OUTPUT 1D
4218+
c = paddle.ones(shape=[3, 4, 5])
4219+
c.stop_gradient = False
4220+
out_c = paddle.linalg.matrix_rank(c)
4221+
prog = paddle.static.default_main_program()
4222+
self.exe.run(paddle.static.default_startup_program())
4223+
res = self.exe.run(prog, fetch_list=[out_c])
4224+
self.assertEqual(res[0].shape, (3,))
4225+
4226+
# 2D, tol->float : OUTPUT 0D
4227+
x_tol = paddle.eye(10)
4228+
x_tol.stop_gradient = False
4229+
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
4230+
prog = paddle.static.default_main_program()
4231+
res = self.exe.run(prog, fetch_list=[out_tol])
4232+
self.assertEqual(res[0].shape, ())
4233+
4234+
# 3D, tol->float : OUTPUT 1D
4235+
c_tol = paddle.ones(shape=[3, 4, 5])
4236+
c_tol.stop_gradient = False
4237+
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
4238+
prog = paddle.static.default_main_program()
4239+
self.exe.run(paddle.static.default_startup_program())
4240+
res = self.exe.run(prog, fetch_list=[out_c_tol])
4241+
self.assertEqual(res[0].shape, (3,))
4242+
4243+
tol_2 = paddle.randn([2])
4244+
# 2D, tol->Tensor[1,2] : OUTPUT 1D
4245+
d = paddle.eye(10)
4246+
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
4247+
prog = paddle.static.default_main_program()
4248+
self.exe.run(paddle.static.default_startup_program())
4249+
res = self.exe.run(prog, fetch_list=[out_d])
4250+
self.assertEqual(res[0].shape, (2,))
4251+
41384252

41394253
unary_apis_with_complex_input = [
41404254
paddle.real,

0 commit comments

Comments
 (0)