@@ -25,14 +25,69 @@ limitations under the License. */
2525namespace phi {
2626namespace funcs {
2727
28+ template <typename Context, typename T>
29+ struct MapMatrixInverseFunctor {
30+ void operator ()(
31+ const Context& dev_ctx, const T* a_ptr, T* a_inv_ptr, int offset, int n) {
32+ using Matrix =
33+ Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
34+ using EigenMatrixMap = Eigen::Map<Matrix>;
35+ using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
36+
37+ ConstEigenMatrixMap mat (a_ptr + offset, n, n);
38+ EigenMatrixMap mat_inv (a_inv_ptr + offset, n, n);
39+ Eigen::PartialPivLU<Matrix> lu;
40+ lu.compute (mat);
41+
42+ const T min_abs_pivot = lu.matrixLU ().diagonal ().cwiseAbs ().minCoeff ();
43+ PADDLE_ENFORCE_GT (min_abs_pivot,
44+ static_cast <T>(0 ),
45+ errors::InvalidArgument (" Input is not invertible." ));
46+ mat_inv.noalias () = lu.inverse ();
47+ }
48+ };
49+
50+ template <typename Context, typename T>
51+ struct MapMatrixInverseFunctor <Context, phi::dtype::complex <T>> {
52+ void operator ()(const Context& dev_ctx,
53+ const phi::dtype::complex <T>* a_ptr,
54+ phi::dtype::complex <T>* a_inv_ptr,
55+ int offset,
56+ int n) {
57+ using Matrix = Eigen::Matrix<std::complex <T>,
58+ Eigen::Dynamic,
59+ Eigen::Dynamic,
60+ Eigen::RowMajor>;
61+ using EigenMatrixMap = Eigen::Map<Matrix>;
62+ using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
63+ std::complex <T>* std_ptr = new std::complex <T>[n * n];
64+ std::complex <T>* std_inv_ptr = new std::complex <T>[n * n];
65+ for (int i = 0 ; i < n * n; i++) {
66+ *(std_ptr + i) = static_cast <std::complex <T>>(*(a_ptr + offset + i));
67+ }
68+ ConstEigenMatrixMap mat (std_ptr, n, n);
69+ EigenMatrixMap mat_inv (std_inv_ptr, n, n);
70+ Eigen::PartialPivLU<Matrix> lu;
71+ lu.compute (mat);
72+
73+ const T min_abs_pivot = lu.matrixLU ().diagonal ().cwiseAbs ().minCoeff ();
74+ PADDLE_ENFORCE_NE (min_abs_pivot,
75+ static_cast <std::complex <T>>(0 ),
76+ errors::InvalidArgument (" Input is not invertible." ));
77+ mat_inv.noalias () = lu.inverse ();
78+ for (int i = 0 ; i < n * n; i++) {
79+ *(a_inv_ptr + offset + i) =
80+ static_cast <phi::dtype::complex <T>>(*(std_inv_ptr + i));
81+ }
82+ delete[] std_ptr;
83+ delete[] std_inv_ptr;
84+ }
85+ };
86+
2887template <typename Context, typename T>
2988void ComputeInverseEigen (const Context& dev_ctx,
3089 const DenseTensor& a,
3190 DenseTensor* a_inv) {
32- using Matrix =
33- Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
34- using EigenMatrixMap = Eigen::Map<Matrix>;
35- using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
3691 const auto & mat_dims = a.dims ();
3792 const int rank = mat_dims.size ();
3893 int n = mat_dims[rank - 1 ];
@@ -41,17 +96,13 @@ void ComputeInverseEigen(const Context& dev_ctx,
4196 const T* a_ptr = a.data <T>();
4297 T* a_inv_ptr = dev_ctx.template Alloc <T>(a_inv);
4398
99+ // Putting phi::dtype::complex into eigen::matrix has a problem,
100+ // it's not going to get the right result,
101+ // so we're going to convert it to std::complex and
102+ // then we're going to put it into eigen::matrix.
44103 for (int i = 0 ; i < batch_size; ++i) {
45- ConstEigenMatrixMap mat (a_ptr + i * n * n, n, n);
46- EigenMatrixMap mat_inv (a_inv_ptr + i * n * n, n, n);
47- Eigen::PartialPivLU<Matrix> lu;
48- lu.compute (mat);
49-
50- const T min_abs_pivot = lu.matrixLU ().diagonal ().cwiseAbs ().minCoeff ();
51- PADDLE_ENFORCE_GT (min_abs_pivot,
52- static_cast <T>(0 ),
53- errors::InvalidArgument (" Input is not invertible." ));
54- mat_inv.noalias () = lu.inverse ();
104+ MapMatrixInverseFunctor<Context, T> functor;
105+ functor (dev_ctx, a_ptr, a_inv_ptr, i * n * n, n);
55106 }
56107}
57108
0 commit comments