Skip to content

Commit 3c7b550

Browse files
committed
mean accuracy fix
1 parent 4aee08b commit 3c7b550

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

paddle/phi/kernels/reduce_mean_kernel.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/all_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/cast_kernel.h"
1920
#include "paddle/phi/kernels/reduce_kernel_impl.h"
2021

2122
namespace phi {
@@ -27,7 +28,24 @@ void MeanKernel(const Context& dev_ctx,
2728
bool keep_dim,
2829
DenseTensor* out) {
2930
bool reduce_all = recompute_reduce_all(x, dims);
30-
MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
31+
if (std::is_same<T, int>::value || std::is_same<T, int64_t>::value ||
32+
std::is_same<T, bool>::value) {
33+
using Type =
34+
typename std::conditional<std::is_same<T, int>::value ||
35+
std::is_same<T, int64_t>::value ||
36+
std::is_same<T, bool>::value,
37+
float,
38+
T>::type;
39+
DenseTensor x_float =
40+
phi::Cast<T, Context>(dev_ctx, x, phi::DataType::FLOAT32);
41+
DenseTensor* out_float = new DenseTensor();
42+
out_float->Resize(out->dims());
43+
MeanRawKernel<Type>(
44+
dev_ctx, x_float, dims, keep_dim, reduce_all, out_float);
45+
phi::CastKernel<Type, Context>(dev_ctx, *out_float, x.dtype(), out);
46+
} else {
47+
MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
48+
}
3149
}
3250

3351
} // namespace phi

test/legacy_test/test_mean_op.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,64 @@ def test_errors(self):
806806
)
807807

808808

809+
class TestMeanAPIInt32(unittest.TestCase):
810+
def setUp(self):
811+
self.x_shape = [2, 3, 4, 5]
812+
self.dtype = "int32"
813+
self.x_np = np.random.randint(-1, 10000, self.x_shape).astype(
814+
self.dtype
815+
)
816+
self.places = [paddle.CPUPlace()]
817+
if core.is_compiled_with_cuda():
818+
self.places.append(paddle.CUDAPlace(0))
819+
820+
def test_dygraph(self):
821+
for place in self.places:
822+
with base.dygraph.guard(place):
823+
x = paddle.to_tensor(self.x_np)
824+
out = paddle.mean(x=x)
825+
np.testing.assert_equal(
826+
out.numpy(),
827+
np.mean(self.x_np.astype("float32")).astype(self.dtype),
828+
)
829+
830+
def test_static(self):
831+
paddle.enable_static()
832+
for place in self.places:
833+
with base.program_guard(base.Program(), base.Program()):
834+
x = paddle.static.data(
835+
"x", shape=self.x_shape, dtype=self.dtype
836+
)
837+
out = paddle.mean(x=x)
838+
exe = base.Executor(place)
839+
res = exe.run(feed={"x": self.x_np}, fetch_list=[out])
840+
np.testing.assert_equal(
841+
res[0], np.mean(self.x_np.astype("float32")).astype(self.dtype)
842+
)
843+
844+
845+
class TestMeanAPIInt64(TestMeanAPIInt32):
846+
def setUp(self):
847+
self.x_shape = [2, 3, 4, 5]
848+
self.dtype = "int64"
849+
self.x_np = np.random.randint(-1, 10000, self.x_shape).astype(
850+
self.dtype
851+
)
852+
self.places = [paddle.CPUPlace()]
853+
if core.is_compiled_with_cuda():
854+
self.places.append(paddle.CUDAPlace(0))
855+
856+
857+
class TestMeanAPIBool(TestMeanAPIInt32):
858+
def setUp(self):
859+
self.x_shape = [2, 3, 4, 5]
860+
self.dtype = "bool"
861+
self.x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype)
862+
self.places = [paddle.CPUPlace()]
863+
if core.is_compiled_with_cuda():
864+
self.places.append(paddle.CUDAPlace(0))
865+
866+
809867
class TestMeanWithTensorAxis1(TestReduceOPTensorAxisBase):
810868
def init_data(self):
811869
self.pd_api = paddle.mean

0 commit comments

Comments
 (0)