12#define PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) \
13 template <typename Dst, typename Lhs, typename Rhs> \
14 static void fn(Eigen::MatrixBase<Dst> &dst, \
15 Eigen::MatrixBase<Lhs> const &lhs, \
16 Eigen::MatrixBase<Rhs> const &rhs, Scalar alpha)
18template <
typename Scalar, BlockKind LHS, BlockKind RHS>
struct GemmT {
20 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, , , , ) {}
26 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
27 auto v = lhs.diagonal().cwiseProduct(rhs.transpose().diagonal());
29 dst.diagonal().head(n) += alpha * v;
36 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
43 for (isize j = 0; j < n; ++j) {
44 dst.col(j).head(j + 1) += alpha * lhs.diagonal().cwiseProduct(
45 rhs.transpose().col(j).head(j + 1));
53 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
61 for (isize j = 0; j < n; ++j) {
62 dst.col(j).tail(m - j) += alpha * lhs.diagonal().cwiseProduct(
63 rhs.transpose().col(j).tail(m - j));
71 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
72 dst += alpha * (lhs.diagonal().asDiagonal() * rhs.transpose());
79 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
87 for (isize j = 0; j < n; ++j) {
88 dst.col(j).tail(m - j) += (alpha * rhs(j, j)) * lhs.col(j).tail(m - j);
96 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
102 alpha * (lhs * rhs.transpose().
template triangularView<Eigen::Upper>());
109 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
114 dst.template triangularView<Eigen::Lower>() +=
115 alpha * (lhs * rhs.transpose().
template triangularView<Eigen::Lower>());
122 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
124 lhs.template triangularView<Eigen::Lower>() * (alpha * rhs.transpose());
131 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
136 isize n = dst.cols();
138 for (isize j = 0; j < n; ++j) {
139 dst.col(j).head(j + 1) += (alpha * rhs(j, j)) * lhs.col(j).head(j + 1);
147 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
152 dst.template triangularView<Eigen::Upper>() +=
153 alpha * (lhs * rhs.transpose().
template triangularView<Eigen::Upper>());
160 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
166 alpha * (lhs * rhs.transpose().
template triangularView<Eigen::Lower>());
173 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
174 dst.noalias() += (lhs.template triangularView<Eigen::Upper>() *
175 (alpha * rhs.transpose()));
181 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
182 dst.noalias() += alpha * (lhs * rhs.transpose().diagonal().asDiagonal());
188 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
190 alpha * (lhs * rhs.transpose().
template triangularView<Eigen::Upper>());
196 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
198 alpha * (lhs * rhs.transpose().
template triangularView<Eigen::Lower>());
204 PROXSUITE_NLP_GEMMT_SIGNATURE(Scalar, dst, lhs, rhs, alpha) {
205 dst.noalias() += alpha * lhs * rhs.transpose();
209template <
typename Scalar,
typename DstDerived,
typename LhsDerived,
211inline void gemmt(Eigen::MatrixBase<DstDerived> &dst,
212 Eigen::MatrixBase<LhsDerived>
const &lhs,
213 Eigen::MatrixBase<RhsDerived>
const &rhs,
BlockKind lhs_kind,
240 GemmT<Scalar, Diag, Zero>::fn(dst, lhs, rhs, alpha);
243 GemmT<Scalar, Diag, Diag>::fn(dst, lhs, rhs, alpha);
246 GemmT<Scalar, Diag, TriL>::fn(dst, lhs, rhs, alpha);
249 GemmT<Scalar, Diag, TriU>::fn(dst, lhs, rhs, alpha);
252 GemmT<Scalar, Diag, Dense>::fn(dst, lhs, rhs, alpha);
260 GemmT<Scalar, TriL, Zero>::fn(dst, lhs, rhs, alpha);
263 GemmT<Scalar, TriL, Diag>::fn(dst, lhs, rhs, alpha);
266 GemmT<Scalar, TriL, TriL>::fn(dst, lhs, rhs, alpha);
269 GemmT<Scalar, TriL, TriU>::fn(dst, lhs, rhs, alpha);
272 GemmT<Scalar, TriL, Dense>::fn(dst, lhs, rhs, alpha);
Definition for matrix "kind" enums.
#define PROXSUITE_NLP_DYNAMIC_TYPEDEFS(Scalar)
Specific linear algebra routines.
BlockKind
Kind of matrix block: zeros, diagonal, lower/upper triangular or dense.
@ Dense
There is no known prior structure; assume a dense block.
@ Diag
The block is diagonal.
@ Zero
All entries in the block are zero.
@ TriL
The block is lower-triangular.
@ TriU
The block is upper-triangular.