@@ -63,6 +63,18 @@ def test_nonzero_api_as_tuple(self):
6363 expect_out = np .array ([0 , 1 ])
6464 np .testing .assert_allclose (expect_out , np .array (res ), rtol = 1e-05 )
6565
66+ data = np .zeros ([10 , 3 , 0 ], dtype = "float32" )
67+ with program_guard (Program (), Program ()):
68+ x = paddle .static .data (name = 'x' , shape = [10 , 3 , 0 ], dtype = 'float32' )
69+ if not paddle .framework .use_pir_api ():
70+ x .desc .set_need_check_feed (False )
71+ y = paddle .nonzero (x , as_tuple = True )
72+ self .assertEqual (type (y ), tuple )
73+ self .assertEqual (len (y ), 3 )
74+ expect_out = np .zeros ([0 ])
75+ for item in y :
76+ np .testing .assert_array_equal (expect_out , item )
77+
6678 def test_nonzero_api (self ):
6779 paddle .enable_static ()
6880 data = np .array ([[1 , 0 ], [0 , 1 ]], dtype = "float32" )
@@ -181,5 +193,26 @@ def return_outputs(self):
181193 return {'Out' : np .transpose (np .nonzero (self .inputs ['Condition' ]))}
182194
183195
196+ class TestZeroSizeOp (TestNonzeroOp ):
197+
198+ def init_shape (self ):
199+ self .shape = [0 , 10 ]
200+
201+ def init_dtype (self ):
202+ self .dtype = np .float64
203+
204+
205+ class TestZeroSizeOpCase2 (TestNonzeroOp ):
206+
207+ def init_shape (self ):
208+ self .shape = [0 , 10 ]
209+
210+ def init_dtype (self ):
211+ self .dtype = np .float64
212+
213+ def test_check_output (self ):
214+ self .check_output (check_pir = True , check_symbol_infer = True )
215+
216+
184217if __name__ == "__main__" :
185218 unittest .main ()
0 commit comments