9template <
typename MatrixType_,
int UpLo_ = Lower>
struct BunchKaufman;
12template <
typename MatrixType_,
int UpLo_>
13struct traits<
BunchKaufman<MatrixType_, UpLo_>> : traits<MatrixType_> {
14 typedef MatrixXpr XprKind;
15 typedef SolverStorage StorageKind;
16 typedef int StorageIndex;
20template <
typename MatrixType,
typename IndicesType>
21ComputationInfo bunch_kaufman_in_place_unblocked(MatrixType &a,
24 using Scalar =
typename MatrixType::Scalar;
25 using Real =
typename Eigen::NumTraits<Scalar>::Real;
26 const Real alpha = (Real(1) + numext::sqrt(Real(17))) / Real(8);
31 const Index n = a.rows();
35 if (numext::abs(numext::real(a(0, 0))) == Real(0)) {
36 return NumericalIssue;
38 a(0, 0) = Real(1) / numext::real(a(0, 0));
46 Real abs_akk = numext::abs(numext::real(a(k, k)));
48 Real colmax = Real(0);
51 colmax = a.col(k).segment(k + 1, n - k - 1).cwiseAbs().maxCoeff(&imax);
56 if (numext::maxi(abs_akk, colmax) == Real(0)) {
57 return NumericalIssue;
59 if (abs_akk >= colmax * alpha) {
62 Real rowmax = Real(0);
64 rowmax = a.row(imax).segment(k, imax - k).cwiseAbs().maxCoeff();
66 if (n - imax - 1 > 0) {
67 rowmax = numext::maxi(rowmax, a.col(imax)
68 .segment(imax + 1, n - imax - 1)
73 if (abs_akk >= (alpha * colmax) * (colmax / rowmax)) {
75 }
else if (numext::abs(numext::real(a(imax, imax))) >= alpha * rowmax) {
83 Index kk = k + k_step - 1;
87 .segment(kp + 1, n - kp - 1)
88 .swap(a.col(kp).segment(kp + 1, n - kp - 1));
90 for (Index j = kk + 1; j < kp; ++j) {
91 Scalar tmp = a(j, kk);
92 a(j, kk) = numext::conj(a(kp, j));
93 a(kp, j) = numext::conj(tmp);
95 a(kp, kk) = numext::conj(a(kp, kk));
96 swap(a(kk, kk), a(kp, kp));
99 swap(a(k + 1, k), a(kp, k));
104 Real d11 = Real(1) / numext::real(a(k, k));
105 a(k, k) = Scalar(d11);
107 auto x = a.middleRows(k + 1, n - k - 1).col(k);
109 a.middleRows(k + 1, n - k - 1).middleCols(k + 1, n - k - 1);
111 for (Index j = 0; j < n - k - 1; ++j) {
112 Scalar d11xj = numext::conj(x(j)) * d11;
113 for (Index i = j; i < n - k - 1; ++i) {
115 trailing(i, j) -= d11xj * xi;
117 trailing(j, j) = Scalar(numext::real(trailing(j, j)));
122 Real d21_abs = numext::abs(a(k + 1, k));
123 Real d21_inv = Real(1) / d21_abs;
124 Real d11 = d21_inv * numext::real(a(k + 1, k + 1));
125 Real d22 = d21_inv * numext::real(a(k, k));
127 Real t = Real(1) / ((d11 * d22) - Real(1));
128 Real d = t * d21_inv;
129 Scalar d21 = a(k + 1, k) * d21_inv;
131 a(k, k) = Scalar(d11 * d);
132 a(k + 1, k) = -d21 * d;
133 a(k + 1, k + 1) = Scalar(d22 * d);
135 for (Index j = k + 2; j < n; ++j) {
136 Scalar wk = ((a(j, k) * d11) - (a(j, k + 1) * d21)) * d;
138 ((a(j, k + 1) * d22) - (a(j, k) * numext::conj(d21))) * d;
140 for (Index i = j; i < n; ++i) {
142 a(i, k) * numext::conj(wk) + a(i, k + 1) * numext::conj(wkp1);
144 a(j, j) = Scalar(numext::real(a(j, j)));
157 pivots[k] =
static_cast<int>(kp);
159 pivots[k] =
static_cast<int>(-1 - kp);
160 pivots[k + 1] =
static_cast<int>(-1 - kp);
169template <
typename MatrixType,
typename WType,
typename IndicesType>
171bunch_kaufman_in_place_one_block(MatrixType &a, WType &w, IndicesType &pivots,
172 Index &pivot_count, Index &processed_cols) {
173 using Scalar =
typename MatrixType::Scalar;
174 using Real =
typename Eigen::NumTraits<Scalar>::Real;
178 Real alpha = (Real(1) + numext::sqrt(Real(17))) / Real(8);
187 while (k < n && k + 1 < nb) {
188 w.col(k).segment(k, n - k) = a.col(k).segment(k, n - k);
190 auto w_row = w.row(k).segment(0, k);
191 auto w_col = w.col(k).segment(k, n - k);
192 w_col.noalias() -= a.block(k, 0, n - k, k) * w_row.transpose();
194 w(k, k) = Scalar(numext::real(w(k, k)));
197 Real abs_akk = numext::abs(numext::real(w(k, k)));
199 Real colmax = Real(0);
202 colmax = w.col(k).segment(k + 1, n - k - 1).cwiseAbs().maxCoeff(&imax);
207 if (numext::maxi(abs_akk, colmax) == Real(0)) {
208 return NumericalIssue;
210 if (abs_akk >= colmax * alpha) {
213 w.col(k + 1).segment(k, imax - k) =
214 a.row(imax).segment(k, imax - k).adjoint();
215 w.col(k + 1).segment(imax, n - imax) =
216 a.col(imax).segment(imax, n - imax);
219 auto w_row = w.row(imax).segment(0, k);
220 auto w_col = w.col(k + 1).segment(k, n - k);
221 w_col.noalias() -= a.block(k, 0, n - k, k) * w_row.transpose();
223 w(imax, k + 1) = Scalar(numext::real(w(imax, k + 1)));
225 Real rowmax = Real(0);
227 rowmax = w.col(k + 1).segment(k, imax - k).cwiseAbs().maxCoeff();
229 if (n - imax - 1 > 0) {
230 rowmax = numext::maxi(rowmax, w.col(k + 1)
231 .segment(imax + 1, n - imax - 1)
236 if (abs_akk >= (alpha * colmax) * (colmax / rowmax)) {
238 }
else if (numext::abs(numext::real(w(imax, k + 1))) >=
241 w.col(k).segment(k, n - k) = w.col(k + 1).segment(k, n - k);
248 Index kk = k + k_step - 1;
252 a(kp, kp) = a(kk, kk);
253 for (Index j = kk + 1; j < kp; ++j) {
254 a(kp, j) = numext::conj(a(j, kk));
256 a.col(kp).segment(kp + 1, n - kp - 1) =
257 a.col(kk).segment(kp + 1, n - kp - 1);
258 a.row(kk).segment(0, k).swap(a.row(kp).segment(0, k));
259 w.row(kk).segment(0, kk + 1).swap(w.row(kp).segment(0, kk + 1));
263 a.col(k).segment(k, n - k) = w.col(k).segment(k, n - k);
265 Real d11 = Real(1) / numext::real(w(k, k));
266 a(k, k) = Scalar(d11);
267 auto x = a.middleRows(k + 1, n - k - 1).col(k);
269 w.col(k).segment(k + 1, n - k - 1) =
270 w.col(k).segment(k + 1, n - k - 1).conjugate();
272 Real d21_abs = numext::abs(w(k + 1, k));
273 Real d21_inv = Real(1) / d21_abs;
274 Real d11 = d21_inv * numext::real(w(k + 1, k + 1));
275 Real d22 = d21_inv * numext::real(w(k, k));
277 Real t = Real(1) / ((d11 * d22) - Real(1));
278 Scalar d21 = w(k + 1, k) * d21_inv;
279 Real d = t * d21_inv;
281 a(k, k) = Scalar(d11 * d);
282 a(k + 1, k) = -d21 * d;
283 a(k + 1, k + 1) = Scalar(d22 * d);
285 for (Index j = k + 2; j < n; ++j) {
286 Scalar wk = ((w(j, k) * d11) - (w(j, k + 1) * d21)) * d;
288 ((w(j, k + 1) * d22) - (w(j, k) * numext::conj(d21))) * d;
294 w.col(k).segment(k + 1, n - k - 1) =
295 w.col(k).segment(k + 1, n - k - 1).conjugate();
296 w.col(k + 1).segment(k + 2, n - k - 2) =
297 w.col(k + 1).segment(k + 2, n - k - 2).conjugate();
306 pivots[k] =
static_cast<int>(kp);
308 pivots[k] =
static_cast<int>(-1 - kp);
309 pivots[k + 1] =
static_cast<int>(-1 - kp);
315 auto a_left = a.bottomRows(n - k).leftCols(k);
316 auto a_right = a.bottomRows(n - k).rightCols(n - k);
318 a_right.template triangularView<Lower>() -=
319 a_left * w.block(k, 0, n - k, k).transpose();
325 Index jp = pivots[j];
336 a.row(jp).segment(0, j + 1).swap(a.row(jj).segment(0, j + 1));
344template <
typename MatrixType,
typename VecType,
typename IndicesType,
345 typename WorkspaceType>
346ComputationInfo bunch_kaufman_in_place(MatrixType &a, VecType &subdiag,
347 IndicesType &pivots, WorkspaceType &w,
348 Index &pivot_count) {
351 const Index blocksize = w.cols();
356 Index k_pivot_count = 0;
357 auto a_block = a.block(k, k, n - k, n - k);
358 auto pivots_block = pivots.segment(k, n - k);
359 ComputationInfo info = InvalidInput;
360 if (blocksize != 0 && blocksize < n - k) {
361 info = internal::bunch_kaufman_in_place_one_block(
362 a_block, w, pivots_block, k_pivot_count, kb);
364 info = internal::bunch_kaufman_in_place_unblocked(a_block, pivots_block,
368 if (info != Success) {
372 for (Index j = k; j < k + kb; ++j) {
373 auto &p = pivots.coeffRef(j);
379 p +=
static_cast<int>(k);
381 p -=
static_cast<int>(k);
385 pivot_count += k_pivot_count;
389 using Scalar =
typename MatrixType::Scalar;
394 subdiag(k) = a(k + 1, k);
395 subdiag(k + 1) = Scalar(0);
396 a(k + 1, k) = Scalar(0);
399 subdiag(k) = Scalar(0);
409 a.row(k + 1).segment(0, k).swap(a.row(p).segment(0, k));
412 a.row(k).segment(0, k).swap(a.row(p).segment(0, k));
420template <
typename MatrixType,
bool Conjugate>
struct BK_Traits;
422template <
typename MatrixType>
struct BK_Traits<MatrixType, false> {
423 typedef TriangularView<const MatrixType, UnitLower> MatrixL;
424 typedef TriangularView<
const typename MatrixType::AdjointReturnType,
427 static inline MatrixL getL(
const MatrixType &m) {
return MatrixL(m); }
428 static inline MatrixU getU(
const MatrixType &m) {
429 return MatrixU(m.adjoint());
433template <
typename MatrixType>
struct BK_Traits<MatrixType, true> {
434 typedef typename MatrixType::ConjugateReturnType ConjugateReturnType;
435 typedef TriangularView<const ConjugateReturnType, UnitLower> MatrixL;
436 typedef TriangularView<
const typename MatrixType::TransposeReturnType,
439 static inline MatrixL getL(
const MatrixType &m) {
440 return MatrixL(m.conjugate());
442 static inline MatrixU getU(
const MatrixType &m) {
443 return MatrixU(m.transpose());
447template <
bool Conjugate,
typename MatrixType,
typename VecType,
448 typename IndicesType,
typename Rhs>
449void bunch_kaufman_solve_in_place(MatrixType
const &L, VecType
const &subdiag,
450 IndicesType
const &pivots, Rhs &x) {
460 x.row(k + 1).swap(x.row(p));
463 x.row(k).swap(x.row(p));
468 using Traits = BK_Traits<MatrixType, Conjugate>;
470 Traits::getL(L).solveInPlace(x);
476 using Scalar =
typename MatrixType::Scalar;
477 using Real =
typename Eigen::NumTraits<Scalar>::Real;
479 Scalar akp1k = subdiag(k);
480 Real ak = numext::real(L(k, k));
481 Real akp1 = numext::real(L(k + 1, k + 1));
484 akp1k = numext::conj(akp1k);
487 for (Index j = 0; j < x.cols(); ++j) {
489 Scalar xkp1 = x(k + 1, j);
491 x(k, j) = xk * ak + xkp1 * numext::conj(akp1k);
492 x(k + 1, j) = xkp1 * akp1 + xk * akp1k;
497 x.row(k) *= numext::real(L(k, k));
502 Traits::getU(L).solveInPlace(x);
510 x.row(k).swap(x.row(p));
513 x.row(k).swap(x.row(p));
519template <
typename MatrixType_,
int UpLo_>
528 using Base = SolverBase<BunchKaufman>;
536 typename Transpositions<RowsAtCompileTime,
542 : m_matrix(), m_subdiag(), m_pivot_count(0), m_pivots(),
543 m_isInitialized(false), m_info(ComputationInfo::InvalidInput),
544 m_blocksize(), m_workspace() {}
546 : m_matrix(size, size), m_subdiag(size), m_pivot_count(0), m_pivots(size),
547 m_isInitialized(false), m_info(ComputationInfo::InvalidInput),
549 m_workspace(size, m_blocksize) {}
551 template <
typename InputType>
553 : m_matrix(matrix.
rows(), matrix.
cols()), m_subdiag(matrix.
rows()),
554 m_pivot_count(0), m_pivots(matrix.
rows()), m_isInitialized(false),
555 m_info(ComputationInfo::InvalidInput),
557 m_workspace(matrix.
rows(), m_blocksize) {
558 this->
compute(matrix.derived());
563 EIGEN_DEVICE_FUNC
inline Index
rows() const EIGEN_NOEXCEPT {
564 return m_matrix.rows();
568 EIGEN_DEVICE_FUNC
inline Index
cols() const EIGEN_NOEXCEPT {
569 return m_matrix.cols();
578 template <
typename InputType>
581 ComputationInfo
info()
const {
return m_info; }
583#ifdef EIGEN_PARSED_BY_DOXYGEN
584 template <
typename Rhs>
585 inline const Solve<LDLT, Rhs> solve(
const MatrixBase<Rhs> &b)
const;
588#ifndef EIGEN_PARSED_BY_DOXYGEN
589 template <
typename RhsType,
typename DstType>
590 void _solve_impl(
const RhsType &rhs, DstType &dst)
const;
592 template <
bool Conjugate,
typename RhsType,
typename DstType>
595 template <
typename RhsType>
596 bool solveInPlace(Eigen::MatrixBase<RhsType> &bAndX)
const;
604 bool m_isInitialized;
605 ComputationInfo m_info;
610#ifndef EIGEN_PARSED_BY_DOXYGEN
611template <
typename MatrixType_,
int UpLo_>
612template <
typename RhsType,
typename DstType>
614 DstType &dst)
const {
616 internal::bunch_kaufman_solve_in_place<false>(this->m_matrix, this->m_subdiag,
617 this->m_pivots, dst);
620template <
typename MatrixType_,
int UpLo_>
621template <
bool Conjugate,
typename RhsType,
typename DstType>
623 const RhsType &rhs, DstType &dst)
const {
625 internal::bunch_kaufman_solve_in_place<!Conjugate>(
626 this->m_matrix, this->m_subdiag, this->m_pivots, dst);
629template <
typename MatrixType_,
int UpLo_>
630template <
typename RhsType>
632 Eigen::MatrixBase<RhsType> &bAndX)
const {
633 bAndX = this->solve(bAndX);
639template <
typename MatrixType_,
int UpLo_>
640template <
typename InputType>
643 eigen_assert(a.rows() == a.cols());
645 this->m_matrix.resize(n, n);
646 this->m_subdiag.resize(n);
647 this->m_pivots.resize(n);
649 this->m_matrix.setZero();
650 this->m_subdiag.setZero();
651 this->m_pivots.setZero();
652 this->m_blocksize = n <= BlockSize ? 0 : BlockSize;
653 this->m_workspace.setZero(n, this->m_blocksize);
655 this->m_matrix.template triangularView<Lower>() =
656 a.derived().template triangularView<UpLo_>();
657 this->m_info = internal::bunch_kaufman_in_place(
658 this->m_matrix, this->m_subdiag, this->m_pivots, this->m_workspace,
659 this->m_pivot_count);
660 this->m_isInitialized =
true;