Skip to content

Commit 2d085d2

Browse files
[BIT] nonzero (#72244) (#73010)
* nonzero-size * fix * fix * fix * fix * fix * fix Co-authored-by: 正在学习 <62892980+cszdrg@users.noreply.github.com>
1 parent 86bb9e7 commit 2d085d2

File tree

6 files changed

+74
-8
lines changed

6 files changed

+74
-8
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,13 +2366,27 @@ bool NonzeroOpInferSymbolicShape(
23662366
common::errors::InvalidArgument(
23672367
"Input(x) should have number of dimension at least 1."));
23682368

2369-
std::string sym_name = infer_context->GetNextSymName();
2370-
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{sym_name},
2371-
symbol::DimExpr{rank}};
2372-
2373-
symbol::ShapeOrDataDimExprs shape_data{
2374-
symbol::TensorShapeOrDataDimExprs(out_shape)};
2375-
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
2369+
bool zero = 0;
2370+
for (int i = 0; i < rank; i++) {
2371+
if (x_shape[i] == 0) {
2372+
zero = 1;
2373+
break;
2374+
}
2375+
}
2376+
if (zero) {
2377+
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{0},
2378+
symbol::DimExpr{rank}};
2379+
symbol::ShapeOrDataDimExprs shape_data{
2380+
symbol::TensorShapeOrDataDimExprs(out_shape)};
2381+
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
2382+
} else {
2383+
std::string sym_name = infer_context->GetNextSymName();
2384+
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{sym_name},
2385+
symbol::DimExpr{rank}};
2386+
symbol::ShapeOrDataDimExprs shape_data{
2387+
symbol::TensorShapeOrDataDimExprs(out_shape)};
2388+
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
2389+
}
23762390
return true;
23772391
}
23782392

paddle/phi/infermeta/unary.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2873,7 +2873,12 @@ void NonZeroInferMeta(const MetaTensor& condition, MetaTensor* out) {
28732873
1UL,
28742874
common::errors::InvalidArgument(
28752875
"Input(Condition) should have number of dimension at least 1"));
2876-
out->set_dims(common::make_ddim({-1, rank}));
2876+
if (condition.numel() == 0) {
2877+
out->set_dims(common::make_ddim({0, rank}));
2878+
} else {
2879+
out->set_dims(common::make_ddim({-1, rank}));
2880+
}
2881+
28772882
out->set_dtype(DataType::INT64);
28782883
}
28792884

paddle/phi/kernels/cpu/nonzero_kernel.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ void NonZeroKernel(const Context& dev_ctx,
5555
auto dims = condition.dims();
5656
const int rank = dims.size();
5757

58+
if (numel == 0) {
59+
dev_ctx.template Alloc<T>(out);
60+
return;
61+
}
62+
5863
std::vector<int64_t> true_index;
5964
for (auto i = 0; i < numel; i++) {
6065
if (static_cast<bool>(cond_data[i])) {

paddle/phi/kernels/gpu/nonzero_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ template <typename T, typename Context>
6565
void NonZeroKernel(const Context &dev_ctx,
6666
const DenseTensor &condition,
6767
DenseTensor *out) {
68+
if (condition.numel() == 0) {
69+
dev_ctx.template Alloc<T>(out);
70+
return;
71+
}
6872
DenseTensor in_data;
6973
auto dims = condition.dims();
7074
using Functor = IndexFunctor<T, int64_t, int64_t>;

paddle/phi/kernels/xpu/nonzero_kernel.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ void NonZeroKernel(const Context& dev_ctx,
3232

3333
using XPUType = typename XPUTypeTrait<T>::Type;
3434

35+
if (numel == 0) {
36+
dev_ctx.template Alloc<T>(out);
37+
return;
38+
}
39+
3540
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
3641
int64_t* true_num = RAII_GUARD.alloc_l3_or_gm<int64_t>(1);
3742
int64_t* workspace =

test/legacy_test/test_nonzero_api.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,18 @@ def test_nonzero_api_as_tuple(self):
6363
expect_out = np.array([0, 1])
6464
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
6565

66+
data = np.zeros([10, 3, 0], dtype="float32")
67+
with program_guard(Program(), Program()):
68+
x = paddle.static.data(name='x', shape=[10, 3, 0], dtype='float32')
69+
if not paddle.framework.use_pir_api():
70+
x.desc.set_need_check_feed(False)
71+
y = paddle.nonzero(x, as_tuple=True)
72+
self.assertEqual(type(y), tuple)
73+
self.assertEqual(len(y), 3)
74+
expect_out = np.zeros([0])
75+
for item in y:
76+
np.testing.assert_array_equal(expect_out, item)
77+
6678
def test_nonzero_api(self):
6779
paddle.enable_static()
6880
data = np.array([[1, 0], [0, 1]], dtype="float32")
@@ -181,5 +193,26 @@ def return_outputs(self):
181193
return {'Out': np.transpose(np.nonzero(self.inputs['Condition']))}
182194

183195

196+
class TestZeroSizeOp(TestNonzeroOp):
197+
198+
def init_shape(self):
199+
self.shape = [0, 10]
200+
201+
def init_dtype(self):
202+
self.dtype = np.float64
203+
204+
205+
class TestZeroSizeOpCase2(TestNonzeroOp):
206+
207+
def init_shape(self):
208+
self.shape = [0, 10]
209+
210+
def init_dtype(self):
211+
self.dtype = np.float64
212+
213+
def test_check_output(self):
214+
self.check_output(check_pir=True, check_symbol_infer=True)
215+
216+
184217
if __name__ == "__main__":
185218
unittest.main()

0 commit comments

Comments
 (0)