Skip to content
12 changes: 10 additions & 2 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4229,7 +4229,11 @@ void LstsqInferMeta(const MetaTensor& x,
m,
y_dims[y_rank - 2]));

rank->set_dims(common::make_ddim(batch_dims_vec));
if (x.numel() == 0 || y.numel() == 0) {
rank->set_dims(common::make_ddim({0}));
} else {
rank->set_dims(common::make_ddim(batch_dims_vec));
}

if (m > n && driver != "gelsy") {
if (driver == "gelss" || driver == "gelsd") {
Expand All @@ -4245,7 +4249,11 @@ void LstsqInferMeta(const MetaTensor& x,
residuals->set_dtype(y.dtype());

batch_dims_vec.emplace_back(std::min(m, n));
singular_values->set_dims(common::make_ddim(batch_dims_vec));
if (x.numel() == 0 || y.numel() == 0) {
singular_values->set_dims(common::make_ddim({0}));
} else {
singular_values->set_dims(common::make_ddim(batch_dims_vec));
}
singular_values->set_dtype(y.dtype());

batch_dims_vec[x_rank - 2] = n;
Expand Down
22 changes: 21 additions & 1 deletion paddle/phi/kernels/cpu/lstsq_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/lstsq_kernel_impl.h"
#include "paddle/phi/kernels/lstsq_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {

enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss };
Expand All @@ -40,6 +40,26 @@ void LstsqKernel(const Context& dev_ctx,
DenseTensor* rank,
DenseTensor* singular_values) {
using ValueType = phi::dtype::Real<T>;
if (x.numel() == 0 || y.numel() == 0) {
if (solution)
Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(solution->dims())),
0,
solution);
if (rank)
Full<int64_t, Context>(
dev_ctx, phi::IntArray(common::vectorize(rank->dims())), 0, rank);
if (residuals)
GetResidualsTensor<Context, T>(
dev_ctx, x, y, driver_string, solution, residuals, rank);
if (singular_values)
Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(singular_values->dims())),
0,
singular_values);
return;
}

static auto driver_type = std::unordered_map<std::string, LapackDriverType>(
{{"gels", LapackDriverType::Gels},
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ void AffineGridGradCUDAKernel(const Context& dev_ctx,
DenseTensor* output) {
auto* theta = &input;
auto theta_size = theta->dims().size();
if (output->numel() == 0 || input.numel() == 0) {
dev_ctx.template Alloc<T>(output);
phi::funcs::SetConstant<phi::GPUContext, T>()(
dev_ctx, output, static_cast<T>(0));
return;
}
if (theta_size == 4) {
AffineGridGrad4DCUDAKernel<T, Context>(
dev_ctx, input, outputShape, align_corners, output);
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/gpu/affine_grid_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/affine_grid_utils.h"

namespace phi {
Expand Down Expand Up @@ -138,6 +139,11 @@ void AffineGrid4DCUDAKernel(const Context& dev_ctx,
w = size_attr[3];
output->Resize(common::make_ddim({n, h, w, 2}));
T* out_data = dev_ctx.template Alloc<T>(output);
if (input.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(output->dims())), 0, output);
return;
}

T h_step;
T w_step;
Expand Down Expand Up @@ -188,6 +194,11 @@ void AffineGrid5DCUDAKernel(const Context& dev_ctx,
w = size_attr[4];
output->Resize(common::make_ddim({n, d, h, w, 3}));
T* out_data = dev_ctx.template Alloc<T>(output);
if (input.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(output->dims())), 0, output);
return;
}

T d_step;
T h_step;
Expand Down
23 changes: 21 additions & 2 deletions paddle/phi/kernels/gpu/lstsq_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/impl/lstsq_kernel_impl.h"
#include "paddle/phi/kernels/impl/qr_kernel_impl.h"
Expand All @@ -26,7 +27,6 @@
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"

namespace phi {

enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss };
Expand All @@ -41,6 +41,26 @@ void LstsqKernel(const Context& dev_ctx,
DenseTensor* residuals,
DenseTensor* rank,
DenseTensor* singular_values) {
if (x.numel() == 0 || y.numel() == 0) {
if (solution)
Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(solution->dims())),
0,
solution);
if (rank)
Full<int64_t, Context>(
dev_ctx, phi::IntArray(common::vectorize(rank->dims())), 0, rank);
if (residuals)
GetResidualsTensor<Context, T>(
dev_ctx, x, y, driver_string, solution, residuals, rank);
if (singular_values)
Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(singular_values->dims())),
0,
singular_values);
return;
}
auto x_dims = x.dims();
auto y_dims = y.dims();
int dim_size = x_dims.size();
Expand Down Expand Up @@ -158,7 +178,6 @@ void LstsqKernel(const Context& dev_ctx,
phi::Copy<Context>(
dev_ctx, solu_tensor, dev_ctx.GetPlace(), true, solution);
}

