proxsuite 0.7.1
The Advanced Proximal Optimization Toolbox
Loading...
Searching...
No Matches
update.hpp
Go to the documentation of this file.
1
2//
3// Copyright (c) 2022 INRIA
4//
5#ifndef PROXSUITE_LINALG_SPARSE_LDLT_UPDATE_HPP
6#define PROXSUITE_LINALG_SPARSE_LDLT_UPDATE_HPP
7
10#include <algorithm>
11
12namespace proxsuite {
13namespace linalg {
14namespace sparse {
15
16/*
17calcule mémoire nécessaire pour la fonction merge_second_col_into_first
18*/
19template<typename I>
20auto
21merge_second_col_into_first_req(proxsuite::linalg::veg::Tag<I> /*tag*/,
22 isize second_size) noexcept
24{
25 return {
26 second_size * isize{ sizeof(I) },
27 alignof(I),
28 };
29}
30
31template<typename T, typename I>
32auto
34 I* difference,
35 T* first_values,
36 I* first_ptr,
37 PROXSUITE_MAYBE_UNUSED isize first_full_len,
38 isize first_initial_len,
39 Slice<I> second,
40 proxsuite::linalg::veg::DoNotDeduce<I> ignore_threshold_inclusive,
41 bool move_values,
42 DynStackMut stack) noexcept(false)
44{
45 VEG_CHECK_CONCEPT(trivially_copyable<I>);
46 VEG_CHECK_CONCEPT(trivially_copyable<T>);
47
48 if (second.len() == 0) {
49 return {
50 proxsuite::linalg::veg::tuplify,
51 { unsafe, from_raw_parts, first_values, first_initial_len },
52 { unsafe, from_raw_parts, first_ptr, first_initial_len },
53 { unsafe, from_raw_parts, difference, 0 },
54 };
55 }
56
57 I const* second_ptr = second.ptr();
58 usize second_len = usize(second.len());
59
60 usize index_second = 0;
61
62 for (; index_second < second_len; ++index_second) {
63 if (second_ptr[index_second] > ignore_threshold_inclusive) {
64 break;
65 }
66 }
67 auto ufirst_initial_len = usize(first_initial_len);
68
69 second_ptr += index_second;
70 second_len -= index_second;
71 index_second = 0;
72
73 proxsuite::linalg::veg::Tag<I> tag{};
74
75 auto _ins_pos = stack.make_new_for_overwrite(tag, isize(second_len));
76
77 I* insert_pos_ptr = _ins_pos.ptr_mut();
78 usize insert_count = 0;
79
80 for (usize index_first = 0; index_first < ufirst_initial_len; ++index_first) {
81 I current_first = first_ptr[index_first];
82 while (true) {
83 if (!(index_second < second_len)) {
84 break;
85 }
86
87 I current_second = second_ptr[index_second];
88 if (!(current_second < current_first)) {
89 break;
90 }
91
92 insert_pos_ptr[insert_count] = I(index_first);
93 difference[insert_count] = current_second;
94 ++insert_count;
95 ++index_second;
96 }
97
98 if (index_second == second_len) {
99 break;
100 }
101 if (second_ptr[index_second] == current_first) {
102 ++index_second;
103 }
104 }
105
106 usize remaining_insert_count = insert_count;
107 usize first_new_len =
108 ufirst_initial_len + insert_count + (second_len - index_second);
109 VEG_ASSERT(usize(first_full_len) >= first_new_len);
110
111 usize append_count = second_len - index_second;
112 std::memmove( //
113 difference + insert_count,
114 second_ptr + index_second,
115 append_count * sizeof(I));
116 std::memmove( //
117 first_ptr + (ufirst_initial_len + insert_count),
118 second_ptr + index_second,
119 append_count * sizeof(I));
120 if (move_values) {
121 for (usize i = 0; i < append_count; ++i) {
122 first_values[i + ufirst_initial_len + insert_count] = 0;
123 }
124 }
125
126 while (remaining_insert_count != 0) {
127
128 usize old_insert_pos = usize(insert_pos_ptr[remaining_insert_count - 1]);
129 usize range_size =
130 (remaining_insert_count == insert_count)
131 ? ufirst_initial_len - old_insert_pos
132 : usize(insert_pos_ptr[remaining_insert_count]) - old_insert_pos;
133
134 usize old_pos = old_insert_pos;
135 usize new_pos = old_pos + remaining_insert_count;
136
137 std::memmove( //
138 first_ptr + new_pos,
139 first_ptr + old_pos,
140 range_size * sizeof(I));
141 if (move_values) {
142 std::memmove( //
143 first_values + new_pos,
144 first_values + old_pos,
145 range_size * sizeof(T));
146 first_values[new_pos - 1] = 0;
147 }
148
149 first_ptr[new_pos - 1] = difference[remaining_insert_count - 1];
150 --remaining_insert_count;
151 }
152
153 return {
154 proxsuite::linalg::veg::tuplify,
155 { unsafe, from_raw_parts, first_values, isize(first_new_len) },
156 { unsafe, from_raw_parts, first_ptr, isize(first_new_len) },
157 { unsafe, from_raw_parts, difference, isize(insert_count + append_count) },
158 };
159}
160
168template<typename T, typename I>
169auto
171 proxsuite::linalg::veg::Tag<T> /*tag*/,
172 proxsuite::linalg::veg::Tag<I> /*tag*/,
173 isize n,
174 bool id_perm,
176{
178 StackReq permuted_indices = { id_perm ? 0 : (col_nnz * isize{ sizeof(I) }),
179 isize{ alignof(I) } };
180 StackReq difference = { n * isize{ sizeof(I) }, isize{ alignof(I) } };
181 difference = difference & difference;
182
184 proxsuite::linalg::veg::Tag<I>{}, n);
185
186 StackReq numerical_workspace = { n * isize{ sizeof(T) },
187 isize{ alignof(T) } };
188
189 return permuted_indices & ((difference & merge) | numerical_workspace);
190}
191
205template<typename T, typename I>
206auto
208 I* etree,
209 I const* perm_inv,
210 VecRef<T, I> w,
212 DynStackMut stack) noexcept(false) -> MatMut<T, I>
213{
214 VEG_ASSERT(!ld.is_compressed());
215
216 if (w.nnz() == 0) {
217 return ld;
218 }
219
220 proxsuite::linalg::veg::Tag<I> tag;
221 usize n = usize(ld.ncols());
222 bool id_perm = perm_inv == nullptr;
223
224 auto _w_permuted_indices =
225 stack.make_new_for_overwrite(tag, id_perm ? isize(0) : w.nnz());
226
227 auto w_permuted_indices =
228 id_perm ? w.row_indices() : _w_permuted_indices.ptr();
229 if (!id_perm) {
230 I* pw_permuted_indices = _w_permuted_indices.ptr_mut();
231 for (usize k = 0; k < usize(w.nnz()); ++k) {
232 usize i = util::zero_extend(w.row_indices()[k]);
233 pw_permuted_indices[k] = perm_inv[i];
234 }
235 std::sort(pw_permuted_indices, pw_permuted_indices + w.nnz());
236 }
237
238 auto sx = util::sign_extend;
239 auto zx = util::zero_extend;
240 // symbolic update
241 {
242 usize current_col = zx(w_permuted_indices[0]);
243
244 auto _difference =
245 stack.make_new_for_overwrite(tag, isize(n - current_col));
246 auto _difference_backup =
247 stack.make_new_for_overwrite(tag, isize(n - current_col));
248
249 auto merge_col = w_permuted_indices;
250 isize merge_col_len = w.nnz();
251 I* difference = _difference.ptr_mut();
252
253 while (true) {
254 usize old_parent = sx(etree[isize(current_col)]);
255
256 usize current_ptr_idx = zx(ld.col_ptrs()[isize(current_col)]);
257 usize next_ptr_idx = zx(ld.col_ptrs()[isize(current_col) + 1]);
258
259 VEG_BIND(auto,
260 (_, new_current_col, computed_difference),
262 difference,
263 ld.values_mut() + (current_ptr_idx + 1),
264 ld.row_indices_mut() + (current_ptr_idx + 1),
265 isize(next_ptr_idx - current_ptr_idx),
266 isize(zx(ld.nnz_per_col()[isize(current_col)])) - 1,
268 unsafe, from_raw_parts, merge_col, merge_col_len },
269 I(current_col),
270 true,
271 stack));
272
273 (void)_;
274 ld._set_nnz(ld.nnz() + new_current_col.len() + 1 -
275 isize(ld.nnz_per_col()[isize(current_col)]));
276 ld.nnz_per_col_mut()[isize(current_col)] = I(new_current_col.len() + 1);
277
278 usize new_parent =
279 (new_current_col.len() == 0) ? usize(-1) : sx(new_current_col[0]);
280
281 if (new_parent == usize(-1)) {
282 break;
283 }
284
285 if (new_parent == old_parent) {
286 merge_col = computed_difference.ptr();
287 merge_col_len = computed_difference.len();
288 difference = _difference_backup.ptr_mut();
289 } else {
290 merge_col = new_current_col.ptr();
291 merge_col_len = new_current_col.len();
292 difference = _difference.ptr_mut();
293 etree[isize(current_col)] = I(new_parent);
294 }
295
296 current_col = new_parent;
297 }
298 }
299
300 // numerical update
301 {
302 usize first_col = zx(w_permuted_indices[0]);
303 auto _work =
304 stack.make_new_for_overwrite(proxsuite::linalg::veg::Tag<T>{}, isize(n));
305 T* pwork = _work.ptr_mut();
306
307 for (usize col = first_col; col != usize(-1); col = sx(etree[isize(col)])) {
308 pwork[col] = 0;
309 }
310 for (usize p = 0; p < usize(w.nnz()); ++p) {
311 pwork[id_perm ? zx(w.row_indices()[isize(p)])
312 : zx(perm_inv[w.row_indices()[isize(p)]])] =
313 w.values()[isize(p)];
314 }
315
316 I const* pldi = ld.row_indices();
317 T* pldx = ld.values_mut();
318
319 for (usize col = first_col; col != usize(-1); col = sx(etree[isize(col)])) {
320 auto col_start = ld.col_start(col);
321 auto col_end = ld.col_end(col);
322
323 T w0 = pwork[col];
324 T old_d = pldx[col_start];
325 T new_d = old_d + alpha * w0 * w0;
326 T beta = alpha * w0 / new_d;
327 alpha = alpha - new_d * beta * beta;
328
329 pldx[col_start] = new_d;
330 pwork[col] -= w0;
331
332 for (usize p = col_start + 1; p < col_end; ++p) {
333 usize i = util::zero_extend(pldi[p]);
334
335 T tmp = pldx[p];
336 pwork[i] = pwork[i] - w0 * tmp;
337 pldx[p] = tmp + beta * pwork[i];
338 }
339 }
340 }
341
342 return ld;
343}
344} // namespace sparse
345} // namespace linalg
346} // namespace proxsuite
347
348#endif /* end of include guard PROXSUITE_LINALG_SPARSE_LDLT_UPDATE_HPP */
#define VEG_ASSERT(...)
#define PROXSUITE_MAYBE_UNUSED
Definition fwd.hpp:20
#define VEG_CHECK_CONCEPT(...)
Definition macros.hpp:1239
auto merge_second_col_into_first(I *difference, T *first_values, I *first_ptr, PROXSUITE_MAYBE_UNUSED isize first_full_len, isize first_initial_len, Slice< I > second, proxsuite::linalg::veg::DoNotDeduce< I > ignore_threshold_inclusive, bool move_values, DynStackMut stack) noexcept(false) -> proxsuite::linalg::veg::Tuple< SliceMut< T >, SliceMut< I >, SliceMut< I > >
Definition update.hpp:33
VEG_INLINE void etree(I *parent, SymbolicMatRef< I > a, DynStackMut stack) noexcept
auto rank1_update_req(proxsuite::linalg::veg::Tag< T >, proxsuite::linalg::veg::Tag< I >, isize n, bool id_perm, isize col_nnz) noexcept -> proxsuite::linalg::veg::dynstack::StackReq
Definition update.hpp:170
auto rank1_update(MatMut< T, I > ld, I *etree, I const *perm_inv, VecRef< T, I > w, proxsuite::linalg::veg::DoNotDeduce< T > alpha, DynStackMut stack) noexcept(false) -> MatMut< T, I >
Definition update.hpp:207
auto merge_second_col_into_first_req(proxsuite::linalg::veg::Tag< I >, isize second_size) noexcept -> proxsuite::linalg::veg::dynstack::StackReq
Definition update.hpp:21
meta::type_identity_t< T > DoNotDeduce
Definition core.hpp:292
_detail::_meta::make_signed< usize >::Type isize
Definition typedefs.hpp:43
decltype(sizeof(0)) usize
Definition macros.hpp:702
#define VEG_BIND(CV_Auto, Identifiers, Tuple)
Definition tuple.hpp:53