Skip to content

Commit 8132805

Browse files
authored
[0-size Tensor Job2 No.2、3、10] Add 0-size Tensor support for equal_all [fluid_ops] (#73550)
* Fix * Fix * Fix * Fix
1 parent 4ffa573 commit 8132805

File tree

14 files changed

+138
-30
lines changed

14 files changed

+138
-30
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,9 @@ bool BincountOpInferSymbolicShape(
282282
"The 'shape' of Input(Weights) must be 1-D tensor. "
283283
"But the dimension of Input(Weights) is [%d]",
284284
weights_dims.size()));
285-
infer_context->AddEqualCstr(weights_dims[0], x_dims[0]);
285+
if (x_dims[0] != 0) {
286+
infer_context->AddEqualCstr(weights_dims[0], x_dims[0]);
287+
}
286288
}
287289

288290
symbol::DimExpr out_unknown = infer_context->GetNextSymName();

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,32 @@ bool AddNOpInferSymbolicShape(pir::Operation *op,
8282
"should be larger than 0. But received X's dimensions %d.",
8383
inputs_shape.size()));
8484
symbol::TensorShapeOrDataDimExprs candidate_shape = inputs_shape.front();
85+
std::vector<symbol::DimExpr> candidate_shape_vec = candidate_shape.shape();
8586
for (size_t i = 1; i < inputs_shape.size(); ++i) {
8687
// 0D tensor
8788
if (inputs_shape[i].shape().size() == 0) {
8889
continue;
8990
}
9091
if (candidate_shape.shape().size() == 0) {
9192
candidate_shape = inputs_shape[i];
93+
candidate_shape_vec = candidate_shape.shape();
9294
continue;
9395
}
94-
for (size_t j = 0; j < candidate_shape.shape().size(); ++j) {
95-
infer_context->AddEqualCstr(candidate_shape.shape()[j],
96-
inputs_shape[i].shape()[j]);
96+
for (size_t j = 0; j < candidate_shape_vec.size(); ++j) {
97+
if (candidate_shape_vec[j] != 0) {
98+
if (inputs_shape[i].shape()[j] != 0) {
99+
infer_context->AddEqualCstr(candidate_shape_vec[j],
100+
inputs_shape[i].shape()[j]);
101+
} else {
102+
candidate_shape_vec[j] = symbol::DimExpr{0};
103+
}
104+
}
97105
}
98106
}
99107
infer_context->SetShapeOrDataForValue(
100-
op->result(0), symbol::ShapeOrDataDimExprs{candidate_shape});
108+
op->result(0),
109+
symbol::ShapeOrDataDimExprs{
110+
symbol::TensorShapeOrDataDimExprs(candidate_shape_vec)});
101111

102112
return true;
103113
}

paddle/phi/infermeta/binary.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -295,16 +295,18 @@ void BincountInferMeta(const MetaTensor& x,
295295
"But the dimension of Input(Weights) is [%d]",
296296
weights_dim.size()));
297297

