@@ -2123,6 +2123,23 @@ def test_linalg_slogdet(self):
21232123 self .assertTrue (out1 .shape , [2 , 3 ])
21242124 self .assertTrue (x1 .grad .shape , [3 , 3 , 3 ])
21252125
2126+ def test_multi_dot (self ):
2127+ a = paddle .randn ([4 ])
2128+ a .stop_gradient = False
2129+ b = paddle .randn ([4 , 5 ])
2130+ b .stop_gradient = False
2131+ c = paddle .randn ([5 ])
2132+ c .stop_gradient = False
2133+
2134+ out = paddle .linalg .multi_dot ([a , b , c ])
2135+ out .retain_grads ()
2136+ out .backward ()
2137+
2138+ self .assertEqual (out .shape , [])
2139+ self .assertEqual (a .grad .shape , [4 ])
2140+ self .assertEqual (b .grad .shape , [4 , 5 ])
2141+ self .assertEqual (c .grad .shape , [5 ])
2142+
21262143
21272144class TestSundryAPIStatic (unittest .TestCase ):
21282145 def setUp (self ):
@@ -3710,6 +3727,26 @@ def test_linalg_slogdet(self):
37103727 self .assertEqual (res [0 ].shape , (2 , 3 ))
37113728 self .assertEqual (res [1 ].shape , (3 , 3 , 3 ))
37123729
3730+ @prog_scope ()
3731+ def test_multi_dot (self ):
3732+ a = paddle .randn ([4 ])
3733+ a .stop_gradient = False
3734+ b = paddle .randn ([4 , 5 ])
3735+ b .stop_gradient = False
3736+ c = paddle .randn ([5 ])
3737+ c .stop_gradient = False
3738+
3739+ out = paddle .linalg .multi_dot ([a , b , c ])
3740+ paddle .static .append_backward (out .sum ())
3741+ prog = paddle .static .default_main_program ()
3742+ res = self .exe .run (
3743+ prog , fetch_list = [out , a .grad_name , b .grad_name , c .grad_name ]
3744+ )
3745+ self .assertEqual (res [0 ].shape , ())
3746+ self .assertEqual (res [1 ].shape , (4 ,))
3747+ self .assertEqual (res [2 ].shape , (4 , 5 ))
3748+ self .assertEqual (res [3 ].shape , (5 ,))
3749+
37133750
37143751# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
37153752class TestNoBackwardAPI (unittest .TestCase ):
@@ -3901,6 +3938,38 @@ def test_unique(self):
39013938 self .assertEqual (inverse .shape , [1 ])
39023939 self .assertEqual (counts .shape , [1 ])
39033940
3941+ def test_matrix_rank (self ):
3942+ x = paddle .eye (10 )
3943+ x .stop_gradient = False
3944+ out = paddle .linalg .matrix_rank (x )
3945+
3946+ self .assertEqual (out .shape , [])
3947+ np .testing .assert_equal (out , np .array (10 ))
3948+
3949+ c = paddle .ones (shape = [3 , 4 , 5 ])
3950+ c .stop_gradient = False
3951+ out_c = paddle .linalg .matrix_rank (c )
3952+ self .assertEqual (out_c .shape , [3 ])
3953+ np .testing .assert_equal (out_c , np .array ([1 , 1 , 1 ]))
3954+
3955+ # 2D, tol->float : OUTPUT 0D
3956+ x_tol = paddle .eye (10 )
3957+ x_tol .stop_gradient = False
3958+ out_tol = paddle .linalg .matrix_rank (x_tol , tol = 0.1 )
3959+ self .assertEqual (out_tol .shape , [])
3960+
3961+ # 3D, tol->float : OUTPUT 1D
3962+ c_tol = paddle .ones (shape = [3 , 4 , 5 ])
3963+ c_tol .stop_gradient = False
3964+ out_c_tol = paddle .linalg .matrix_rank (c_tol , tol = 0.1 )
3965+ self .assertEqual (out_c_tol .shape , [3 ])
3966+
3967+ tol_2 = paddle .randn ([2 ])
3968+ # 2D, tol->Tensor[1,2] : OUTPUT 1D
3969+ d = paddle .eye (10 )
3970+ out_d = paddle .linalg .matrix_rank (d , tol = tol_2 )
3971+ self .assertEqual (out_d .shape , [2 ])
3972+
39043973
39053974class TestNoBackwardAPIStatic (unittest .TestCase ):
39063975 def setUp (self ):
@@ -4135,6 +4204,51 @@ def test_unique(self):
41354204 self .assertEqual (res [2 ].shape , (1 ,))
41364205 self .assertEqual (res [3 ].shape , (1 ,))
41374206
4207+ @prog_scope ()
4208+ def test_static_matrix_rank (self ):
4209+ # 2D : OUTPUT 0D
4210+ x = paddle .eye (10 )
4211+ x .stop_gradient = False
4212+ out = paddle .linalg .matrix_rank (x )
4213+ prog = paddle .static .default_main_program ()
4214+ res = self .exe .run (prog , fetch_list = [out ])
4215+ self .assertEqual (res [0 ].shape , ())
4216+
4217+ # 3D : OUTPUT 1D
4218+ c = paddle .ones (shape = [3 , 4 , 5 ])
4219+ c .stop_gradient = False
4220+ out_c = paddle .linalg .matrix_rank (c )
4221+ prog = paddle .static .default_main_program ()
4222+ self .exe .run (paddle .static .default_startup_program ())
4223+ res = self .exe .run (prog , fetch_list = [out_c ])
4224+ self .assertEqual (res [0 ].shape , (3 ,))
4225+
4226+ # 2D, tol->float : OUTPUT 0D
4227+ x_tol = paddle .eye (10 )
4228+ x_tol .stop_gradient = False
4229+ out_tol = paddle .linalg .matrix_rank (x_tol , tol = 0.1 )
4230+ prog = paddle .static .default_main_program ()
4231+ res = self .exe .run (prog , fetch_list = [out_tol ])
4232+ self .assertEqual (res [0 ].shape , ())
4233+
4234+ # 3D, tol->float : OUTPUT 1D
4235+ c_tol = paddle .ones (shape = [3 , 4 , 5 ])
4236+ c_tol .stop_gradient = False
4237+ out_c_tol = paddle .linalg .matrix_rank (c_tol , tol = 0.1 )
4238+ prog = paddle .static .default_main_program ()
4239+ self .exe .run (paddle .static .default_startup_program ())
4240+ res = self .exe .run (prog , fetch_list = [out_c_tol ])
4241+ self .assertEqual (res [0 ].shape , (3 ,))
4242+
4243+ tol_2 = paddle .randn ([2 ])
4244+ # 2D, tol->Tensor[1,2] : OUTPUT 1D
4245+ d = paddle .eye (10 )
4246+ out_d = paddle .linalg .matrix_rank (d , tol = tol_2 )
4247+ prog = paddle .static .default_main_program ()
4248+ self .exe .run (paddle .static .default_startup_program ())
4249+ res = self .exe .run (prog , fetch_list = [out_d ])
4250+ self .assertEqual (res [0 ].shape , (2 ,))
4251+
41384252
41394253unary_apis_with_complex_input = [
41404254 paddle .real ,
0 commit comments