if (batch_count == 1) solution->Resize(common::make_ddim({n, nrhs}));
GetResidualsTensor<Context, T>(
dev_ctx, x, y, driver_string, solution, residuals, rank);
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/kernels/gpudnn/affine_grid_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {

Expand All @@ -40,6 +41,13 @@ void AffineGridGradCudnnKernel(const Context& dev_ctx,
common::errors::InvalidArgument(
"Only support for CUDAPlace.Please switch your context from "
"CPUPlace to CUDAPlace or update your cudnn."));
if (output_grad.numel() == 0 || input_grad->numel() == 0) {
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(input_grad->dims())),
0,
input_grad);
return;
}
auto handle = dev_ctx.cudnn_handle();
auto& theta_grad = input_grad;

Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/kernels/gpudnn/affine_grid_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/kernels/full_kernel.h"
namespace phi {

using ScopedSpatialTransformerDescriptor =
Expand Down Expand Up @@ -51,6 +51,11 @@ void AffineGridCudnnKernel(const Context& dev_ctx,
h_size_data[3] = size_attr[3];
output->Resize(common::make_ddim({n, h_size_data[2], h_size_data[3], 2}));
T* output_data = dev_ctx.template Alloc<T>(output);
if (input.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(output->dims())), 0, output);
return;
}
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, h_size_data);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/lstsq_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx,

if (m > n && driver != "gelsy") {
bool compute_residuals = true;
if (driver == "gelss" || driver == "gelsd") {
if ((driver == "gelss" || driver == "gelsd") && rank->numel() != 0) {
if (dim_size == 2) {
compute_residuals = rank->data<int>()[0] == n;
} else {
Expand Down
16 changes: 16 additions & 0 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import math
from typing import TYPE_CHECKING, Literal

import numpy as np
Expand Down Expand Up @@ -1744,6 +1745,21 @@ def empty_tensor(input, shape):
raise ValueError(
"only support x is nonempty tensor in static graph mode"
)
# reshape([]) is invalid,
# so use reshae([0]) and sum to get a scalar when shape is []
old_size = input.numel()
if len(shape) == 0 and old_size == 0:
return input.reshape([0]).sum()
new_size = math.prod(shape)
# 0-size Tensor cannot be reshaped to non 0-size Tensor
if new_size > 0 and old_size == 0:
tmp = paddle.concat(
[
input.flatten(),
paddle.zeros([new_size], dtype=input.dtype),
]
)
return tmp.reshape(shape)
return input.reshape(shape)
raise ValueError(
"only support x is nonempty tensor in static graph mode"
Expand Down
31 changes: 31 additions & 0 deletions test/legacy_test/test_linalg_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,37 @@ def test_dygraph_empty_tensor_input(self):
test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)


class TestCondZeroSizeTensor(unittest.TestCase):
def setUp(self):
self.shape = [0, 3]
self.dtype = 'float32'
self.p = 2
self.except_shape = []

def _init_data(self):
self.x = paddle.randn(self.shape, dtype=self.dtype)
self.x.stop_gradient = False

def _test_cond(self):
res = paddle.linalg.cond(self.x, self.p)
np.testing.assert_allclose(res.shape, self.except_shape)
loss = res.sum()
loss.backward()
np.testing.assert_allclose(self.x.grad.shape, self.x.shape)

def test_dygraph(self):
self._init_data()
self._test_cond()


class TestCondZeroSizeTensor1(TestCondZeroSizeTensor):
def setUp(self):
self.shape = [8, 9, 0, 3]
self.dtype = 'float32'
self.p = 2
self.except_shape = [8, 9]


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
29 changes: 28 additions & 1 deletion test/legacy_test/test_linalg_lstsq_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def assert_np_close(self):
and self.driver != "gelsy"
):
np.testing.assert_allclose(
self._result_residuals, self._output_residuals, rtol=1e-5
self._result_residuals, self._output_residuals, rtol=1e-3
)
if self.driver in ("gelsy", "gelsd", "gelss"):
np.testing.assert_allclose(
Expand Down Expand Up @@ -290,6 +290,33 @@ def init_config(self):
self._input_shape_2 = (50, 300)


class LinalgLstsqTestZeroSize(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gelsd"
self._input_shape_1 = (0, 100)
self._input_shape_2 = (0, 50)


class LinalgLstsqTestZeroSize1(LinalgLstsqTestZeroSize):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gels"
self._input_shape_1 = (10, 7, 0)
self._input_shape_2 = (10, 7, 6)


class LinalgLstsqTestZeroSize2(LinalgLstsqTestZeroSize):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gelss"
self._input_shape_1 = (5, 0)
self._input_shape_2 = (5, 0)


class TestLinalgLstsqAPIError(unittest.TestCase):
def setUp(self):
pass
Expand Down