12#ifdef ALIGATOR_MULTITHREADING
22template <
typename _Scalar>
23class ParallelRiccatiSolver :
public RiccatiSolverBase<_Scalar> {
25 using Scalar = _Scalar;
27 using Base = RiccatiSolverBase<Scalar>;
28 using StageFactorVec = std::vector<StageFactor<Scalar>>;
31 using Impl = ProximalRiccatiKernel<Scalar>;
32 using KnotType = LQRKnotTpl<Scalar>;
34 using BlkMat = BlkMatrix<MatrixXs, -1, -1>;
35 using BlkVec = BlkMatrix<VectorXs, -1, 1>;
37 explicit ParallelRiccatiSolver(LQRProblemTpl<Scalar> &problem,
38 const uint num_threads);
40 void allocateLeg(uint start, uint end,
bool last_leg);
42 static void setupKnot(KnotType &knot,
const Scalar mudyn) {
45 knot.Gx = knot.A.transpose();
46 knot.Gu = knot.B.transpose();
48 knot.Gth.diagonal().setConstant(-mudyn);
52 bool backward(
const Scalar mudyn,
const Scalar mueq);
54 inline void collapseFeedback() {
55 using RowMatrix = Eigen::Matrix<Scalar, -1, -1, Eigen::RowMajor>;
56 StageFactor<Scalar> &d = datas[0];
57 Eigen::Ref<RowMatrix> K = d.fb.blockRow(0);
58 Eigen::Ref<RowMatrix> Kth = d.fth.blockRow(0);
63 auto &Up1t = condensedKktSystem.subdiagonal[1];
64 K.noalias() -= Kth * Up1t;
67 struct condensed_system_t {
68 std::vector<MatrixXs> subdiagonal;
69 std::vector<MatrixXs> diagonal;
70 std::vector<MatrixXs> superdiagonal;
73 struct condensed_system_factor {
74 std::vector<MatrixXs> diagonalFacs;
75 std::vector<MatrixXs> upFacs;
76 std::vector<Eigen::BunchKaufman<MatrixXs>> ldlt;
80 void assembleCondensedSystem(
const Scalar mudyn);
82 bool forward(VectorOfVectors &xs, VectorOfVectors &us, VectorOfVectors &vs,
83 VectorOfVectors &lbdas,
84 const std::optional<ConstVectorRef> & = std::nullopt)
const;
86 VectorRef getFeedforward(
size_t i) {
return datas[i].ff.matrix(); }
87 RowMatrixRef getFeedback(
size_t i) {
return datas[i].fb.matrix(); }
93 condensed_system_t condensedKktSystem;
95 condensed_system_factor condensedFacs;
97 BlkVec condensedKktRhs, condensedKktSolution;
100 void initializeTridiagSystem(
const std::vector<long> &dims);
103 LQRProblemTpl<Scalar> *problem_;
110#include "aligator/gar/parallel-solver.hxx"
112#ifdef ALIGATOR_ENABLE_TEMPLATE_INSTANTIATION
113#include "aligator/gar/parallel-solver.txx"
#define ALIGATOR_NOMALLOC_SCOPED
#define ALIGATOR_DYNAMIC_TYPEDEFS_WITH_ROW_TYPES(Scalar)