proxsuite 0.7.2
The Advanced Proximal Optimization Toolbox
 
Loading...
Searching...
No Matches
utils.py
Go to the documentation of this file.
1import torch
2import numpy as np
3
4
5def extract_nBatch(Q, p, A, b, G, l, u):
6 dims = [3, 2, 3, 2, 3, 2, 2]
7 params = [Q, p, A, b, G, l, u]
8 for param, dim in zip(params, dims):
9 if param.ndimension() == dim:
10 return param.size(0)
11 return 1
12
13
14# from qpth: https://github.com/locuslab/qpth/blob/master/qpth/util.py
15def print_header(msg):
16 print("===>", msg)
17
18
19def to_np(t):
20 if t is None:
21 return None
22 elif t.nelement() == 0:
23 return np.array([])
24 else:
25 return t.cpu().numpy()
26
27
28def bger(x, y):
29 return x.unsqueeze(2).bmm(y.unsqueeze(1))
30
31
32def get_sizes(G, A=None):
33 if G.dim() == 2:
34 nineq, nz = G.size()
35 nBatch = 1
36 elif G.dim() == 3:
37 nBatch, nineq, nz = G.size()
38 if A is not None:
39 neq = A.size(1) if A.nelement() > 0 else 0
40 else:
41 neq = None
42 # nBatch = batchedTensor.size(0) if batchedTensor is not None else None
43 return nineq, nz, neq, nBatch
44
45
46def bdiag(d):
47 nBatch, sz = d.size()
48 D = torch.zeros(nBatch, sz, sz).type_as(d)
49 I = torch.eye(sz).repeat(nBatch, 1, 1).type_as(d).bool()
50 D[I] = d.squeeze().view(-1)
51 return D
52
53
54def expandParam(X, nBatch, nDim):
55 if X.ndimension() in (0, nDim) or X.nelement() == 0:
56 return X, False
57 elif X.ndimension() == nDim - 1:
58 return X.unsqueeze(0).expand(*([nBatch] + list(X.size()))), True
59 else:
60 raise RuntimeError("Unexpected number of dimensions.")
expandParam(X, nBatch, nDim)
Definition utils.py:54
get_sizes(G, A=None)
Definition utils.py:32
extract_nBatch(Q, p, A, b, G, l, u)
Definition utils.py:5