Skip to content
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5333,7 +5333,7 @@ void SendUERecvInferMeta(const MetaTensor& x,
dst_index_dims.size()));
}

if (src_index_dims[0] != 0) {
if (src_index_dims[0] != 0 && dst_index_dims[0] != 0) {
PADDLE_ENFORCE_EQ(
src_index_dims[0],
dst_index_dims[0],
Expand Down Expand Up @@ -5421,7 +5421,7 @@ void SendUVInferMeta(const MetaTensor& x,
dst_index_dims.size()));
}

if (src_index_dims[0] != 0) {
if (src_index_dims[0] != 0 && dst_index_dims[0] != 0) {
PADDLE_ENFORCE_EQ(
src_index_dims[0],
dst_index_dims[0],
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2517,7 +2517,7 @@ void SendURecvInferMeta(const MetaTensor& x,
dst_index_dims.size()));
}

if (src_index_dims[0] != 0) {
if (src_index_dims[0] != 0 && dst_index_dims[0] != 0) {
PADDLE_ENFORCE_EQ(
src_index_dims[0],
dst_index_dims[0],
Expand Down
16 changes: 13 additions & 3 deletions paddle/phi/kernels/cpu/lu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"

#include "paddle/phi/kernels/impl/lu_kernel_impl.h"
Expand All @@ -35,9 +36,18 @@ void LUKernel(const Context& dev_ctx,
"but got pivots=False"));

if (x.numel() == 0) {
dev_ctx.template Alloc<int>(infos);
dev_ctx.template Alloc<int>(pivots);
dev_ctx.template Alloc<T>(out);
phi::Full<int, Context>(dev_ctx,
phi::IntArray(common::vectorize(infos->dims())),
static_cast<int>(0),
infos);
phi::Full<int, Context>(dev_ctx,
phi::IntArray(common::vectorize(pivots->dims())),
static_cast<int>(0),
pivots);
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(out->dims())),
static_cast<T>(0),
out);
return;
}
*out = Transpose2DTo6D<Context, T>(dev_ctx, x);
Expand Down
16 changes: 13 additions & 3 deletions paddle/phi/kernels/gpu/lu_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"

#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/impl/lu_kernel_impl.h"
Expand Down Expand Up @@ -198,9 +199,18 @@ void LUKernel(const Context& dev_ctx,
::common::errors::PreconditionNotMet(
"Invalid input x dimensionality: %d (expected ≥2)", x.dims().size()));
if (x.numel() == 0) {
dev_ctx.template Alloc<int>(infos);
dev_ctx.template Alloc<int>(pivots);
dev_ctx.template Alloc<T>(out);
phi::Full<int, Context>(dev_ctx,
phi::IntArray(common::vectorize(infos->dims())),
static_cast<int>(0),
infos);
phi::Full<int, Context>(dev_ctx,
phi::IntArray(common::vectorize(pivots->dims())),
static_cast<int>(0),
pivots);
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(out->dims())),
static_cast<T>(0),
out);
return;
}
int64_t largest_matrix = (1LL << 31) - 1;
Expand Down
43 changes: 43 additions & 0 deletions test/legacy_test/test_lu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import paddle
from paddle import base
from paddle.base import core


def scipy_lu(A, pivot):
Expand Down Expand Up @@ -383,6 +384,48 @@ def test_zero_size1(self):
self.assertEqual(x.grad.shape, x.shape)


class TestLUOp(OpTest):
def config(self):
self.x_shape = [2, 0, 12]
self.pivot = True
self.get_infos = True
self.dtype = "float64"

def setUp(self):
self.op_type = "lu"
self.python_api = paddle.tensor.linalg.lu
self.python_out_sig = ["Out", "Pivots"]
self.config()

A = np.random.random([2, 0, 12]).astype(self.dtype)
self.inputs = {'X': A}
self.attrs = {'pivots': self.pivot}

self.output = np.zeros([2, 0, 12]).astype(self.dtype)
self.Pivots = np.zeros([2, 0]).astype(self.dtype)
self.Infos = np.zeros([2]).astype(self.dtype)

self.outputs = {
'Out': self.output,
'Pivots': self.Pivots,
'Infos': self.Infos,
}

def test_check_output(self):
self.check_output_with_place(paddle.CPUPlace(), check_pir=True)
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0), check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(
paddle.CPUPlace(), ['X'], ['Out'], check_pir=True
)
if core.is_compiled_with_cuda():
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], ['Out'], check_pir=True
)


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
16 changes: 16 additions & 0 deletions test/legacy_test/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ def setUp(self):
self.outputs = {'Out': out_np}


class TestMeanOp_Int32ZeroSize(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = paddle.mean
self.dtype = np.int32
self.public_python_api = paddle.mean
self.inputs = {'X': np.array([]).astype(self.dtype)}
self.outputs = {'Out': np.nan}

def test_check_output(self):
self.check_output(check_pir=True)

def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestMeanOp_Int64ZeroSize(OpTest):
def setUp(self):
self.op_type = "mean"
Expand Down
Loading