aligator  0.6.1
A primal-dual augmented Lagrangian-type solver for nonlinear trajectory optimization.
Loading...
Searching...
No Matches
quad-costs.hpp
Go to the documentation of this file.
1#pragma once
2
4#include <proxsuite-nlp/modelling/spaces/vector-space.hpp>
5
6namespace aligator {
7
8template <typename Scalar> struct QuadraticCostDataTpl;
9
11template <typename _Scalar> struct QuadraticCostTpl : CostAbstractTpl<_Scalar> {
12public:
13 using Scalar = _Scalar;
15 using Base = CostAbstractTpl<Scalar>;
16 using CostData = CostDataAbstractTpl<Scalar>;
17
18 using Data = QuadraticCostDataTpl<Scalar>;
19 using VectorSpace = proxsuite::nlp::VectorSpaceTpl<Scalar, Eigen::Dynamic>;
20
22 MatrixXs Wxx_;
24 MatrixXs Wuu_;
25
26protected:
28 MatrixXs Wxu_;
29
30public:
32 VectorXs interp_x;
34 VectorXs interp_u;
35
36 static auto get_vector_space(Eigen::Index nx) {
37 return std::make_shared<VectorSpace>((int)nx);
38 }
39
40 QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u,
41 const ConstVectorRef &interp_x,
42 const ConstVectorRef &interp_u)
43 : Base(get_vector_space(w_x.cols()), (int)w_u.cols()), Wxx_(w_x),
44 Wuu_(w_u), Wxu_(this->ndx(), this->nu), interp_x(interp_x),
46 debug_check_dims();
47 Wxu_.setZero();
48 }
49
50 QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u,
51 const ConstMatrixRef &w_cross,
52 const ConstVectorRef &interp_x,
53 const ConstVectorRef &interp_u)
54 : Base(get_vector_space(w_x.cols()), (int)w_u.cols()), Wxx_(w_x),
55 Wuu_(w_u), Wxu_(w_cross), interp_x(interp_x), interp_u(interp_u),
56 has_cross_term_(true) {
57 debug_check_dims();
58 }
59
60 QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u)
61 : QuadraticCostTpl(w_x, w_u, VectorXs::Zero(w_x.cols()),
62 VectorXs::Zero(w_u.cols())) {}
63
64 QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u,
65 const ConstMatrixRef &w_cross)
66 : QuadraticCostTpl(w_x, w_u, w_cross, VectorXs::Zero(w_x.cols()),
67 VectorXs::Zero(w_u.cols())) {}
68
69 void evaluate(const ConstVectorRef &x, const ConstVectorRef &u,
70 CostData &data) const {
71 Data &d = static_cast<Data &>(data);
72 d.w_times_x_.noalias() = Wxx_ * x;
73 d.w_times_u_.noalias() = Wuu_ * u;
74 if (has_cross_term_) {
75 d.cross_x_.noalias() = Wxu_ * u;
76 d.cross_u_.noalias() = Wxu_.transpose() * x;
77
78 d.w_times_x_ += d.cross_x_;
79 d.w_times_u_ += d.cross_u_;
80 }
81 data.value_ = Scalar(0.5) * x.dot(d.w_times_x_ + 2 * interp_x) +
82 Scalar(0.5) * u.dot(d.w_times_u_ + 2 * interp_u);
83 }
84
85 void computeGradients(const ConstVectorRef &, const ConstVectorRef &,
86 CostData &data) const {
87 Data &d = static_cast<Data &>(data);
88 d.Lx_ = d.w_times_x_ + interp_x;
89 d.Lu_ = d.w_times_u_ + interp_u;
90 }
91
92 void computeHessians(const ConstVectorRef &, const ConstVectorRef &,
93 CostData &) const {}
94
95 shared_ptr<CostData> createData() const {
96 auto data = std::make_shared<Data>(this->ndx(), this->nu);
97 data->Lxx_ = Wxx_;
98 data->Luu_ = Wuu_;
99 data->Lxu_ = Wxu_;
100 data->Lux_ = Wxu_.transpose();
101 return data;
102 }
103
104 ConstMatrixRef getCrossWeights() const { return Wxu_; }
105 void setCrossWeight(const ConstMatrixRef &w) {
106 Wxu_ = w;
107 has_cross_term_ = true;
108 debug_check_dims();
109 }
110
112 bool hasCrossTerm() const { return has_cross_term_; }
113
114protected:
117
118private:
119 static void _check_dim_equal(long n, long m, const std::string &msg = "") {
120 if (n != m)
121 ALIGATOR_RUNTIME_ERROR(fmt::format(
122 "Dimensions inconsistent: got {:d} and {:d}{}.\n", n, m, msg));
123 }
124
125 void debug_check_dims() const {
126 _check_dim_equal(Wxx_.rows(), Wxx_.cols(), " for x weights");
127 _check_dim_equal(Wuu_.rows(), Wuu_.cols(), " for u weights");
128 _check_dim_equal(Wxu_.rows(), this->ndx(), " for cross-term weight");
129 _check_dim_equal(Wxu_.cols(), this->nu, " for cross-term weight");
130 _check_dim_equal(interp_x.rows(), Wxx_.rows(),
131 " for x weights and intercept");
132 _check_dim_equal(interp_u.rows(), Wuu_.rows(),
133 " for u weights and intercept");
134 }
135};
136
137template <typename Scalar>
139 using Base = CostDataAbstractTpl<Scalar>;
140 using VectorXs = typename Base::VectorXs;
142
143 QuadraticCostDataTpl(const int nx, const int nu)
144 : Base(nx, nu), w_times_x_(nx), w_times_u_(nu), cross_x_(nu),
145 cross_u_(nu) {
146 w_times_x_.setZero();
147 w_times_u_.setZero();
148 cross_x_.setZero();
149 cross_u_.setZero();
150 }
151};
152
153} // namespace aligator
154
155#ifdef ALIGATOR_ENABLE_TEMPLATE_INSTANTIATION
156#include "./quad-costs.txx"
157#endif
#define ALIGATOR_RUNTIME_ERROR(msg)
Definition exceptions.hpp:6
Main package namespace.
Stage costs for control problems.
int nu
Control dimension.
QuadraticCostDataTpl(const int nx, const int nu)
typename Base::VectorXs VectorXs
Euclidean quadratic cost.
VectorXs interp_x
Affine term in .
QuadraticCostDataTpl< Scalar > Data
QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u)
VectorXs interp_u
Affine term in .
void computeGradients(const ConstVectorRef &, const ConstVectorRef &, CostData &data) const
Compute the cost gradients .
QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u, const ConstMatrixRef &w_cross)
shared_ptr< CostData > createData() const
void evaluate(const ConstVectorRef &x, const ConstVectorRef &u, CostData &data) const
Evaluate the cost function.
MatrixXs Wxu_
Weight N for term .
bool has_cross_term_
Whether a cross term exists.
ConstMatrixRef getCrossWeights() const
bool hasCrossTerm() const
Whether a cross term exists.
QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u, const ConstVectorRef &interp_x, const ConstVectorRef &interp_u)
static auto get_vector_space(Eigen::Index nx)
proxsuite::nlp::VectorSpaceTpl< Scalar, Eigen::Dynamic > VectorSpace
void computeHessians(const ConstVectorRef &, const ConstVectorRef &, CostData &) const
Compute the cost Hessians .
void setCrossWeight(const ConstMatrixRef &w)
QuadraticCostTpl(const ConstMatrixRef &w_x, const ConstMatrixRef &w_u, const ConstMatrixRef &w_cross, const ConstVectorRef &interp_x, const ConstVectorRef &interp_u)