Skip to content
2 changes: 2 additions & 0 deletions paddle/phi/backends/dynload/cusolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
__macro(cusolverDnSgetrf); \
__macro(cusolverDnSgetrs); \
__macro(cusolverDnDgetrs); \
__macro(cusolverDnCgetrs); \
__macro(cusolverDnZgetrs); \
__macro(cusolverDnDgetrf); \
__macro(cusolverDnCgetrf); \
__macro(cusolverDnZgetrf); \
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/backends/dynload/lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ extern "C" void dgetrs_(char *trans,
double *b,
int *ldb,
int *info);
extern "C" void cgetrs_(char *trans,
int *n,
int *nrhs,
std::complex<float> *a,
int *lda,
int *ipiv,
std::complex<float> *b,
int *ldb,
int *info);
extern "C" void zgetrs_(char *trans,
int *n,
int *nrhs,
std::complex<double> *a,
int *lda,
int *ipiv,
std::complex<double> *b,
int *ldb,
int *info);

// evd
extern "C" void zheevd_(char *jobz,
Expand Down Expand Up @@ -396,6 +414,8 @@ extern void *lapack_dso_handle;
__macro(zgetrf_); \
__macro(sgetrs_); \
__macro(dgetrs_); \
__macro(cgetrs_); \
__macro(zgetrs_); \
__macro(zheevd_); \
__macro(cheevd_); \
__macro(dsyevd_); \
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/dynload/rocsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ extern void *rocsolver_dso_handle;
__macro(rocsolver_zpotrs); \
__macro(rocsolver_sgetrs); \
__macro(rocsolver_dgetrs); \
__macro(rocsolver_cgetrs); \
__macro(rocsolver_zgetrs); \
__macro(rocsolver_sgetrf); \
__macro(rocsolver_dgetrf); \
__macro(rocsolver_cgetrf); \
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/lu_solve_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,11 @@
#include "paddle/phi/kernels/lu_solve_grad_kernel.h"

// Register the CPU backward kernel
PD_REGISTER_KERNEL(
lu_solve_grad, CPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {}
PD_REGISTER_KERNEL(lu_solve_grad,
CPU,
ALL_LAYOUT,
phi::LuSolveGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/lu_solve_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,11 @@ void LuSolveKernel(const Context& dev_ctx,
}
} // namespace phi

PD_REGISTER_KERNEL(
lu_solve, CPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {}
PD_REGISTER_KERNEL(lu_solve,
CPU,
ALL_LAYOUT,
phi::LuSolveKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/lu_unpack_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_unpack_grad_kernel.h"

PD_REGISTER_KERNEL(
lu_unpack_grad, CPU, ALL_LAYOUT, phi::LUUnpackGradKernel, float, double) {}
PD_REGISTER_KERNEL(lu_unpack_grad,
CPU,
ALL_LAYOUT,
phi::LUUnpackGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/lu_unpack_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/kernels/impl/lu_unpack_kernel_impl.h"
#include "paddle/phi/kernels/lu_unpack_kernel.h"

PD_REGISTER_KERNEL(
lu_unpack, CPU, ALL_LAYOUT, phi::LUUnpackKernel, float, double) {}
PD_REGISTER_KERNEL(lu_unpack,
CPU,
ALL_LAYOUT,
phi::LUUnpackKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
42 changes: 42 additions & 0 deletions paddle/phi/kernels/funcs/lapack/lapack_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,48 @@ void lapackLuSolve<float>(char trans,
dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

template <>
void lapackLuSolve<phi::dtype::complex<float>>(char trans,
int n,
int nrhs,
phi::dtype::complex<float> *a,
int lda,
int *ipiv,
phi::dtype::complex<float> *b,
int ldb,
int *info) {
dynload::cgetrs_(&trans,
&n,
&nrhs,
reinterpret_cast<std::complex<float> *>(a),
&lda,
ipiv,
reinterpret_cast<std::complex<float> *>(b),
&ldb,
info);
}

template <>
void lapackLuSolve<phi::dtype::complex<double>>(char trans,
int n,
int nrhs,
phi::dtype::complex<double> *a,
int lda,
int *ipiv,
phi::dtype::complex<double> *b,
int ldb,
int *info) {
dynload::zgetrs_(&trans,
&n,
&nrhs,
reinterpret_cast<std::complex<double> *>(a),
&lda,
ipiv,
reinterpret_cast<std::complex<double> *>(b),
&ldb,
info);
}

// eigh
template <>
void lapackEigh<float>(char jobz,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/gpu/lu_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
// HIP not support cusolver in LUKernel
PD_REGISTER_KERNEL(lu_grad, GPU, ALL_LAYOUT, phi::LUGradKernel, float, double) {
}
#else
#else // PADDLE_WITH_CUDA
PD_REGISTER_KERNEL(lu_grad,
GPU,
ALL_LAYOUT,
Expand All @@ -31,4 +31,4 @@ PD_REGISTER_KERNEL(lu_grad,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif // PADDLE_WITH_HIP
#endif
13 changes: 13 additions & 0 deletions paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,18 @@
#include "paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_solve_grad_kernel.h"

#ifdef PADDLE_WITH_HIP
// blas_impl.hip.h not support CUBlas<T>::TRSM for complex in
// TriangularSolveKernel
PD_REGISTER_KERNEL(
lu_solve_grad, GPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {}
#else // PADDLE_WITH_CUDA
PD_REGISTER_KERNEL(lu_solve_grad,
GPU,
ALL_LAYOUT,
phi::LuSolveGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
104 changes: 102 additions & 2 deletions paddle/phi/kernels/gpu/lu_solve_kernle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,51 @@ void rocsolver_getrs<double>(const solverHandle_t& handle,
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::rocsolver_dgetrs(handle, trans, n, nrhs, a, lda, ipiv, b, ldb));
}

template <>
void rocsolver_getrs<dtype::complex<float>>(const solverHandle_t& handle,
rocblas_operation trans,
int n,
int nrhs,
dtype::complex<float>* a,
int lda,
int* ipiv,
dtype::complex<float>* b,
int ldb) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::rocsolver_cgetrs(handle,
trans,
n,
nrhs,
reinterpret_cast<rocblas_float_complex*>(a),
lda,
ipiv,
reinterpret_cast<rocblas_float_complex*>(b),
ldb));
}

template <>
void rocsolver_getrs<dtype::complex<double>>(const solverHandle_t& handle,
rocblas_operation trans,
int n,
int nrhs,
dtype::complex<double>* a,
int lda,
int* ipiv,
dtype::complex<double>* b,
int ldb) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::rocsolver_zgetrs(handle,
trans,
n,
nrhs,
reinterpret_cast<rocblas_double_complex*>(a),
lda,
ipiv,
reinterpret_cast<rocblas_double_complex*>(b),
ldb));
}

#else
template <typename T>
void cusolver_getrs(const solverHandle_t& handle,
Expand Down Expand Up @@ -107,6 +152,55 @@ void cusolver_getrs<double>(const solverHandle_t& handle,
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgetrs(
handle, trans, n, nrhs, a, lda, ipiv, b, ldb, info));
}

template <>
void cusolver_getrs<dtype::complex<float>>(const solverHandle_t& handle,
cublasOperation_t trans,
int n,
int nrhs,
dtype::complex<float>* a,
int lda,
int* ipiv,
dtype::complex<float>* b,
int ldb,
int* info) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCgetrs(handle,
trans,
n,
nrhs,
reinterpret_cast<cuComplex*>(a),
lda,
ipiv,
reinterpret_cast<cuComplex*>(b),
ldb,
info));
}

template <>
void cusolver_getrs<dtype::complex<double>>(const solverHandle_t& handle,
cublasOperation_t trans,
int n,
int nrhs,
dtype::complex<double>* a,
int lda,
int* ipiv,
dtype::complex<double>* b,
int ldb,
int* info) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnZgetrs(handle,
trans,
n,
nrhs,
reinterpret_cast<cuDoubleComplex*>(a),
lda,
ipiv,
reinterpret_cast<cuDoubleComplex*>(b),
ldb,
info));
}

#endif // PADDLE_WITH_HIP

template <typename T, typename Context>
Expand Down Expand Up @@ -199,5 +293,11 @@ void LuSolveKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
lu_solve, GPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {}
PD_REGISTER_KERNEL(lu_solve,
GPU,
ALL_LAYOUT,
phi::LuSolveKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/lu_unpack_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_unpack_grad_kernel.h"

PD_REGISTER_KERNEL(
lu_unpack_grad, GPU, ALL_LAYOUT, phi::LUUnpackGradKernel, float, double) {}
PD_REGISTER_KERNEL(lu_unpack_grad,
GPU,
ALL_LAYOUT,
phi::LUUnpackGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/lu_unpack_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/kernels/impl/lu_unpack_kernel_impl.h"
#include "paddle/phi/kernels/lu_unpack_kernel.h"

PD_REGISTER_KERNEL(
lu_unpack, GPU, ALL_LAYOUT, phi::LUUnpackKernel, float, double) {}
PD_REGISTER_KERNEL(lu_unpack,
GPU,
ALL_LAYOUT,
phi::LUUnpackKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
18 changes: 14 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3610,10 +3610,16 @@ def lu_solve(
given LU decomposition :math:`A` and column vector :math:`b`.

Args:
b (Tensor): Column vector `b` in the above equation. It has shape :math:`(*, m, k)`, where :math:`*` is batch dimensions, with data type float32, float64.
lu (Tensor): LU decomposition. It has shape :math:`(*, m, m)`, where :math:`*` is batch dimensions, that can be decomposed into an upper triangular matrix U and a lower triangular matrix L, with data type float32, float64.
b (Tensor): Column vector `b` in the above equation. It has shape :math:`(*, m, k)`, where :math:`*` is batch dimensions,
with data type float32, float64, complex64, or complex128.

lu (Tensor): LU decomposition. It has shape :math:`(*, m, m)`, where :math:`*` is batch dimensions, that can be decomposed into an upper triangular matrix U and a lower triangular matrix L,
with data type float32, float64, complex64, or complex128.

pivots (Tensor): Permutation matrix P of LU decomposition. It has shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted to a permutation matrix P, with data type int32.

trans (str, optional): The transpose of the matrix A. It can be "N" , "T" or "C", "N" means :math:`Ax=b`, "T" means :math:`A^Tx=b`, "C" means :math:`A^Hx=b`, default is "N".

name (str|None, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -3704,8 +3710,9 @@ def lu_unpack(

Args:
x (Tensor): The LU tensor get from paddle.linalg.lu, which is combined by L and U.
Its data type should be float32, float64, complex64, or complex128.

y (Tensor): Pivots get from paddle.linalg.lu.
y (Tensor): Pivots get from paddle.linalg.lu. Its data type should be int32.

unpack_ludata (bool, optional): whether to unpack L and U from x. Default: True.

Expand Down Expand Up @@ -3774,7 +3781,10 @@ def lu_unpack(
return P, L, U
else:
check_variable_and_dtype(
x, 'dtype', ['float32', 'float64'], 'lu_unpack'
x,
'dtype',
['float32', 'float64', 'complex64', 'complex128'],
'lu_unpack',
)
helper = LayerHelper('lu_unpack', **locals())
p = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down
Loading