aligator  0.9.0
A primal-dual augmented Lagrangian-type solver for nonlinear trajectory optimization.
Loading...
Searching...
No Matches
parallel-solver.hpp
Go to the documentation of this file.
1
3#pragma once
4
7
8namespace aligator {
9namespace gar {
10
11#ifdef ALIGATOR_MULTITHREADING
21template <typename _Scalar>
22class ParallelRiccatiSolver : public RiccatiSolverBase<_Scalar> {
23public:
24 using Scalar = _Scalar;
26 using Base = RiccatiSolverBase<Scalar>;
27 using StageFactorVec = std::vector<StageFactor<Scalar>>;
28 StageFactorVec datas;
29
30 using Impl = ProximalRiccatiKernel<Scalar>;
31 using KnotType = LQRKnotTpl<Scalar>;
32
33 using BlkMat = BlkMatrix<MatrixXs, -1, -1>;
34 using BlkVec = BlkMatrix<VectorXs, -1, 1>;
35
36 explicit ParallelRiccatiSolver(LQRProblemTpl<Scalar> &problem,
37 const uint num_threads);
38
39 void allocateLeg(uint start, uint end, bool last_leg);
40
41 static void setupKnot(KnotType &knot, const Scalar mudyn) {
42 ALIGATOR_TRACY_ZONE_SCOPED;
44 knot.Gx = knot.A.transpose();
45 knot.Gu = knot.B.transpose();
46 knot.Gth.setZero();
47 knot.Gth.diagonal().setConstant(-mudyn);
48 knot.gamma = knot.f;
49 }
50
51 bool backward(const Scalar mudyn, const Scalar mueq);
52
53 inline void collapseFeedback() {
54 using RowMatrix = Eigen::Matrix<Scalar, -1, -1, Eigen::RowMajor>;
55 StageFactor<Scalar> &d = datas[0];
56 Eigen::Ref<RowMatrix> K = d.fb.blockRow(0);
57 Eigen::Ref<RowMatrix> Kth = d.fth.blockRow(0);
58
59 // condensedSystem.subdiagonal contains the 'U' factors in the
60 // block-tridiag UDUt decomposition
61 // and ∂Xi+1 = -Ui+1.t ∂Xi
62 auto &Up1t = condensedKktSystem.subdiagonal[1];
63 K.noalias() -= Kth * Up1t;
64 }
65
66 struct condensed_system_t {
67 std::vector<MatrixXs> subdiagonal;
68 std::vector<MatrixXs> diagonal;
69 std::vector<MatrixXs> superdiagonal;
70 };
71
72 struct condensed_system_factor {
73 std::vector<MatrixXs> diagonalFacs; //< diagonal factors
74 std::vector<MatrixXs> upFacs; //< transposed U factors
75 std::vector<Eigen::BunchKaufman<MatrixXs>> ldlt;
76 };
77
79 void assembleCondensedSystem(const Scalar mudyn);
80
81 bool forward(VectorOfVectors &xs, VectorOfVectors &us, VectorOfVectors &vs,
82 VectorOfVectors &lbdas,
83 const std::optional<ConstVectorRef> & = std::nullopt) const;
84
85 void cycleAppend(const KnotType &knot);
86 VectorRef getFeedforward(size_t i) { return datas[i].ff.matrix(); }
87 RowMatrixRef getFeedback(size_t i) { return datas[i].fb.matrix(); }
88
90 uint numThreads;
91
93 condensed_system_t condensedKktSystem;
95 condensed_system_factor condensedFacs;
97 BlkVec condensedKktRhs, condensedKktSolution;
98
100 void initializeTridiagSystem(const std::vector<long> &dims);
101
102protected:
103 LQRProblemTpl<Scalar> *problem_;
104};
105#endif
106
107} // namespace gar
108} // namespace aligator
109
110#ifdef ALIGATOR_ENABLE_TEMPLATE_INSTANTIATION
111#include "aligator/gar/parallel-solver.txx"
112#endif
#define ALIGATOR_NOMALLOC_SCOPED
#define ALIGATOR_DYNAMIC_TYPEDEFS_WITH_ROW_TYPES(Scalar)
Definition math.hpp:10
Main package namespace.
unsigned int uint
Definition logger.hpp:10