11template <
typename MatrixType_,
int UpLo_ = Lower>
struct BunchKaufman;
14template <
typename MatrixType_,
int UpLo_>
15struct traits<BunchKaufman<MatrixType_, UpLo_>> : traits<MatrixType_> {
16 typedef MatrixXpr XprKind;
17 typedef SolverStorage StorageKind;
18 typedef int StorageIndex;
22template <
typename MatrixType,
typename IndicesType>
23ComputationInfo bunch_kaufman_in_place_unblocked(MatrixType &a,
26 using Scalar =
typename MatrixType::Scalar;
27 using Real =
typename Eigen::NumTraits<Scalar>::Real;
28 const Real alpha = (Real(1) + numext::sqrt(Real(17))) / Real(8);
33 const Index n = a.rows();
37 if (numext::abs(numext::real(a(0, 0))) == Real(0)) {
38 return NumericalIssue;
40 a(0, 0) = Real(1) / numext::real(a(0, 0));
48 Real abs_akk = numext::abs(numext::real(a(k, k)));
50 Real colmax = Real(0);
53 colmax = a.col(k).segment(k + 1, n - k - 1).cwiseAbs().maxCoeff(&imax);
58 if (numext::maxi(abs_akk, colmax) == Real(0)) {
59 return NumericalIssue;
61 if (abs_akk >= colmax * alpha) {
64 Real rowmax = Real(0);
66 rowmax = a.row(imax).segment(k, imax - k).cwiseAbs().maxCoeff();
68 if (n - imax - 1 > 0) {
69 rowmax = numext::maxi(rowmax, a.col(imax)
70 .segment(imax + 1, n - imax - 1)
75 if (abs_akk >= (alpha * colmax) * (colmax / rowmax)) {
77 }
else if (numext::abs(numext::real(a(imax, imax))) >= alpha * rowmax) {
85 Index kk = k + k_step - 1;
89 .segment(kp + 1, n - kp - 1)
90 .swap(a.col(kp).segment(kp + 1, n - kp - 1));
92 for (Index j = kk + 1; j < kp; ++j) {
94 a(j, kk) = numext::conj(a(kp, j));
95 a(kp, j) = numext::conj(tmp);
97 a(kp, kk) = numext::conj(a(kp, kk));
98 swap(a(kk, kk), a(kp, kp));
101 swap(a(k + 1, k), a(kp, k));
106 Real d11 = Real(1) / numext::real(a(k, k));
109 auto x = a.middleRows(k + 1, n - k - 1).col(k);
111 a.middleRows(k + 1, n - k - 1).middleCols(k + 1, n - k - 1);
113 for (Index j = 0; j < n - k - 1; ++j) {
114 Scalar d11xj = numext::conj(x(j)) * d11;
115 for (Index i = j; i < n - k - 1; ++i) {
117 trailing(i, j) -= d11xj * xi;
119 trailing(j, j) =
Scalar(numext::real(trailing(j, j)));
124 Real d21_abs = numext::abs(a(k + 1, k));
125 Real d21_inv = Real(1) / d21_abs;
126 Real d11 = d21_inv * numext::real(a(k + 1, k + 1));
127 Real d22 = d21_inv * numext::real(a(k, k));
129 Real t = Real(1) / ((d11 * d22) - Real(1));
130 Real d = t * d21_inv;
131 Scalar d21 = a(k + 1, k) * d21_inv;
133 a(k, k) =
Scalar(d11 * d);
134 a(k + 1, k) = -d21 * d;
135 a(k + 1, k + 1) =
Scalar(d22 * d);
137 for (Index j = k + 2; j < n; ++j) {
138 Scalar wk = ((a(j, k) * d11) - (a(j, k + 1) * d21)) * d;
140 ((a(j, k + 1) * d22) - (a(j, k) * numext::conj(d21))) * d;
142 for (Index i = j; i < n; ++i) {
144 a(i, k) * numext::conj(wk) + a(i, k + 1) * numext::conj(wkp1);
146 a(j, j) =
Scalar(numext::real(a(j, j)));
159 pivots[k] =
static_cast<int>(kp);
161 pivots[k] =
static_cast<int>(-1 - kp);
162 pivots[k + 1] =
static_cast<int>(-1 - kp);
171template <
typename MatrixType,
typename WType,
typename IndicesType>
173bunch_kaufman_in_place_one_block(MatrixType &a, WType &w, IndicesType &pivots,
174 Index &pivot_count, Index &processed_cols) {
175 using Scalar =
typename MatrixType::Scalar;
176 using Real =
typename Eigen::NumTraits<Scalar>::Real;
180 Real alpha = (Real(1) + numext::sqrt(Real(17))) / Real(8);
189 while (k < n && k + 1 < nb) {
190 w.col(k).segment(k, n - k) = a.col(k).segment(k, n - k);
192 auto w_row = w.row(k).segment(0, k);
193 auto w_col = w.col(k).segment(k, n - k);
194 w_col.noalias() -= a.block(k, 0, n - k, k) * w_row.transpose();
196 w(k, k) =
Scalar(numext::real(w(k, k)));
199 Real abs_akk = numext::abs(numext::real(w(k, k)));
201 Real colmax = Real(0);
204 colmax = w.col(k).segment(k + 1, n - k - 1).cwiseAbs().maxCoeff(&imax);
209 if (numext::maxi(abs_akk, colmax) == Real(0)) {
210 return NumericalIssue;
212 if (abs_akk >= colmax * alpha) {
215 w.col(k + 1).segment(k, imax - k) =
216 a.row(imax).segment(k, imax - k).adjoint();
217 w.col(k + 1).segment(imax, n - imax) =
218 a.col(imax).segment(imax, n - imax);
221 auto w_row = w.row(imax).segment(0, k);
222 auto w_col = w.col(k + 1).segment(k, n - k);
223 w_col.noalias() -= a.block(k, 0, n - k, k) * w_row.transpose();
225 w(imax, k + 1) =
Scalar(numext::real(w(imax, k + 1)));
227 Real rowmax = Real(0);
229 rowmax = w.col(k + 1).segment(k, imax - k).cwiseAbs().maxCoeff();
231 if (n - imax - 1 > 0) {
232 rowmax = numext::maxi(rowmax, w.col(k + 1)
233 .segment(imax + 1, n - imax - 1)
238 if (abs_akk >= (alpha * colmax) * (colmax / rowmax)) {
240 }
else if (numext::abs(numext::real(w(imax, k + 1))) >=
243 w.col(k).segment(k, n - k) = w.col(k + 1).segment(k, n - k);
250 Index kk = k + k_step - 1;
254 a(kp, kp) = a(kk, kk);
255 for (Index j = kk + 1; j < kp; ++j) {
256 a(kp, j) = numext::conj(a(j, kk));
258 a.col(kp).segment(kp + 1, n - kp - 1) =
259 a.col(kk).segment(kp + 1, n - kp - 1);
260 a.row(kk).segment(0, k).swap(a.row(kp).segment(0, k));
261 w.row(kk).segment(0, kk + 1).swap(w.row(kp).segment(0, kk + 1));
265 a.col(k).segment(k, n - k) = w.col(k).segment(k, n - k);
267 Real d11 = Real(1) / numext::real(w(k, k));
269 auto x = a.middleRows(k + 1, n - k - 1).col(k);
271 w.col(k).segment(k + 1, n - k - 1) =
272 w.col(k).segment(k + 1, n - k - 1).conjugate();
274 Real d21_abs = numext::abs(w(k + 1, k));
275 Real d21_inv = Real(1) / d21_abs;
276 Real d11 = d21_inv * numext::real(w(k + 1, k + 1));
277 Real d22 = d21_inv * numext::real(w(k, k));
279 Real t = Real(1) / ((d11 * d22) - Real(1));
280 Scalar d21 = w(k + 1, k) * d21_inv;
281 Real d = t * d21_inv;
283 a(k, k) =
Scalar(d11 * d);
284 a(k + 1, k) = -d21 * d;
285 a(k + 1, k + 1) =
Scalar(d22 * d);
287 for (Index j = k + 2; j < n; ++j) {
288 Scalar wk = ((w(j, k) * d11) - (w(j, k + 1) * d21)) * d;
290 ((w(j, k + 1) * d22) - (w(j, k) * numext::conj(d21))) * d;
296 w.col(k).segment(k + 1, n - k - 1) =
297 w.col(k).segment(k + 1, n - k - 1).conjugate();
298 w.col(k + 1).segment(k + 2, n - k - 2) =
299 w.col(k + 1).segment(k + 2, n - k - 2).conjugate();
308 pivots[k] =
static_cast<int>(kp);
310 pivots[k] =
static_cast<int>(-1 - kp);
311 pivots[k + 1] =
static_cast<int>(-1 - kp);
317 auto a_left = a.bottomRows(n - k).leftCols(k);
318 auto a_right = a.bottomRows(n - k).rightCols(n - k);
320 a_right.template triangularView<Lower>() -=
321 a_left * w.block(k, 0, n - k, k).transpose();
327 Index jp = pivots[j];
338 a.row(jp).segment(0, j + 1).swap(a.row(jj).segment(0, j + 1));
346template <
typename MatrixType,
typename VecType,
typename IndicesType,
347 typename WorkspaceType>
348ComputationInfo bunch_kaufman_in_place(MatrixType &a, VecType &subdiag,
349 IndicesType &pivots, WorkspaceType &w,
350 Index &pivot_count) {
353 const Index blocksize = w.cols();
358 Index k_pivot_count = 0;
359 auto a_block = a.block(k, k, n - k, n - k);
360 auto pivots_block = pivots.segment(k, n - k);
361 ComputationInfo info = InvalidInput;
362 if (blocksize != 0 && blocksize < n - k) {
363 info = internal::bunch_kaufman_in_place_one_block(
364 a_block, w, pivots_block, k_pivot_count, kb);
366 info = internal::bunch_kaufman_in_place_unblocked(a_block, pivots_block,
370 if (info != Success) {
374 for (Index j = k; j < k + kb; ++j) {
375 auto &p = pivots.coeffRef(j);
381 p +=
static_cast<int>(k);
383 p -=
static_cast<int>(k);
387 pivot_count += k_pivot_count;
391 using Scalar =
typename MatrixType::Scalar;
396 subdiag(k) = a(k + 1, k);
397 subdiag(k + 1) =
Scalar(0);
411 a.row(k + 1).segment(0, k).swap(a.row(p).segment(0, k));
414 a.row(k).segment(0, k).swap(a.row(p).segment(0, k));
422template <
typename MatrixType,
bool Conjugate>
struct BK_Traits;
424template <
typename MatrixType>
struct BK_Traits<MatrixType, false> {
425 typedef TriangularView<const MatrixType, UnitLower> MatrixL;
426 typedef TriangularView<
const typename MatrixType::AdjointReturnType,
429 static inline MatrixL getL(
const MatrixType &m) {
return MatrixL(m); }
430 static inline MatrixU getU(
const MatrixType &m) {
431 return MatrixU(m.adjoint());
435template <
typename MatrixType>
struct BK_Traits<MatrixType, true> {
436 typedef typename MatrixType::ConjugateReturnType ConjugateReturnType;
437 typedef TriangularView<const ConjugateReturnType, UnitLower> MatrixL;
438 typedef TriangularView<
const typename MatrixType::TransposeReturnType,
441 static inline MatrixL getL(
const MatrixType &m) {
442 return MatrixL(m.conjugate());
444 static inline MatrixU getU(
const MatrixType &m) {
445 return MatrixU(m.transpose());
449template <
bool Conjugate,
typename MatrixType,
typename VecType,
450 typename IndicesType,
typename Rhs>
451void bunch_kaufman_solve_in_place(MatrixType
const &L, VecType
const &subdiag,
452 IndicesType
const &pivots, Rhs &x) {
462 x.row(k + 1).swap(x.row(p));
465 x.row(k).swap(x.row(p));
470 using Traits = BK_Traits<MatrixType, Conjugate>;
472 Traits::getL(L).solveInPlace(x);
478 using Scalar =
typename MatrixType::Scalar;
479 using Real =
typename Eigen::NumTraits<Scalar>::Real;
481 Scalar akp1k = subdiag(k);
482 Real ak = numext::real(L(k, k));
483 Real akp1 = numext::real(L(k + 1, k + 1));
486 akp1k = numext::conj(akp1k);
489 for (Index j = 0; j < x.cols(); ++j) {
491 Scalar xkp1 = x(k + 1, j);
493 x(k, j) = xk * ak + xkp1 * numext::conj(akp1k);
494 x(k + 1, j) = xkp1 * akp1 + xk * akp1k;
499 x.row(k) *= numext::real(L(k, k));
504 Traits::getU(L).solveInPlace(x);
512 x.row(k).swap(x.row(p));
515 x.row(k).swap(x.row(p));
521template <
typename MatrixType_,
int UpLo_>
530 using Base = SolverBase<BunchKaufman>;
538 typename Transpositions<RowsAtCompileTime,
548 , m_isInitialized(false)
549 , m_info(ComputationInfo::InvalidInput)
553 : m_matrix(size, size)
557 , m_isInitialized(false)
558 , m_info(ComputationInfo::InvalidInput)
560 , m_workspace(size, m_blocksize) {}
562 template <
typename InputType>
564 : m_matrix(matrix.
rows(), matrix.
cols())
565 , m_subdiag(matrix.
rows())
567 , m_pivots(matrix.
rows())
568 , m_isInitialized(false)
569 , m_info(ComputationInfo::InvalidInput)
571 , m_workspace(matrix.
rows(), m_blocksize) {
572 this->
compute(matrix.derived());
577 EIGEN_DEVICE_FUNC
inline Index
rows() const noexcept {
578 return m_matrix.rows();
582 EIGEN_DEVICE_FUNC
inline Index
cols() const noexcept {
583 return m_matrix.cols();
592 template <
typename InputType>
595 ComputationInfo
info()
const {
return m_info; }
597#ifdef EIGEN_PARSED_BY_DOXYGEN
598 template <
typename Rhs>
599 inline const Solve<LDLT, Rhs> solve(
const MatrixBase<Rhs> &b)
const;
602#ifndef EIGEN_PARSED_BY_DOXYGEN
603 template <
typename RhsType,
typename DstType>
606 template <
bool Conjugate,
typename RhsType,
typename DstType>
609 template <
typename RhsType>
618 bool m_isInitialized;
619 ComputationInfo m_info;
624#ifndef EIGEN_PARSED_BY_DOXYGEN
625template <
typename MatrixType_,
int UpLo_>
626template <
typename RhsType,
typename DstType>
628 DstType &dst)
const {
630 internal::bunch_kaufman_solve_in_place<false>(this->m_matrix, this->m_subdiag,
631 this->m_pivots, dst);
634template <
typename MatrixType_,
int UpLo_>
635template <
bool Conjugate,
typename RhsType,
typename DstType>
637 const RhsType &rhs, DstType &dst)
const {
639 internal::bunch_kaufman_solve_in_place<!Conjugate>(
640 this->m_matrix, this->m_subdiag, this->m_pivots, dst);
643template <
typename MatrixType_,
int UpLo_>
644template <
typename RhsType>
646 Eigen::MatrixBase<RhsType> &bAndX)
const {
647 bAndX = this->solve(bAndX);
653template <
typename MatrixType_,
int UpLo_>
654template <
typename InputType>
657 eigen_assert(a.rows() == a.cols());
659 this->m_matrix.resize(n, n);
660 this->m_subdiag.resize(n);
661 this->m_pivots.resize(n);
663 this->m_matrix.setZero();
664 this->m_subdiag.setZero();
665 this->m_pivots.setZero();
667 this->m_workspace.setZero(n, this->m_blocksize);
669 this->m_matrix.template triangularView<Lower>() =
670 a.derived().template triangularView<UpLo_>();
671 this->m_info = internal::bunch_kaufman_in_place(
672 this->m_matrix, this->m_subdiag, this->m_pivots, this->m_workspace,
673 this->m_pivot_count);
674 this->m_isInitialized =
true;
::aligator::context::Scalar Scalar
typename Transpositions< RowsAtCompileTime, MaxRowsAtCompileTime >::IndicesType IndicesType
SolverBase< BunchKaufman > Base
EIGEN_DEVICE_FUNC Index cols() const noexcept
static constexpr Index BlockSize
typename MatrixType::PlainObject PlainObject
const MatrixType & matrixLDLT() const
void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const
BunchKaufman(const EigenBase< InputType > &matrix)
const IndicesType & pivots() const
void _solve_impl(const RhsType &rhs, DstType &dst) const
EIGEN_DEVICE_FUNC Index rows() const noexcept
const VecType & subdiag() const
bool solveInPlace(Eigen::MatrixBase< RhsType > &bAndX) const
ComputationInfo info() const
BunchKaufman & compute(const EigenBase< InputType > &matrix)
PermutationMatrix< RowsAtCompileTime, MaxRowsAtCompileTime > PermutationType
Matrix< Scalar, RowsAtCompileTime, 1, Eigen::DontAlign, MaxRowsAtCompileTime > VecType
typename Transpositions< RowsAtCompileTime, MaxRowsAtCompileTime >::IndicesType IndicesType
static constexpr Index BlockSize
typename MatrixType::PlainObject PlainObject
void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const
void _solve_impl(const RhsType &rhs, DstType &dst) const
bool solveInPlace(Eigen::MatrixBase< RhsType > &bAndX) const
Matrix< Scalar, RowsAtCompileTime, 1, Eigen::DontAlign, MaxRowsAtCompileTime > VecType