@@ -806,6 +806,64 @@ def test_errors(self):
806806 )
807807
808808
809+ class TestMeanAPIInt32 (unittest .TestCase ):
810+ def setUp (self ):
811+ self .x_shape = [2 , 3 , 4 , 5 ]
812+ self .dtype = "int32"
813+ self .x_np = np .random .randint (- 1 , 10000 , self .x_shape ).astype (
814+ self .dtype
815+ )
816+ self .places = [paddle .CPUPlace ()]
817+ if core .is_compiled_with_cuda ():
818+ self .places .append (paddle .CUDAPlace (0 ))
819+
820+ def test_dygraph (self ):
821+ for place in self .places :
822+ with base .dygraph .guard (place ):
823+ x = paddle .to_tensor (self .x_np )
824+ out = paddle .mean (x = x )
825+ np .testing .assert_equal (
826+ out .numpy (),
827+ np .mean (self .x_np .astype ("float32" )).astype (self .dtype ),
828+ )
829+
830+ def test_static (self ):
831+ paddle .enable_static ()
832+ for place in self .places :
833+ with base .program_guard (base .Program (), base .Program ()):
834+ x = paddle .static .data (
835+ "x" , shape = self .x_shape , dtype = self .dtype
836+ )
837+ out = paddle .mean (x = x )
838+ exe = base .Executor (place )
839+ res = exe .run (feed = {"x" : self .x_np }, fetch_list = [out ])
840+ np .testing .assert_equal (
841+ res [0 ], np .mean (self .x_np .astype ("float32" )).astype (self .dtype )
842+ )
843+
844+
845+ class TestMeanAPIInt64 (TestMeanAPIInt32 ):
846+ def setUp (self ):
847+ self .x_shape = [2 , 3 , 4 , 5 ]
848+ self .dtype = "int64"
849+ self .x_np = np .random .randint (- 1 , 10000 , self .x_shape ).astype (
850+ self .dtype
851+ )
852+ self .places = [paddle .CPUPlace ()]
853+ if core .is_compiled_with_cuda ():
854+ self .places .append (paddle .CUDAPlace (0 ))
855+
856+
857+ class TestMeanAPIBool (TestMeanAPIInt32 ):
858+ def setUp (self ):
859+ self .x_shape = [2 , 3 , 4 , 5 ]
860+ self .dtype = "bool"
861+ self .x_np = np .random .uniform (- 1 , 1 , self .x_shape ).astype (self .dtype )
862+ self .places = [paddle .CPUPlace ()]
863+ if core .is_compiled_with_cuda ():
864+ self .places .append (paddle .CUDAPlace (0 ))
865+
866+
809867class TestMeanWithTensorAxis1 (TestReduceOPTensorAxisBase ):
810868 def init_data (self ):
811869 self .pd_api = paddle .mean
0 commit comments