proxsuite 0.6.7
The Advanced Proximal Optimization Toolbox
Loading...
Searching...
No Matches
update.hpp
Go to the documentation of this file.
1
2//
3// Copyright (c) 2022 INRIA
4//
5#ifndef PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP
6#define PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP
7
9
10namespace proxsuite {
11namespace linalg {
12namespace dense {
13namespace _detail {
14inline auto
15bytes_to_prev_aligned(void* ptr, usize align) noexcept -> isize
16{
17 using UPtr = std::uintptr_t;
18
19 UPtr mask = align - 1;
20 UPtr iptr = UPtr(ptr);
21 UPtr aligned_ptr = iptr & ~mask;
22 return isize(aligned_ptr - iptr);
23}
24inline auto
25bytes_to_next_aligned(void* ptr, usize align) noexcept -> isize
26{
27 using UPtr = std::uintptr_t;
28
29 UPtr mask = align - 1;
30 UPtr iptr = UPtr(ptr);
31 UPtr aligned_ptr = (iptr + mask) & ~mask;
32 return isize(aligned_ptr - iptr);
33}
34
35template<usize... Is, typename Fn>
36VEG_INLINE void
42
43template<usize N, typename Fn>
44VEG_INLINE void
50
51template<typename T, usize N>
53{
55 T const* pw;
56 isize w_stride;
57
58 VEG_INLINE void operator()(usize i) const
59 {
61 }
62};
63
64template<typename T, usize N>
78
79template<typename T, usize N>
81{
83 T* pw;
84 isize w_stride;
85
86 VEG_INLINE void operator()(usize i) const
87 {
88 p_wr[i].store_unaligned(pw + w_stride * isize(i));
89 }
90};
91
92template<usize R, typename T, usize N>
93VEG_INLINE void
95 _simd::Pack<T, N> const* p_p,
96 _simd::Pack<T, N> const* p_mu,
97 T* inout_l,
98 T* pw,
99 isize w_stride)
100{
101
102 _simd::Pack<T, N> p_wr[R];
103 _detail::unroll<R>(RankUpdateLoadW<T, N>{ p_wr, pw, w_stride });
105 _detail::unroll<R>(RankUpdateUpdateWAndL<T, N>{ p_wr, p_in_l, p_p, p_mu });
106 _detail::unroll<R>(RankUpdateStoreW<T, N>{ p_wr, pw, w_stride });
107
108 p_in_l.store_unaligned(inout_l);
109}
110
111template<bool VECTORIZABLE>
113
114template<typename T, usize N>
116{
119 T const* p;
120 T const* mu;
121 VEG_INLINE void operator()(usize i) const
122 {
125 }
126};
127
128template<>
130{
131 template<usize R, typename T>
132 VEG_INLINE static void fn(isize n,
133 T* inout_l,
134 T* pw,
135 isize w_stride,
136 T const* p,
137 T const* mu) noexcept
138 {
139 using Pack_ = _simd::Pack<T, 1>;
140 Pack_ p_p[R];
141 Pack_ p_mu[R];
142
143 _detail::unroll<R>(RankUpdateLoadPMu<T, 1>{ p_p, p_mu, p, mu });
144
145 auto inout_l_finish = inout_l + n;
146 while (inout_l < inout_l_finish) {
147 _detail::rank_r_update_inner_loop_iter<R>(
148 p_p, p_mu, inout_l, pw, w_stride);
149 ++inout_l;
150 ++pw;
151 }
152 }
153};
154
155template<>
157{
158 template<usize R, typename T>
159 VEG_INLINE static void fn(isize n,
160 T* inout_l,
161 T* pw,
162 isize w_stride,
163 T const* p,
164 T const* mu) noexcept
165 {
166
167 // best perf if beginning of each pw is aligned
168 // should be enforced by the Ldlt class
169
170 using Info = _simd::NativePackInfo<T>;
171 constexpr usize N = Info::N;
172 auto inout_l_vectorized_end = inout_l + usize(n) / N * N;
173 auto inout_l_end = inout_l + usize(n);
174
175 {
176 using Pack = _simd::NativePack<T>;
177 Pack p_p[R];
178 Pack p_mu[R];
179
180 _detail::unroll<R>(RankUpdateLoadPMu<T, N>{ p_p, p_mu, p, mu });
181
182 while (inout_l < inout_l_vectorized_end) {
183 _detail::rank_r_update_inner_loop_iter<R>(
184 p_p, p_mu, inout_l, pw, w_stride);
185 inout_l += N;
186 pw += N;
187 }
188 }
189 {
190 using Pack_ = _simd::Pack<T, 1>;
191 Pack_ p_p[R];
192 Pack_ p_mu[R];
193
194 _detail::unroll<R>(RankUpdateLoadPMu<T, 1>{ p_p, p_mu, p, mu });
195
196 while (inout_l < inout_l_end) {
197 _detail::rank_r_update_inner_loop_iter<R>(
198 p_p, p_mu, inout_l, pw, w_stride);
199 ++inout_l;
200 ++pw;
201 }
202 }
203 }
204};
205
206template<usize R, typename T>
207VEG_INLINE void
209 T* inout_l,
210 T* pw,
211 isize w_stride,
212 T const* p,
213 T const* mu)
214{
216 n, inout_l, pw, w_stride, p, mu);
217}
218
219template<typename LD, typename T, typename Fn>
220void
222 LD ld,
223 T* pw,
224 isize w_stride,
225 T* palpha,
226 Fn r_fn)
227{
228 static_assert(LD::InnerStrideAtCompileTime == 1, ".");
229 static_assert(!bool(LD::IsRowMajor), ".");
230
231 isize n = ld.rows();
232
233 for (isize j = 0; j < n; ++j) {
234 isize r = r_fn();
235
236 isize r_done = 0;
237 if (!(r_done < r)) {
238 continue;
239 }
240
241 while (true) {
242 isize r_chunk = min2(isize(4), r - r_done);
243
244 T p_array[4];
245 T mu_array[4];
246
247 T dj = ld(j, j);
248 for (isize k = 0; k < r_chunk; ++k) {
249 auto& p = (+p_array)[k];
250 auto& mu = (+mu_array)[k];
251 auto& alpha = palpha[r_done + k];
252
253 p = pw[(r_done + k) * w_stride];
254 T new_dj = dj + (alpha * p) * p;
255 mu = (alpha * p) / new_dj;
256 alpha -= new_dj * (mu * mu);
257
258 dj = new_dj;
259 }
260 ld(j, j) = dj;
261
262 isize rem = n - j - 1;
263
264 using FnType = void (*)(isize, T*, T*, isize, T const*, T const*);
265 FnType fn_table[] = {
266 rank_r_update_inner_loop<1, T>,
267 rank_r_update_inner_loop<2, T>,
268 rank_r_update_inner_loop<3, T>,
269 rank_r_update_inner_loop<4, T>,
270 };
271
272 (*fn_table[r_chunk - 1])( //
273 rem,
274 util::matrix_elem_addr(ld, j + 1, j),
275 pw + 1 + r_done * w_stride,
276 w_stride,
277 p_array,
278 mu_array);
279
280 r_done += r_chunk;
281 if (!(r_done < r)) {
282 break;
283 }
284 }
285 ++pw;
286 }
287}
289{
290 isize r;
291 VEG_INLINE auto operator()() const noexcept -> isize { return r; }
292};
293} // namespace _detail
294
295template<typename LD,
296 typename W,
297 typename T = typename proxsuite::linalg::veg::uncvref_t<LD>::Scalar>
298void
300 W&& w,
302{
305 w.data(),
306 0,
307 proxsuite::linalg::veg::mem::addressof(alpha),
308 _detail::ConstantR{ 1 });
309}
310
311template<typename LD,
312 typename W,
313 typename A,
314 typename T = typename proxsuite::linalg::veg::uncvref_t<LD>::Scalar>
315void
316rank_r_update_clobber_inputs(LD&& ld, W&& w, A&& alpha)
317{
318 isize r = w.cols();
321 w.data(),
322 w.outerStride(),
323 alpha.data(),
324 _detail::ConstantR{ r });
325}
326} // namespace dense
327} // namespace linalg
328} // namespace proxsuite
329
330#endif /* end of include guard PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP */
#define VEG_EVAL_ALL(...)
Definition macros.hpp:191
#define VEG_INLINE
Definition macros.hpp:118
#define VEG_FWD(X)
Definition macros.hpp:569
typename NativePackInfo< T >::Type NativePack
Definition core.hpp:295
VEG_INLINE void unroll_impl(proxsuite::linalg::veg::meta::index_sequence< Is... >, Fn fn)
Definition update.hpp:37
auto bytes_to_prev_aligned(void *ptr, usize align) noexcept -> isize
Definition update.hpp:15
auto bytes_to_next_aligned(void *ptr, usize align) noexcept -> isize
Definition update.hpp:25
VEG_INLINE void rank_r_update_inner_loop_iter(_simd::Pack< T, N > const *p_p, _simd::Pack< T, N > const *p_mu, T *inout_l, T *pw, isize w_stride)
Definition update.hpp:94
void rank_r_update_clobber_w_impl(LD ld, T *pw, isize w_stride, T *palpha, Fn r_fn)
Definition update.hpp:221
VEG_INLINE void rank_r_update_inner_loop(isize n, T *inout_l, T *pw, isize w_stride, T const *p, T const *mu)
Definition update.hpp:208
VEG_INLINE void unroll(Fn fn)
Definition update.hpp:45
auto align() noexcept -> isize
Definition core.hpp:340
auto matrix_elem_addr(Mat &&mat, isize row, isize col) noexcept -> decltype(mat.data())
Definition core.hpp:563
auto to_view_dyn(Mat &&mat) noexcept -> Eigen::Map< _detail::const_if< _detail::ptr_is_const< decltype(mat.data())>::value, _detail::OwnedMatrix< proxsuite::linalg::veg::uncvref_t< Mat > > >, Eigen::Unaligned, _detail::StrideOf< proxsuite::linalg::veg::uncvref_t< Mat > > >
Definition core.hpp:730
void rank_r_update_clobber_inputs(LD &&ld, W &&w, A &&alpha)
Definition update.hpp:316
void rank_1_update_clobber_w(LD &&ld, W &&w, proxsuite::linalg::veg::DoNotDeduce< T > alpha)
Definition update.hpp:299
_detail::_meta::make_integer_sequence< usize, N > * make_index_sequence
VEG_INLINE auto operator()() const noexcept -> isize
Definition update.hpp:291
static VEG_INLINE void fn(isize n, T *inout_l, T *pw, isize w_stride, T const *p, T const *mu) noexcept
Definition update.hpp:132
static VEG_INLINE void fn(isize n, T *inout_l, T *pw, isize w_stride, T const *p, T const *mu) noexcept
Definition update.hpp:159
VEG_INLINE void operator()(usize i) const
Definition update.hpp:121
VEG_INLINE void operator()(usize i) const
Definition update.hpp:58
VEG_INLINE void operator()(usize i) const
Definition update.hpp:86