proxsuite-nlp  0.10.0
A primal-dual augmented Lagrangian-type solver for nonlinear programming on manifolds.
Loading...
Searching...
No Matches
gemmt.hpp
Go to the documentation of this file.
1
3#pragma once
4
6
7namespace proxsuite {
8namespace nlp {
9namespace linalg {
10namespace backend {
11
12#define PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) \
13 template <typename Dst, typename Lhs, typename Rhs> \
14 static void fn(Eigen::MatrixBase<Dst> &dst, \
15 Eigen::MatrixBase<Lhs> const &lhs, \
16 Eigen::MatrixBase<Rhs> const &rhs, Scalar alpha)
17
18template <typename Scalar, BlockKind LHS, BlockKind RHS> struct GemmT {
20 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, /*dst*/, /*lhs*/, /*rhs*/, /*alpha*/) {}
21};
22
23template <typename Scalar> struct GemmT<Scalar, Diag, Diag> {
25 // dst is diagonal
26 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
27 auto v = lhs.diagonal().cwiseProduct(rhs.transpose().diagonal());
28 isize n = v.rows();
29 dst.diagonal().head(n) += alpha * v;
30 }
31};
32
33template <typename Scalar> struct GemmT<Scalar, Diag, TriL> {
35 // dst is triu
36 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
37 // dst.template triangularView<Eigen::Upper>() +=
38 // alpha * (lhs.diagonal().asDiagonal() *
39 // rhs.template triangularView<Eigen::Lower>().transpose());
40
41 isize n = dst.cols();
42
43 for (isize j = 0; j < n; ++j) {
44 dst.col(j).head(j + 1) += alpha * lhs.diagonal().cwiseProduct(
45 rhs.transpose().col(j).head(j + 1));
46 }
47 }
48};
49
50template <typename Scalar> struct GemmT<Scalar, Diag, TriU> {
52 // dst is tril
53 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
54 // dst.template triangularView<Eigen::Lower>() +=
55 // alpha * (lhs.diagonal().asDiagonal() *
56 // rhs.template triangularView<Eigen::Upper>().transpose());
57
58 isize m = dst.rows();
59 isize n = dst.cols();
60
61 for (isize j = 0; j < n; ++j) {
62 dst.col(j).tail(m - j) += alpha * lhs.diagonal().cwiseProduct(
63 rhs.transpose().col(j).tail(m - j));
64 }
65 }
66};
67
68template <typename Scalar> struct GemmT<Scalar, Diag, Dense> {
70 // dst is dense
71 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
72 dst += alpha * (lhs.diagonal().asDiagonal() * rhs.transpose());
73 }
74};
75
76template <typename Scalar> struct GemmT<Scalar, TriL, Diag> {
78 // dst is tril
79 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
80 // dst.template triangularView<Eigen::Lower>() +=
81 // alpha * (lhs.template triangularView<Eigen::Lower>() *
82 // rhs.diagonal().asDiagonal());
83
84 isize m = dst.rows();
85 isize n = dst.cols();
86
87 for (isize j = 0; j < n; ++j) {
88 dst.col(j).tail(m - j) += (alpha * rhs(j, j)) * lhs.col(j).tail(m - j);
89 }
90 }
91};
92
93template <typename Scalar> struct GemmT<Scalar, TriL, TriL> {
95 // dst is dense
96 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
97 // PERF
98 // dst += alpha * (lhs.template triangularView<Eigen::Lower>() *
99 // rhs.transpose().template
100 // triangularView<Eigen::Upper>());
101 dst.noalias() +=
102 alpha * (lhs * rhs.transpose().template triangularView<Eigen::Upper>());
103 }
104};
105
106template <typename Scalar> struct GemmT<Scalar, TriL, TriU> {
108 // dst is tril
109 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
110 // PERF
111 // dst += alpha * (lhs.template triangularView<Eigen::Lower>() *
112 // rhs.transpose().template
113 // triangularView<Eigen::Lower>());
114 dst.template triangularView<Eigen::Lower>() +=
115 alpha * (lhs * rhs.transpose().template triangularView<Eigen::Lower>());
116 }
117};
118
119template <typename Scalar> struct GemmT<Scalar, TriL, Dense> {
121 // dst is dense
122 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
123 dst.noalias() +=
124 lhs.template triangularView<Eigen::Lower>() * (alpha * rhs.transpose());
125 }
126};
127
128template <typename Scalar> struct GemmT<Scalar, TriU, Diag> {
130 // dst is triu
131 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
132 // dst.template triangularView<Eigen::Lower>() +=
133 // alpha * (lhs.template triangularView<Eigen::Lower>() *
134 // rhs.diagonal().asDiagonal());
135
136 isize n = dst.cols();
137
138 for (isize j = 0; j < n; ++j) {
139 dst.col(j).head(j + 1) += (alpha * rhs(j, j)) * lhs.col(j).head(j + 1);
140 }
141 }
142};
143
144template <typename Scalar> struct GemmT<Scalar, TriU, TriL> {
146 // dst is triu
147 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
148 // PERF
149 // dst.template triangularView<Eigen::Upper>() +=
150 // alpha * (lhs.template triangularView<Eigen::Upper>() *
151 // rhs.transpose().triangularView<Eigen::Upper>());
152 dst.template triangularView<Eigen::Upper>() +=
153 alpha * (lhs * rhs.transpose().template triangularView<Eigen::Upper>());
154 }
155};
156
157template <typename Scalar> struct GemmT<Scalar, TriU, TriU> {
159 // dst is dense
160 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
161 // PERF
162 // dst.noalias() += alpha * (lhs.template triangularView<Eigen::Upper>() *
163 // rhs.transpose().template
164 // triangularView<Eigen::Lower>());
165 dst.noalias() +=
166 alpha * (lhs * rhs.transpose().template triangularView<Eigen::Lower>());
167 }
168};
169
170template <typename Scalar> struct GemmT<Scalar, TriU, Dense> {
172 // dst is dense
173 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
174 dst.noalias() += (lhs.template triangularView<Eigen::Upper>() *
175 (alpha * rhs.transpose()));
176 }
177};
178
179template <typename Scalar> struct GemmT<Scalar, Dense, Diag> {
181 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
182 dst.noalias() += alpha * (lhs * rhs.transpose().diagonal().asDiagonal());
183 }
184};
185
186template <typename Scalar> struct GemmT<Scalar, Dense, TriL> {
188 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
189 dst.noalias() +=
190 alpha * (lhs * rhs.transpose().template triangularView<Eigen::Upper>());
191 }
192};
193
194template <typename Scalar> struct GemmT<Scalar, Dense, TriU> {
196 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
197 dst.noalias() +=
198 alpha * (lhs * rhs.transpose().template triangularView<Eigen::Lower>());
199 }
200};
201
202template <typename Scalar> struct GemmT<Scalar, Dense, Dense> {
204 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
205 dst.noalias() += alpha * lhs * rhs.transpose();
206 }
207};
208
209template <typename Scalar, typename DstDerived, typename LhsDerived,
210 typename RhsDerived>
211inline void gemmt(Eigen::MatrixBase<DstDerived> &dst,
212 Eigen::MatrixBase<LhsDerived> const &lhs,
213 Eigen::MatrixBase<RhsDerived> const &rhs, BlockKind lhs_kind,
214 BlockKind rhs_kind, Scalar alpha) {
215 // dst += alpha * lhs * rhs.T
216 switch (lhs_kind) {
217 case Zero: {
218 switch (rhs_kind) {
219 case Zero:
220 GemmT<Scalar, Zero, Zero>::fn(dst, lhs, rhs, alpha);
221 break;
222 case Diag:
223 GemmT<Scalar, Zero, Diag>::fn(dst, lhs, rhs, alpha);
224 break;
225 case TriL:
226 GemmT<Scalar, Zero, TriL>::fn(dst, lhs, rhs, alpha);
227 break;
228 case TriU:
229 GemmT<Scalar, Zero, TriU>::fn(dst, lhs, rhs, alpha);
230 break;
231 case Dense:
232 GemmT<Scalar, Zero, Dense>::fn(dst, lhs, rhs, alpha);
233 break;
234 }
235 break;
236 }
237 case Diag: {
238 switch (rhs_kind) {
239 case Zero:
240 GemmT<Scalar, Diag, Zero>::fn(dst, lhs, rhs, alpha);
241 break;
242 case Diag:
243 GemmT<Scalar, Diag, Diag>::fn(dst, lhs, rhs, alpha);
244 break;
245 case TriL:
246 GemmT<Scalar, Diag, TriL>::fn(dst, lhs, rhs, alpha);
247 break;
248 case TriU:
249 GemmT<Scalar, Diag, TriU>::fn(dst, lhs, rhs, alpha);
250 break;
251 case Dense:
252 GemmT<Scalar, Diag, Dense>::fn(dst, lhs, rhs, alpha);
253 break;
254 }
255 break;
256 }
257 case TriL: {
258 switch (rhs_kind) {
259 case Zero:
260 GemmT<Scalar, TriL, Zero>::fn(dst, lhs, rhs, alpha);
261 break;
262 case Diag:
263 GemmT<Scalar, TriL, Diag>::fn(dst, lhs, rhs, alpha);
264 break;
265 case TriL:
266 GemmT<Scalar, TriL, TriL>::fn(dst, lhs, rhs, alpha);
267 break;
268 case TriU:
269 GemmT<Scalar, TriL, TriU>::fn(dst, lhs, rhs, alpha);
270 break;
271 case Dense:
272 GemmT<Scalar, TriL, Dense>::fn(dst, lhs, rhs, alpha);
273 break;
274 }
275 break;
276 }
277 case TriU: {
278 switch (rhs_kind) {
279 case Zero:
280 GemmT<Scalar, TriU, Zero>::fn(dst, lhs, rhs, alpha);
281 break;
282 case Diag:
283 GemmT<Scalar, TriU, Diag>::fn(dst, lhs, rhs, alpha);
284 break;
285 case TriL:
286 GemmT<Scalar, TriU, TriL>::fn(dst, lhs, rhs, alpha);
287 break;
288 case TriU:
289 GemmT<Scalar, TriU, TriU>::fn(dst, lhs, rhs, alpha);
290 break;
291 case Dense:
292 GemmT<Scalar, TriU, Dense>::fn(dst, lhs, rhs, alpha);
293 break;
294 }
295 break;
296 }
297 case Dense: {
298 switch (rhs_kind) {
299 case Zero:
300 GemmT<Scalar, Dense, Zero>::fn(dst, lhs, rhs, alpha);
301 break;
302 case Diag:
303 GemmT<Scalar, Dense, Diag>::fn(dst, lhs, rhs, alpha);
304 break;
305 case TriL:
306 GemmT<Scalar, Dense, TriL>::fn(dst, lhs, rhs, alpha);
307 break;
308 case TriU:
309 GemmT<Scalar, Dense, TriU>::fn(dst, lhs, rhs, alpha);
310 break;
311 case Dense:
312 GemmT<Scalar, Dense, Dense>::fn(dst, lhs, rhs, alpha);
313 break;
314 }
315 break;
316 }
317 }
318}
319
320} // namespace backend
321} // namespace linalg
322} // namespace nlp
323} // namespace proxsuite
Definition for matrix "kind" enums.
#define PROXSUITE_NLP_DYNAMIC_TYPEDEFS(Scalar)
Definition math.hpp:26
void gemmt(Eigen::MatrixBase< DstDerived > &dst, Eigen::MatrixBase< LhsDerived > const &lhs, Eigen::MatrixBase< RhsDerived > const &rhs, BlockKind lhs_kind, BlockKind rhs_kind, Scalar alpha)
Definition gemmt.hpp:211
BlockKind
Kind of matrix block: zeros, diagonal, lower/upper triangular or dense.
@ Dense
There is no known prior structure; assume a dense block.
@ Diag
The block is diagonal.
@ Zero
All entries in the block are zero.
@ TriL
The block is lower-triangular.
@ TriU
The block is upper-triangular.
Main package namespace.
Definition bcl-params.hpp:5
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:181
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:188
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:196
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:71
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:26
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:36
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:53
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:122
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:79
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:96
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:109
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:173
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:131
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:147
PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha)
Definition gemmt.hpp:160