aligator  0.9.0
A primal-dual augmented Lagrangian-type solver for nonlinear trajectory optimization.
Loading...
Searching...
No Matches
function-xpr-slice.hpp
Go to the documentation of this file.
1
3#pragma once
4
7#include <proxsuite-nlp/third-party/polymorphic_cxx14.hpp>
8
9namespace aligator {
10
11namespace detail {
12template <typename Base> struct slice_impl_tpl;
13}
14
15template <typename Scalar> struct FunctionSliceDataTpl;
16
20template <typename Scalar, typename Base = StageFunctionTpl<Scalar>>
22
23template <typename Scalar>
25 : StageFunctionTpl<Scalar>,
26 detail::slice_impl_tpl<StageFunctionTpl<Scalar>> {
32
33 FunctionSliceXprTpl(xyz::polymorphic<Base> func,
34 std::vector<int> const &indices)
35 : Base(func->ndx1, func->nu, (int)indices.size()),
36 SliceImpl(func, indices) {}
37
38 FunctionSliceXprTpl(xyz::polymorphic<Base> func, const int idx)
39 : FunctionSliceXprTpl(func, std::vector<int>{idx}) {}
40
41 void evaluate(const ConstVectorRef &x, const ConstVectorRef &u,
42 BaseData &data) const override {
43
44 this->evaluate_impl(data, x, u);
45 }
46
47 void computeJacobians(const ConstVectorRef &x, const ConstVectorRef &u,
48 BaseData &data) const override {
49 this->computeJacobians_impl(data, x, u);
50 }
51
52 void computeVectorHessianProducts(const ConstVectorRef &x,
53 const ConstVectorRef &u,
54 const ConstVectorRef &lbda,
55 BaseData &data) const override {
56 this->computeVectorHessianProducts_impl(data, lbda, x, u);
57 }
58
59 shared_ptr<BaseData> createData() const override {
60 return std::make_shared<Data>(*this);
61 }
62};
63
64template <typename Scalar>
66 : UnaryFunctionTpl<Scalar>,
67 detail::slice_impl_tpl<UnaryFunctionTpl<Scalar>> {
73
74 FunctionSliceXprTpl(xyz::polymorphic<Base> func,
75 std::vector<int> const &indices)
76 : Base(func->ndx1, func->nu, (int)indices.size()),
77 SliceImpl(func, indices) {}
78
79 FunctionSliceXprTpl(xyz::polymorphic<Base> func, const int idx)
80 : FunctionSliceXprTpl(func, std::vector<int>{idx}) {}
81
82 void evaluate(const ConstVectorRef &x, BaseData &data) const override {
83
84 this->evaluate_impl(data, x);
85 }
86
87 void computeJacobians(const ConstVectorRef &x,
88 BaseData &data) const override {
89 this->computeJacobians_impl(data, x);
90 }
91
92 void computeVectorHessianProducts(const ConstVectorRef &x,
93 const ConstVectorRef &lbda,
94 BaseData &data) const override {
95 this->computeVectorHessianProducts_impl(data, lbda, x);
96 }
97
98 shared_ptr<BaseData> createData() const override {
99 return std::make_shared<Data>(*this);
100 }
101};
102
103template <typename Scalar>
108 shared_ptr<BaseData> sub_data;
109 VectorXs lbda_sub;
110
111 template <typename Base>
113 : BaseData(obj.ndx1, obj.nu, obj.nr), sub_data(obj.func->createData()),
114 lbda_sub(obj.nr) {}
115};
116
117namespace detail {
119template <typename Base> struct slice_impl_tpl {
120 using Scalar = typename Base::Scalar;
123
125
126 xyz::polymorphic<Base> func;
128 std::vector<int> indices;
129
130 slice_impl_tpl(xyz::polymorphic<Base> func, std::vector<int> const &indices);
131 slice_impl_tpl(xyz::polymorphic<Base> func, int idx);
132
133protected:
134 template <typename... Args>
135 void evaluate_impl(BaseData &data, Args &&...args) const;
136
137 template <typename... Args>
138 void computeJacobians_impl(BaseData &data, Args &&...args) const;
139
140 template <typename... Args>
142 const ConstVectorRef &lbda,
143 Args &&...args) const;
144};
145} // namespace detail
146
147} // namespace aligator
148
149#include "aligator/modelling/function-xpr-slice.hxx"
150
151#ifdef ALIGATOR_ENABLE_TEMPLATE_INSTANTIATION
152#include "aligator/modelling/function-xpr-slice.txx"
153#endif
Base definitions for ternary functions.
Main package namespace.
FunctionSliceDataTpl(FunctionSliceXprTpl< Scalar, Base > const &obj)
FunctionSliceXprTpl(xyz::polymorphic< Base > func, const int idx)
void computeVectorHessianProducts(const ConstVectorRef &x, const ConstVectorRef &u, const ConstVectorRef &lbda, BaseData &data) const override
Compute the vector-hessian products of this function.
shared_ptr< BaseData > createData() const override
Instantiate a Data object.
void evaluate(const ConstVectorRef &x, const ConstVectorRef &u, BaseData &data) const override
Evaluate the function.
void computeJacobians(const ConstVectorRef &x, const ConstVectorRef &u, BaseData &data) const override
Compute Jacobians of this function.
FunctionSliceXprTpl(xyz::polymorphic< Base > func, std::vector< int > const &indices)
void computeJacobians(const ConstVectorRef &x, BaseData &data) const override
FunctionSliceXprTpl(xyz::polymorphic< Base > func, const int idx)
void evaluate(const ConstVectorRef &x, BaseData &data) const override
void computeVectorHessianProducts(const ConstVectorRef &x, const ConstVectorRef &lbda, BaseData &data) const override
FunctionSliceXprTpl(xyz::polymorphic< Base > func, std::vector< int > const &indices)
shared_ptr< BaseData > createData() const override
Instantiate a Data object.
Represents a function of which the output is a subset of another function, for instance where is gi...
Base struct for function data.
Definition fwd.hpp:62
Class representing ternary functions .
Definition fwd.hpp:56
Represents unary functions of the form , with no control (or next-state) arguments.
Definition fwd.hpp:59
Slicing and indexing of a function's output.
slice_impl_tpl(xyz::polymorphic< Base > func, std::vector< int > const &indices)
void computeJacobians_impl(BaseData &data, Args &&...args) const
void evaluate_impl(BaseData &data, Args &&...args) const
void computeVectorHessianProducts_impl(BaseData &data, const ConstVectorRef &lbda, Args &&...args) const
slice_impl_tpl(xyz::polymorphic< Base > func, int idx)