20template <
class FunctionBase = context::StageFunction>
21struct PyStageFunction : FunctionBase, bp::wrapper<FunctionBase> {
22 using Scalar =
typename FunctionBase::Scalar;
23 using Data = StageFunctionDataTpl<Scalar>;
27 template <
typename... Args>
28 PyStageFunction(Args &&...args) : FunctionBase(std::forward<Args>(args)...) {}
30 void evaluate(
const ConstVectorRef &x,
const ConstVectorRef &u,
31 const ConstVectorRef &y, Data &data)
const override {
36 const ConstVectorRef &y, Data &data)
const override {
42 const ConstVectorRef &u,
43 const ConstVectorRef &y,
44 const ConstVectorRef &lbda,
45 Data &data)
const override {
47 x, u, y, lbda, boost::ref(data));
54 shared_ptr<Data> default_createData()
const {
55 return FunctionBase::createData();
59template <
typename UFunction = context::UnaryFunction>
60struct PyUnaryFunction : UFunction, bp::wrapper<UFunction> {
61 using Scalar =
typename UFunction::Scalar;
63 std::is_base_of_v<UnaryFunctionTpl<Scalar>, UFunction>,
64 "Template parameter UFunction must derive from UnaryFunctionTpl<>.");
67 using Data = StageFunctionDataTpl<Scalar>;
69 using UFunction::UFunction;
71 void evaluate(
const ConstVectorRef &x, Data &data)
const override {
81 const ConstVectorRef &lbda,
82 Data &data)
const override {
84 lbda, boost::ref(data));
87 void default_computeVectorHessianProducts(
const ConstVectorRef &x,
88 const ConstVectorRef &lbda,
90 UFunction::computeVectorHessianProducts(x, lbda, data);
96template <
typename Class>
98 using Scalar =
typename Class::Scalar;
99 using FS = FunctionSliceXprTpl<Scalar, Class>;
101 template <
typename Iterator,
typename Fn>
103 while (range.start != range.stop) {
105 std::advance(range.start, range.step);
110 static auto get_slice(shared_ptr<Class>
const &fn, bp::slice slice_obj) {
111 std::vector<int> indices((
unsigned)fn->nr);
112 std::iota(indices.begin(), indices.end(), 0);
113 auto bounds = slice_obj.get_indices(indices.cbegin(), indices.cend());
114 std::vector<int> out{};
117 return std::make_shared<FS>(fn, out);
121 return std::make_shared<FS>(fn, idx);
125 std::vector<int>
const &indices) {
126 return std::make_shared<FS>(fn, indices);
129 template <
typename... Args>
void visit(bp::class_<Args...> &cl)
const {
132 .def(
"__getitem__", &
get_slice, bp::args(
"self",
"sl"));
#define ALIGATOR_PYTHON_OVERRIDE(ret_type, cname, fname,...)
Define the body of a virtual function override. This is meant to reduce boilerplate code when exposin...
#define ALIGATOR_PYTHON_OVERRIDE_PURE(ret_type, pyname,...)
Define the body of a virtual function override. This is meant to reduce boilerplate code when exposin...
Base definitions for ternary functions.
#define ALIGATOR_DYNAMIC_TYPEDEFS(Scalar)
virtual void computeJacobians(const ConstVectorRef &x, const ConstVectorRef &u, const ConstVectorRef &y, Data &data) const=0
Compute Jacobians of this function.
virtual void computeVectorHessianProducts(const ConstVectorRef &x, const ConstVectorRef &u, const ConstVectorRef &y, const ConstVectorRef &lbda, Data &data) const
Compute the vector-hessian products of this function.
StageFunctionDataTpl< Scalar > Data
virtual shared_ptr< Data > createData() const
Instantiate a Data object.
virtual void evaluate(const ConstVectorRef &x, const ConstVectorRef &u, const ConstVectorRef &y, Data &data) const=0
Evaluate the function.
virtual void evaluate(const ConstVectorRef &x, Data &data) const=0
virtual void computeJacobians(const ConstVectorRef &x, Data &data) const=0
virtual void computeVectorHessianProducts(const ConstVectorRef &, const ConstVectorRef &, Data &) const
StageFunctionDataTpl< Scalar > Data
FunctionSliceXprTpl< Scalar, Class > FS
static auto get_slice(shared_ptr< Class > const &fn, bp::slice slice_obj)
void visit(bp::class_< Args... > &cl) const
static auto get_from_indices(shared_ptr< Class > const &fn, std::vector< int > const &indices)
static auto do_with_slice(Fn &&fun, bp::slice::range< Iterator > &range)
static auto get_from_index(shared_ptr< Class > const &fn, const int idx)
typename Class::Scalar Scalar
#define ALIGATOR_UNARY_FUNCTION_INTERFACE(Scalar)