298-
PADDLE_ENFORCE_EQ(
299-
weights_dim[0],
300-
input_dim[0],
301-
common::errors::InvalidArgument(
302-
"The 'shape' of Input(Weights) must be equal to the 'shape' of "
303-
"Input(X)."
304-
"But received: the 'shape' of Input(Weights) is [%s],"
305-
"the 'shape' of Input(X) is [%s]",
306-
weights_dim,
307-
input_dim));
298+
if (input_dim[0] != 0) {
299+
PADDLE_ENFORCE_EQ(
300+
weights_dim[0],
301+
input_dim[0],
302+
common::errors::InvalidArgument(
303+
"The 'shape' of Input(Weights) must be equal to the 'shape' of "
304+
"Input(X)."
305+
"But received: the 'shape' of Input(Weights) is [%s],"
306+
"the 'shape' of Input(X) is [%s]",
307+
weights_dim,
308+
input_dim));
309+
}
308310
}
309311
out->set_dims(common::make_ddim({-1}));
310312
if (weights) {

paddle/phi/infermeta/multiary.cc

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,19 +415,30 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
415415
continue;
416416
}
417417
is_all_0d_tensor = false;
418-
if (common::product(in_dim) == 0) {
418+
// use the first dimension
419+
if (i == 0) {
419420
in_dim = x_dim;
420421
} else {
421422
if (config.is_runtime) {
422-
PADDLE_ENFORCE_EQ(in_dim,
423-
x_dim,
424-
common::errors::InvalidArgument(
425-
"The input tensor X of AddNOp must"
426-
" have same shape. But received X[0]'s shape = "
427-
"[%s], X[%d]'s shape = [%s].",
428-
in_dim,
429-
i,
430-
x_dim));
423+
for (int j = 0; j < x_dim.size(); ++j) {
424+
if (in_dim[j] != 0) {
425+
if (x_dim[j] == 0) {
426+
// update the 0 dim
427+
in_dim[j] = 0;
428+
} else {
429+
PADDLE_ENFORCE_EQ(
430+
in_dim[j],
431+
x_dim[j],
432+
common::errors::InvalidArgument(
433+
"The input tensor X of AddNOp must"
434+
" have same shape. But received X[0]'s shape = "
435+
"[%s], X[%d]'s shape = [%s].",
436+
in_dim,
437+
i,
438+
x_dim));
439+
}
440+
}
441+
}
431442
} else {
432443
PADDLE_ENFORCE_EQ(
433444
in_dim.size(),

paddle/phi/kernels/cpu/add_n_kernel.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ void AddNKernel(const Context& dev_ctx,
2525
DenseTensor* out) {
2626
size_t in_num = x.size();
2727
dev_ctx.template Alloc<T>(out);
28+
if (out && out->numel() == 0) {
29+
return;
30+
}
2831

2932
bool in_place = false;
3033
if (!x.empty() && x[0]->initialized() && DenseTensor::classof(x[0])) {

paddle/phi/kernels/cpu/bincount_kernel.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/math_function.h"
2021

2122
namespace phi {
@@ -33,9 +34,11 @@ void BincountInner(const Context& dev_ctx,
3334
auto input_numel = input->numel();
3435

3536
if (input_data == nullptr) {
36-
phi::DDim out_dim{0};
37+
phi::DDim out_dim{minlength};
3738
output->Resize(out_dim);
38-
dev_ctx.template Alloc<InputT>(output);
39+
// Since minlength may >0 , so fill with 0.
40+
phi::Full<InputT, Context>(
41+
dev_ctx, phi::IntArray(common::vectorize(output->dims())), 0, output);
3942
return;
4043
}
4144

paddle/phi/kernels/cpu/compare_kernel.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ inline void CompareAllKernelImpl(const Context& dev_ctx,
7070

7171
if (x.dims() != y.dims()) {
7272
out_data[0] = false;
73+
} else if (x.numel() == 0) { // shape equal and numel is 0, return true
74+
out_data[0] = true;
7375
} else {
7476
DenseTensor tmp;
7577
tmp.Resize(x.dims());

paddle/phi/kernels/gpu/add_n_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ template <typename T, typename Context>
7979
void AddNKernel(const Context &dev_ctx,
8080
const std::vector<const TensorBase *> &x,
8181
DenseTensor *out) {
82+
if (out && out->numel() == 0) {
83+
dev_ctx.template Alloc<T>(out);
84+
return;
85+
}
8286
const size_t in_num = x.size();
8387
for (int i = 0; i < in_num; ++i) {
8488
PADDLE_ENFORCE_EQ(

paddle/phi/kernels/gpu/bincount_kernel.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
#include "paddle/phi/backends/gpu/gpu_context.h"
1818
#include "paddle/phi/backends/gpu/gpu_primitives.h"
1919
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/funcs/eigen/common.h"
2122
#include "paddle/phi/kernels/funcs/math_function.h"
22-
2323
namespace phi {
2424

2525
using phi::PADDLE_CUDA_NUM_THREADS;
@@ -99,9 +99,10 @@ void BincountCUDAInner(const Context& dev_ctx,
9999
int64_t input_numel = static_cast<int64_t>(input->numel());
100100

101101
if (input_data == nullptr) {
102-
phi::DDim out_dim{0};
102+
phi::DDim out_dim{minlength};
103103
output->Resize(out_dim);
104-
dev_ctx.template Alloc<T>(output);
104+
phi::Full<InputT, Context>(
105+
dev_ctx, phi::IntArray(common::vectorize(output->dims())), 0, output);
105106
return;
106107
}
107108

paddle/phi/kernels/kps/compare_kernel.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ inline void CompareAllKernelImpl(const Context& ctx,
8080
thrust::fill(out_dev_ptr, out_dev_ptr + 1, false);
8181
return;
8282
}
83+
// shape equal and numel is 0, return true
84+
if (x.numel() == 0) {
85+
thrust::device_ptr<bool> out_dev_ptr(out_data);
86+
thrust::fill(out_dev_ptr, out_dev_ptr + 1, true);
87+
return;
88+
}
8389

8490
DenseTensor tmp;
8591
tmp.Resize(x.dims());

0 commit comments

Comments
 (0)