8#include "aligator/tracy.hpp"
13#ifdef ALIGATOR_MULTITHREADING
23template <
typename _Scalar>
28 using Base = RiccatiSolverBase<Scalar>;
29 using StageFactorVec = std::vector<StageFactor<Scalar>>;
32 using Kernel = ProximalRiccatiKernel<Scalar>;
33 using KnotType = LqrKnotTpl<Scalar>;
35 using BlkMat = BlkMatrix<MatrixXs, -1, -1>;
36 using BlkVec = BlkMatrix<VectorXs, -1, 1>;
38 explicit ParallelRiccatiSolver(LqrProblemTpl<Scalar> &problem,
39 const uint num_threads);
41 void allocateLeg(uint start, uint end,
bool last_leg);
43 static void setupKnot(KnotType &knot,
const Scalar mudyn) {
44 ALIGATOR_TRACY_ZONE_SCOPED;
46 knot.Gx = knot.A.to_const_map().transpose();
47 knot.Gu = knot.B.to_const_map().transpose();
49 knot.Gth.diagonal().setConstant(-mudyn);
53 bool backward(
const Scalar mudyn,
const Scalar mueq);
55 inline void collapseFeedback() {
56 using RowMatrix = Eigen::Matrix<
Scalar, -1, -1, Eigen::RowMajor>;
57 StageFactor<Scalar> &d = datas[0];
58 Eigen::Ref<RowMatrix> K = d.fb.blockRow(0);
59 Eigen::Ref<RowMatrix> Kth = d.fth.blockRow(0);
64 auto &Up1t = condensedKktSystem.subdiagonal[1];
65 K.noalias() -= Kth * Up1t;
68 struct condensed_system_t {
69 std::vector<MatrixXs> subdiagonal;
70 std::vector<MatrixXs> diagonal;
71 std::vector<MatrixXs> superdiagonal;
74 struct condensed_system_factor {
75 std::vector<MatrixXs> diagonalFacs;
76 std::vector<MatrixXs> upFacs;
77 std::vector<Eigen::BunchKaufman<MatrixXs>> ldlt;
81 void assembleCondensedSystem(
const Scalar mudyn);
83 bool forward(VectorOfVectors &xs, VectorOfVectors &us, VectorOfVectors &vs,
84 VectorOfVectors &lbdas,
85 const std::optional<ConstVectorRef> & = std::nullopt)
const;
87 void cycleAppend(
const KnotType &knot);
88 VectorRef getFeedforward(
size_t i) {
return datas[i].ff.matrix(); }
89 RowMatrixRef getFeedback(
size_t i) {
return datas[i].fb.matrix(); }
95 condensed_system_t condensedKktSystem;
97 condensed_system_factor condensedFacs;
99 BlkVec condensedKktRhs, condensedKktSolution, condensedErr;
101 Scalar condensedThreshold{1e-11};
104 void initializeTridiagSystem(
const std::vector<long> &dims);
107 LqrProblemTpl<Scalar> *problem_;
114#ifdef ALIGATOR_ENABLE_TEMPLATE_INSTANTIATION
115#include "aligator/gar/parallel-solver.txx"
#define ALIGATOR_NOMALLOC_SCOPED
#define ALIGATOR_DYNAMIC_TYPEDEFS_WITH_ROW_TYPES(Scalar)
::aligator::context::Scalar Scalar