aligator  0.14.0
A primal-dual augmented Lagrangian-type solver for nonlinear trajectory optimization.
Loading...
Searching...
No Matches
quad-costs.hpp
Go to the documentation of this file.
1#pragma once
2
7
8namespace aligator {
9
10template <typename Scalar> struct QuadraticCostDataTpl;
11
13template <typename _Scalar> struct QuadraticCostTpl : CostAbstractTpl<_Scalar> {
14public:
15 using Scalar = _Scalar;
19
23
25 MatrixXs Wxx_;
27 MatrixXs Wuu_;
28
29protected:
31 MatrixXs Wxu_;
32
33public:
35 VectorXs interp_x;
37 VectorXs interp_u;
38
39 static auto get_vector_space(Eigen::Index nx) {
41 }
42
43 QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u,
44 const ConstVectorRef &interp_x,
45 const ConstVectorRef &interp_u)
46 : Base(get_vector_space(w_x.cols()), (int)w_u.cols())
47 , Wxx_(w_x)
48 , Wuu_(w_u)
49 , Wxu_(this->ndx(), this->nu)
52 , has_cross_term_(false) {
53 debug_check_dims();
54 Wxu_.setZero();
55 }
56
57 QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u,
58 const ConstMatrixRef &w_cross,
59 const ConstVectorRef &interp_x,
60 const ConstVectorRef &interp_u)
61 : Base(get_vector_space(w_x.cols()), (int)w_u.cols())
62 , Wxx_(w_x)
63 , Wuu_(w_u)
64 , Wxu_(w_cross)
67 , has_cross_term_(true) {
68 debug_check_dims();
69 }
70
71 QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u)
72 : QuadraticCostTpl(w_x, w_u, VectorXs::Zero(w_x.cols()),
73 VectorXs::Zero(w_u.cols())) {}
74
75 QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u,
76 const ConstMatrixRef &w_cross)
77 : QuadraticCostTpl(w_x, w_u, w_cross, VectorXs::Zero(w_x.cols()),
78 VectorXs::Zero(w_u.cols())) {}
79
80 void evaluate(const ConstVectorRef &x, const ConstVectorRef &u,
81 CostData &data) const {
82 Data &d = static_cast<Data &>(data);
83 d.w_times_x_.noalias() = Wxx_ * x;
84 d.w_times_u_.noalias() = Wuu_ * u;
85 if (has_cross_term_) {
86 d.cross_x_.noalias() = Wxu_ * u;
87 d.cross_u_.noalias() = Wxu_.transpose() * x;
88
89 d.w_times_x_ += d.cross_x_;
90 d.w_times_u_ += d.cross_u_;
91 }
92 data.value_ = Scalar(0.5) * x.dot(d.w_times_x_ + 2 * interp_x) +
93 Scalar(0.5) * u.dot(d.w_times_u_ + 2 * interp_u);
94 }
95
96 void computeGradients(const ConstVectorRef &, const ConstVectorRef &,
97 CostData &data) const {
98 Data &d = static_cast<Data &>(data);
99 d.Lx_ = d.w_times_x_ + interp_x;
100 d.Lu_ = d.w_times_u_ + interp_u;
101 }
102
103 void computeHessians(const ConstVectorRef &, const ConstVectorRef &,
104 CostData &) const {}
105
106 shared_ptr<CostData> createData() const {
107 auto data = std::make_shared<Data>(this->ndx(), this->nu);
108 data->Lxx_ = Wxx_;
109 data->Luu_ = Wuu_;
110 data->Lxu_ = Wxu_;
111 data->Lux_ = Wxu_.transpose();
112 return data;
113 }
114
115 ConstMatrixRef getCrossWeights() const { return Wxu_; }
116 void setCrossWeight(const ConstMatrixRef &w) {
117 Wxu_ = w;
118 has_cross_term_ = true;
119 debug_check_dims();
120 }
121
123 bool hasCrossTerm() const { return has_cross_term_; }
124
125protected:
128
129private:
130 static void _check_dim_equal(long n, long m, const std::string &msg = "") {
131 if (n != m)
132 ALIGATOR_RUNTIME_ERROR("Dimensions inconsistent: got {:d} and {:d}{}.\n",
133 n, m, msg);
134 }
135
136 void debug_check_dims() const {
137 _check_dim_equal(Wxx_.rows(), Wxx_.cols(), " for x weights");
138 _check_dim_equal(Wuu_.rows(), Wuu_.cols(), " for u weights");
139 _check_dim_equal(Wxu_.rows(), this->ndx(), " for cross-term weight");
140 _check_dim_equal(Wxu_.cols(), this->nu, " for cross-term weight");
141 _check_dim_equal(interp_x.rows(), Wxx_.rows(),
142 " for x weights and intercept");
143 _check_dim_equal(interp_u.rows(), Wuu_.rows(),
144 " for u weights and intercept");
145 }
146};
147
148template <typename Scalar>
151 using VectorXs = typename Base::VectorXs;
153
154 QuadraticCostDataTpl(const int nx, const int nu)
155 : Base(nx, nu)
156 , w_times_x_(nx)
157 , w_times_u_(nu)
158 , cross_x_(nu)
159 , cross_u_(nu) {
160 w_times_x_.setZero();
161 w_times_u_.setZero();
162 cross_x_.setZero();
163 cross_u_.setZero();
164 }
165};
166
167} // namespace aligator
168
169#ifdef ALIGATOR_ENABLE_TEMPLATE_INSTANTIATION
170#include "./quad-costs.txx"
171#endif
#define ALIGATOR_RUNTIME_ERROR(...)
Definition exceptions.hpp:7
Main package namespace.
CostAbstractTpl(U &&space, const int nu)
Data struct for CostAbstractTpl.
CostDataAbstractTpl(const int ndx, const int nu)
Base class for manifolds, to use in cost funcs, solvers...
CostDataAbstractTpl< Scalar > Base
QuadraticCostDataTpl(const int nx, const int nu)
typename Base::VectorXs VectorXs
QuadraticCostDataTpl< Scalar > Data
QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u)
void computeGradients(const ConstVectorRef &, const ConstVectorRef &, CostData &data) const
Compute the cost gradients .
QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u, const ConstMatrixRef &w_cross)
shared_ptr< CostData > createData() const
void evaluate(const ConstVectorRef &x, const ConstVectorRef &u, CostData &data) const
Evaluate the cost function.
ConstMatrixRef getCrossWeights() const
ManifoldAbstractTpl< Scalar > Manifold
CostAbstractTpl< Scalar > Base
bool hasCrossTerm() const
Whether a cross term exists.
CostDataAbstractTpl< Scalar > CostData
QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u, const ConstVectorRef &interp_x, const ConstVectorRef &interp_u)
static auto get_vector_space(Eigen::Index nx)
void computeHessians(const ConstVectorRef &, const ConstVectorRef &, CostData &) const
Compute the cost Hessians .
::aligator::VectorSpaceTpl< Scalar, Eigen::Dynamic > VectorSpace
void setCrossWeight(const ConstMatrixRef &w)
QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u, const ConstMatrixRef &w_cross, const ConstVectorRef &interp_x, const ConstVectorRef &interp_u)
Standard Euclidean vector space.