Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ bool BincountOpInferSymbolicShape(
"The 'shape' of Input(Weights) must be 1-D tensor. "
"But the dimension of Input(Weights) is [%d]",
weights_dims.size()));
infer_context->AddEqualCstr(weights_dims[0], x_dims[0]);
if (x_dims[0] != 0) {
infer_context->AddEqualCstr(weights_dims[0], x_dims[0]);
}
}

symbol::DimExpr out_unknown = infer_context->GetNextSymName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,32 @@ bool AddNOpInferSymbolicShape(pir::Operation *op,
"should be larger than 0. But received X's dimensions %d.",
inputs_shape.size()));
symbol::TensorShapeOrDataDimExprs candidate_shape = inputs_shape.front();
std::vector<symbol::DimExpr> candidate_shape_vec = candidate_shape.shape();
for (size_t i = 1; i < inputs_shape.size(); ++i) {
// 0D tensor
if (inputs_shape[i].shape().size() == 0) {
continue;
}
if (candidate_shape.shape().size() == 0) {
candidate_shape = inputs_shape[i];
candidate_shape_vec = candidate_shape.shape();
continue;
}
for (size_t j = 0; j < candidate_shape.shape().size(); ++j) {
infer_context->AddEqualCstr(candidate_shape.shape()[j],
inputs_shape[i].shape()[j]);
for (size_t j = 0; j < candidate_shape_vec.size(); ++j) {
if (candidate_shape_vec[j] != 0) {
if (inputs_shape[i].shape()[j] != 0) {
infer_context->AddEqualCstr(candidate_shape_vec[j],
inputs_shape[i].shape()[j]);
} else {
candidate_shape_vec[j] = symbol::DimExpr{0};
}
}
}
}
infer_context->SetShapeOrDataForValue(
op->result(0), symbol::ShapeOrDataDimExprs{candidate_shape});
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(candidate_shape_vec)});

return true;
}
Expand Down
22 changes: 12 additions & 10 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,18 @@ void BincountInferMeta(const MetaTensor& x,
"But the dimension of Input(Weights) is [%d]",
weights_dim.size()));

