11#ifdef ALIGATOR_MULTITHREADING
21template <
typename _Scalar>
22class ParallelRiccatiSolver :
public RiccatiSolverBase<_Scalar> {
24 using Scalar = _Scalar;
26 using Base = RiccatiSolverBase<Scalar>;
27 using StageFactorVec = std::vector<StageFactor<Scalar>>;
30 using Impl = ProximalRiccatiKernel<Scalar>;
31 using KnotType = LQRKnotTpl<Scalar>;
33 using BlkMat = BlkMatrix<MatrixXs, -1, -1>;
34 using BlkVec = BlkMatrix<VectorXs, -1, 1>;
36 explicit ParallelRiccatiSolver(LQRProblemTpl<Scalar> &problem,
37 const uint num_threads);
39 void allocateLeg(uint start, uint end,
bool last_leg);
41 static void setupKnot(KnotType &knot,
const Scalar mudyn) {
42 ALIGATOR_TRACY_ZONE_SCOPED;
44 knot.Gx = knot.A.transpose();
45 knot.Gu = knot.B.transpose();
47 knot.Gth.diagonal().setConstant(-mudyn);
51 bool backward(
const Scalar mudyn,
const Scalar mueq);
53 inline void collapseFeedback() {
54 using RowMatrix = Eigen::Matrix<Scalar, -1, -1, Eigen::RowMajor>;
55 StageFactor<Scalar> &d = datas[0];
56 Eigen::Ref<RowMatrix> K = d.fb.blockRow(0);
57 Eigen::Ref<RowMatrix> Kth = d.fth.blockRow(0);
62 auto &Up1t = condensedKktSystem.subdiagonal[1];
63 K.noalias() -= Kth * Up1t;
66 struct condensed_system_t {
67 std::vector<MatrixXs> subdiagonal;
68 std::vector<MatrixXs> diagonal;
69 std::vector<MatrixXs> superdiagonal;
72 struct condensed_system_factor {
73 std::vector<MatrixXs> diagonalFacs;
74 std::vector<MatrixXs> upFacs;
75 std::vector<Eigen::BunchKaufman<MatrixXs>> ldlt;
79 void assembleCondensedSystem(
const Scalar mudyn);
81 bool forward(VectorOfVectors &xs, VectorOfVectors &us, VectorOfVectors &vs,
82 VectorOfVectors &lbdas,
83 const std::optional<ConstVectorRef> & = std::nullopt)
const;
85 void cycleAppend(
const KnotType &knot);
86 VectorRef getFeedforward(
size_t i) {
return datas[i].ff.matrix(); }
87 RowMatrixRef getFeedback(
size_t i) {
return datas[i].fb.matrix(); }
93 condensed_system_t condensedKktSystem;
95 condensed_system_factor condensedFacs;
97 BlkVec condensedKktRhs, condensedKktSolution;
100 void initializeTridiagSystem(
const std::vector<long> &dims);
103 LQRProblemTpl<Scalar> *problem_;
110#ifdef ALIGATOR_ENABLE_TEMPLATE_INSTANTIATION
111#include "aligator/gar/parallel-solver.txx"
#define ALIGATOR_NOMALLOC_SCOPED
#define ALIGATOR_DYNAMIC_TYPEDEFS_WITH_ROW_TYPES(Scalar)