aligator  0.6.1
A primal-dual augmented Lagrangian-type solver for nonlinear trajectory optimization.
Loading...
Searching...
No Matches
functions.hpp
Go to the documentation of this file.
1
2#pragma once
3
6
9
10namespace aligator {
11namespace python {
12namespace internal {
20template <class FunctionBase = context::StageFunction>
21struct PyStageFunction : FunctionBase, bp::wrapper<FunctionBase> {
22 using Scalar = typename FunctionBase::Scalar;
23 using Data = StageFunctionDataTpl<Scalar>;
25
26 // Use perfect forwarding to the FunctionBase constructors.
27 template <typename... Args>
28 PyStageFunction(Args &&...args) : FunctionBase(std::forward<Args>(args)...) {}
29
30 void evaluate(const ConstVectorRef &x, const ConstVectorRef &u,
31 const ConstVectorRef &y, Data &data) const override {
32 ALIGATOR_PYTHON_OVERRIDE_PURE(void, "evaluate", x, u, y, boost::ref(data));
33 }
34
35 void computeJacobians(const ConstVectorRef &x, const ConstVectorRef &u,
36 const ConstVectorRef &y, Data &data) const override {
37 ALIGATOR_PYTHON_OVERRIDE_PURE(void, "computeJacobians", x, u, y,
38 boost::ref(data));
39 }
40
41 void computeVectorHessianProducts(const ConstVectorRef &x,
42 const ConstVectorRef &u,
43 const ConstVectorRef &y,
44 const ConstVectorRef &lbda,
45 Data &data) const override {
46 ALIGATOR_PYTHON_OVERRIDE(void, FunctionBase, computeVectorHessianProducts,
47 x, u, y, lbda, boost::ref(data));
48 }
49
50 shared_ptr<Data> createData() const override {
51 ALIGATOR_PYTHON_OVERRIDE(shared_ptr<Data>, FunctionBase, createData, );
52 }
53
54 shared_ptr<Data> default_createData() const {
55 return FunctionBase::createData();
56 }
57};
58
59template <typename UFunction = context::UnaryFunction>
60struct PyUnaryFunction : UFunction, bp::wrapper<UFunction> {
61 using Scalar = typename UFunction::Scalar;
62 static_assert(
63 std::is_base_of_v<UnaryFunctionTpl<Scalar>, UFunction>,
64 "Template parameter UFunction must derive from UnaryFunctionTpl<>.");
67 using Data = StageFunctionDataTpl<Scalar>;
68
69 using UFunction::UFunction;
70
71 void evaluate(const ConstVectorRef &x, Data &data) const override {
72 ALIGATOR_PYTHON_OVERRIDE_PURE(void, "evaluate", x, boost::ref(data));
73 }
74
75 void computeJacobians(const ConstVectorRef &x, Data &data) const override {
76 ALIGATOR_PYTHON_OVERRIDE_PURE(void, "computeJacobians", x,
77 boost::ref(data));
78 }
79
80 void computeVectorHessianProducts(const ConstVectorRef &x,
81 const ConstVectorRef &lbda,
82 Data &data) const override {
83 ALIGATOR_PYTHON_OVERRIDE(void, UFunction, computeVectorHessianProducts, x,
84 lbda, boost::ref(data));
85 }
86
87 void default_computeVectorHessianProducts(const ConstVectorRef &x,
88 const ConstVectorRef &lbda,
89 Data &data) const {
90 UFunction::computeVectorHessianProducts(x, lbda, data);
91 }
92};
93
94} // namespace internal
95
96template <typename Class>
97struct SlicingVisitor : bp::def_visitor<SlicingVisitor<Class>> {
98 using Scalar = typename Class::Scalar;
99 using FS = FunctionSliceXprTpl<Scalar, Class>;
100
101 template <typename Iterator, typename Fn>
102 static auto do_with_slice(Fn &&fun, bp::slice::range<Iterator> &range) {
103 while (range.start != range.stop) {
104 fun(*range.start);
105 std::advance(range.start, range.step);
106 }
107 fun(*range.start);
108 }
109
110 static auto get_slice(shared_ptr<Class> const &fn, bp::slice slice_obj) {
111 std::vector<int> indices((unsigned)fn->nr);
112 std::iota(indices.begin(), indices.end(), 0);
113 auto bounds = slice_obj.get_indices(indices.cbegin(), indices.cend());
114 std::vector<int> out{};
115
116 do_with_slice([&](int i) { out.push_back(i); }, bounds);
117 return std::make_shared<FS>(fn, out);
118 }
119
120 static auto get_from_index(shared_ptr<Class> const &fn, const int idx) {
121 return std::make_shared<FS>(fn, idx);
122 }
123
124 static auto get_from_indices(shared_ptr<Class> const &fn,
125 std::vector<int> const &indices) {
126 return std::make_shared<FS>(fn, indices);
127 }
128
129 template <typename... Args> void visit(bp::class_<Args...> &cl) const {
130 cl.def("__getitem__", &get_from_index, bp::args("self", "idx"))
131 .def("__getitem__", &get_from_indices, bp::args("self", "indices"))
132 .def("__getitem__", &get_slice, bp::args("self", "sl"));
133 }
134};
135
136} // namespace python
137} // namespace aligator
#define ALIGATOR_PYTHON_OVERRIDE(ret_type, cname, fname,...)
Define the body of a virtual function override. This is meant to reduce boilerplate code when exposin...
Definition macros.hpp:50
#define ALIGATOR_PYTHON_OVERRIDE_PURE(ret_type, pyname,...)
Define the body of a virtual function override. This is meant to reduce boilerplate code when exposin...
Definition macros.hpp:41
Base definitions for ternary functions.
#define ALIGATOR_DYNAMIC_TYPEDEFS(Scalar)
Definition math.hpp:18
Main package namespace.
virtual void computeJacobians(const ConstVectorRef &x, const ConstVectorRef &u, const ConstVectorRef &y, Data &data) const=0
Compute Jacobians of this function.
virtual void computeVectorHessianProducts(const ConstVectorRef &x, const ConstVectorRef &u, const ConstVectorRef &y, const ConstVectorRef &lbda, Data &data) const
Compute the vector-hessian products of this function.
StageFunctionDataTpl< Scalar > Data
virtual shared_ptr< Data > createData() const
Instantiate a Data object.
virtual void evaluate(const ConstVectorRef &x, const ConstVectorRef &u, const ConstVectorRef &y, Data &data) const=0
Evaluate the function.
virtual void evaluate(const ConstVectorRef &x, Data &data) const=0
virtual void computeJacobians(const ConstVectorRef &x, Data &data) const=0
virtual void computeVectorHessianProducts(const ConstVectorRef &, const ConstVectorRef &, Data &) const
StageFunctionDataTpl< Scalar > Data
FunctionSliceXprTpl< Scalar, Class > FS
Definition functions.hpp:99
static auto get_slice(shared_ptr< Class > const &fn, bp::slice slice_obj)
void visit(bp::class_< Args... > &cl) const
static auto get_from_indices(shared_ptr< Class > const &fn, std::vector< int > const &indices)
static auto do_with_slice(Fn &&fun, bp::slice::range< Iterator > &range)
static auto get_from_index(shared_ptr< Class > const &fn, const int idx)
typename Class::Scalar Scalar
Definition functions.hpp:98
#define ALIGATOR_UNARY_FUNCTION_INTERFACE(Scalar)