PADDLE_ENFORCE_EQ(
weights_dim[0],
input_dim[0],
common::errors::InvalidArgument(
"The 'shape' of Input(Weights) must be equal to the 'shape' of "
"Input(X)."
"But received: the 'shape' of Input(Weights) is [%s],"
"the 'shape' of Input(X) is [%s]",
weights_dim,
input_dim));
if (input_dim[0] != 0) {
PADDLE_ENFORCE_EQ(
weights_dim[0],
input_dim[0],
common::errors::InvalidArgument(
"The 'shape' of Input(Weights) must be equal to the 'shape' of "
"Input(X)."
"But received: the 'shape' of Input(Weights) is [%s],"
"the 'shape' of Input(X) is [%s]",
weights_dim,
input_dim));
}
}
out->set_dims(common::make_ddim({-1}));
if (weights) {
Expand Down
31 changes: 21 additions & 10 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,19 +415,30 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
continue;
}
is_all_0d_tensor = false;
if (common::product(in_dim) == 0) {
// use the first dimension
if (i == 0) {
in_dim = x_dim;
} else {
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(in_dim,
x_dim,
common::errors::InvalidArgument(
"The input tensor X of AddNOp must"
" have same shape. But received X[0]'s shape = "
"[%s], X[%d]'s shape = [%s].",
in_dim,
i,
x_dim));
for (int j = 0; j < x_dim.size(); ++j) {
if (in_dim[j] != 0) {
if (x_dim[j] == 0) {
// update the 0 dim
in_dim[j] = 0;
} else {
PADDLE_ENFORCE_EQ(
in_dim[j],
x_dim[j],
common::errors::InvalidArgument(
"The input tensor X of AddNOp must"
" have same shape. But received X[0]'s shape = "
"[%s], X[%d]'s shape = [%s].",
in_dim,
i,
x_dim));
}
}
}
} else {
PADDLE_ENFORCE_EQ(
in_dim.size(),
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/cpu/add_n_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ void AddNKernel(const Context& dev_ctx,
DenseTensor* out) {
size_t in_num = x.size();
dev_ctx.template Alloc<T>(out);
if (out && out->numel() == 0) {
return;
}

bool in_place = false;
if (!x.empty() && x[0]->initialized() && DenseTensor::classof(x[0])) {
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/kernels/cpu/bincount_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {
Expand All @@ -33,9 +34,11 @@ void BincountInner(const Context& dev_ctx,
auto input_numel = input->numel();

if (input_data == nullptr) {
phi::DDim out_dim{0};
phi::DDim out_dim{minlength};
output->Resize(out_dim);
dev_ctx.template Alloc<InputT>(output);
// Since minlength may >0 , so fill with 0.
phi::Full<InputT, Context>(
dev_ctx, phi::IntArray(common::vectorize(output->dims())), 0, output);
return;
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ inline void CompareAllKernelImpl(const Context& dev_ctx,

if (x.dims() != y.dims()) {
out_data[0] = false;
} else if (x.numel() == 0) { // shape equal and numel is 0, return true
out_data[0] = true;
} else {
DenseTensor tmp;
tmp.Resize(x.dims());
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/gpu/add_n_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ template <typename T, typename Context>
void AddNKernel(const Context &dev_ctx,
const std::vector<const TensorBase *> &x,
DenseTensor *out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
const size_t in_num = x.size();
for (int i = 0; i < in_num; ++i) {
PADDLE_ENFORCE_EQ(
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/kernels/gpu/bincount_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

using phi::PADDLE_CUDA_NUM_THREADS;
Expand Down Expand Up @@ -99,9 +99,10 @@ void BincountCUDAInner(const Context& dev_ctx,
int64_t input_numel = static_cast<int64_t>(input->numel());

if (input_data == nullptr) {
phi::DDim out_dim{0};
phi::DDim out_dim{minlength};
output->Resize(out_dim);
dev_ctx.template Alloc<T>(output);
phi::Full<InputT, Context>(
dev_ctx, phi::IntArray(common::vectorize(output->dims())), 0, output);
return;
}

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ inline void CompareAllKernelImpl(const Context& ctx,
thrust::fill(out_dev_ptr, out_dev_ptr + 1, false);
return;
}
// shape equal and numel is 0, return true
if (x.numel() == 0) {
thrust::device_ptr<bool> out_dev_ptr(out_data);
thrust::fill(out_dev_ptr, out_dev_ptr + 1, true);
return;
}

DenseTensor tmp;
tmp.Resize(x.dims());
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/add_n_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void AddNKernel(const Context& dev_ctx,
using XPUType = typename XPUTypeTrait<T>::Type;
size_t in_num = x.size();
dev_ctx.template Alloc<T>(out);
if (out && out->numel() == 0) return;

bool in_place = false;
if (x.size() > 0 && x[0]->initialized() && DenseTensor::classof(x[0])) {
Expand Down
36 changes: 36 additions & 0 deletions test/legacy_test/test_add_n_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,41 @@ def test_add_n_api(self):
np.testing.assert_allclose(y_np_32, y_np_gt, rtol=1e-06)


class TestAddnOp_ZeroSize(unittest.TestCase):
def setUp(self):
np.random.seed(20)
self.l = 2
self.x_np = np.random.random([self.l, 0, 256])

def check_main(self, x_np, dtype, axis=None, mixed_dtype=False):
paddle.disable_static()
x = []
for i in range(x_np.shape[0]):
if mixed_dtype and i == 0:
val = paddle.to_tensor(x_np[i].astype('float32'))
else:
val = paddle.to_tensor(x_np[i].astype(dtype))
val.stop_gradient = False
x.append(val)

y = paddle.add_n(x)
x_g = paddle.grad(y, x)
y_np = y.numpy().astype(dtype)
x_g_np = []
for val in x_g:
x_g_np.append(val.numpy().astype(dtype))
paddle.enable_static()
return y_np, x_g_np

def test_add_n_zerosize(self):
if not paddle.is_compiled_with_cuda():
return
y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32')

np.testing.assert_allclose(y_np_32.shape, [0, 256])
for i in range(len(x_g_np_32)):
np.testing.assert_allclose(x_g_np_32[i].shape, [0, 256])


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions test/legacy_test/test_bincount_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,5 +318,12 @@ def test_static_and_infer(self):
np.testing.assert_allclose(static_out[0], infer_out)


class TestBincountOp_ZeroSize(TestBincountOp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestBincountOp的def test_check_output(self): self.check_output(check_pir=True, check_symbol_infer=False) 。没有检查symbol_infer导致覆盖不到paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc:285新增的代码。请重载这个方法并开启check_symbol_infer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
bincount使用的是out_unknown,在kernel运行时推导shape,check_symbol_infer设置为True是失败

def init_test_case(self):
self.minlength = 0
self.np_input = np.random.randint(low=0, high=20, size=0)
self.Out = np.bincount(self.np_input, minlength=self.minlength)


if __name__ == "__main__":
unittest.main()
20 changes: 20 additions & 0 deletions test/legacy_test/test_compare_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,32 @@ def test_output(self):
globals()[cls_name] = Cls


def create_test_equal_class_zero_size(op_type, typename, callback):
class Cls(op_test.OpTest):
def setUp(self):
x = np.random.random(size=(0, 7)).astype(typename)
y = np.random.random(size=(0, 7)).astype(typename)
z = callback(x, y)
self.python_api = paddle.tensor.equal_all
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z}
self.op_type = op_type

def test_output(self):
self.check_output(check_pir=True)

cls_name = "{}_{}_{}".format(op_type, typename, 'equal_all_zero_size')
Cls.__name__ = cls_name
globals()[cls_name] = Cls


np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y))

for _type_name in {'float32', 'float64', 'int32', 'int64', 'bool'}:
create_test_not_equal_class('equal_all', _type_name, np_equal)
create_test_equal_class('equal_all', _type_name, np_equal)
create_test_dim1_class('equal_all', _type_name, np_equal)
create_test_equal_class_zero_size('equal_all', _type_name, np_equal)


class TestEqualReduceAPI(unittest.TestCase):
Expand Down
Loading