aligator  0.16.0
A versatile and efficient C++ library for real-time constrained trajectory optimization.
Loading...
Searching...
No Matches
plotting.py
Go to the documentation of this file.
1import matplotlib.pyplot as plt
2import numpy as np
3
4from aligator import HistoryCallback, Results
5
6_ROOT_10 = 10.0**0.5
7
8
9def plot_pd_errs(ax0, prim_errs, dual_errs):
10 import matplotlib.pyplot as plt
11
12 ax0: plt.Axes
13 prim_errs = np.asarray(prim_errs)
14 dual_errs = np.asarray(dual_errs)
15 ax0.plot(prim_errs, c="tab:blue")
16 ax0.set_xlabel("Iterations")
17 col2 = "tab:orange"
18 ax0.plot(dual_errs, c=col2)
19 ax0.spines["top"].set_visible(False)
20 ax0.spines["right"].set_color(col2)
21 ax0.yaxis.label.set_color(col2)
22 ax0.set_yscale("log")
23 ax0.legend(["Primal error $p$", "Dual error $d$"])
24 ax0.set_title("Solver primal-dual residuals")
25
26 # handle scaling
27 yhigh = ax0.get_ylim()[1]
28 if len(prim_errs) == 0 or len(dual_errs) == 0:
29 return
30 mach_eps = np.finfo(float).eps
31 dmask = dual_errs > 2 * mach_eps
32 pmask = prim_errs > 2 * mach_eps
33 ymin = np.finfo(float).max
34 if dmask.any():
35 ymin = np.min(dual_errs[dmask])
36 if pmask.any() and sum(prim_errs > 0) > 0:
37 ymin = min(np.min(prim_errs[pmask]), ymin)
38 ax0.set_ylim(ymin / _ROOT_10, yhigh)
39
40
42 cb: HistoryCallback,
43 ax: plt.Axes,
44 res: Results = None,
45 *,
46 show_al_iters=False,
47 target_tol: float = None,
48 legend_kwargs={},
49):
50 prim_infeas = cb.prim_infeas.tolist()
51 dual_infeas = cb.dual_infeas.tolist()
52 if res is not None:
53 prim_infeas.append(res.primal_infeas)
54 dual_infeas.append(res.dual_infeas)
55 plot_pd_errs(ax, prim_infeas, dual_infeas)
56
57 ax.grid(axis="y", which="major")
58 _, labels = ax.get_legend_handles_labels()
59 labels += [
60 "Prim. err $p$",
61 "Dual err $d$",
62 ]
63 if show_al_iters:
64 prim_tols = np.array(cb.prim_tols)
65 al_iters = np.array(cb.al_index)
66 labels.append("$\\eta_k$")
67
68 itrange = np.arange(len(al_iters))
69 if itrange.size > 0:
70 if al_iters.max() > 0:
71 labels.append("AL iters")
72 ax.step(itrange, prim_tols, c="green", alpha=0.9, lw=1.1)
73 al_change = al_iters[1:] - al_iters[:-1]
74 al_change_idx = itrange[:-1][al_change > 0]
75
76 ax.vlines(al_change_idx, *ax.get_ylim(), colors="gray", lw=4.0, alpha=0.5)
77
78 if target_tol:
79 ax.axhline(target_tol, color="k", lw=1.2)
80
81 ax.legend(labels=labels, **legend_kwargs)
82 return labels
83
84
86 q: np.ndarray, ax: plt.Axes, alpha=0.5, fc="tab:blue"
87) -> plt.Rectangle:
88 from matplotlib import transforms
89
90 w = 1.0
91 h = 0.4
92 center = (q[0] - 0.5 * w, q[1] - 0.5 * h)
93 rect = plt.Rectangle(center, w, h, fc=fc, alpha=alpha)
94 theta = np.arctan2(q[3], q[2])
95 transform_ = transforms.Affine2D().rotate_around(*q[:2], -theta) + ax.transData
96 rect.set_transform(transform_)
97 ax.add_patch(rect)
98 return rect
99
100
101def _axes_flatten_if_ndarray(axes) -> list[plt.Axes]:
102 if isinstance(axes, np.ndarray):
103 axes = axes.flatten()
104 elif not isinstance(axes, list):
105 axes = [axes]
106 return axes
107
108
110 times,
111 us,
112 ncols=2,
113 axes=None,
114 effort_limit=None,
115 joint_names=None,
116 rmodel=None,
117 figsize=(6.4, 6.4),
118 xlabel="Time (s)",
119) -> tuple[plt.Figure, list[plt.Axes]]:
120 t0 = times[0]
121 tf = times[-1]
122 us = np.asarray(us)
123 nu = us.shape[1]
124 nrows, r = divmod(nu, ncols)
125 nrows += int(r > 0)
126
127 make_new_plot = axes is None
128 if make_new_plot:
129 fig, axes = plt.subplots(nrows, ncols, sharex="col", figsize=figsize)
130 else:
131 fig = axes.flat[0].get_figure()
132 axes = _axes_flatten_if_ndarray(axes)
133
134 if rmodel is not None:
135 effort_limit = rmodel.effortLimit
136 joint_names = rmodel.names
137
138 for i in range(nu):
139 ax: plt.Axes = axes[i]
140 ax.step(times[:-1], us[:, i])
141 if effort_limit is not None:
142 ylim = ax.get_ylim()
143 ax.hlines(-effort_limit[i], t0, tf, colors="k", linestyles="--")
144 ax.hlines(+effort_limit[i], t0, tf, colors="r", linestyles="dashdot")
145 ax.set_ylim(*ylim)
146 if joint_names is not None:
147 joint_name = joint_names[i].lower()
148 ax.set_title(joint_name, fontsize=8)
149 if nu > 1:
150 fig.supxlabel(xlabel)
151 fig.suptitle("Control trajectories")
152 else:
153 axes[0].set_xlabel(xlabel)
154 axes[0].set_title("Control trajectories")
155 fig.tight_layout()
156 return fig, axes
157
158
160 times,
161 vs,
162 rmodel,
163 axes=None,
164 ncols=2,
165 vel_limit=None,
166 figsize=(6.4, 6.4),
167 xlabel="Time (s)",
168) -> tuple[plt.Figure, list[plt.Axes]]:
169 vs = np.asarray(vs)
170 nv = rmodel.nv
171 assert nv == vs.shape[1]
172 if vel_limit is not None:
173 assert nv == vel_limit.shape[0]
174 idx_to_joint_id_map = {}
175 jid = 0
176 for i in range(nv):
177 if i in rmodel.idx_vs.tolist():
178 jid += 1
179 idx_to_joint_id_map[i] = jid
180 nrows, r = divmod(nv, ncols)
181 nrows += int(r > 0)
182
183 t0 = times[0]
184 tf = times[-1]
185
186 if axes is None:
187 fig, axes = plt.subplots(nrows, ncols, sharex=True, figsize=figsize)
188 fig: plt.Figure
189 else:
190 fig = axes.flat[0].get_figure()
191 axes = _axes_flatten_if_ndarray(axes)
192
193 for i in range(nv):
194 ax: plt.Axes = axes[i]
195 ax.plot(times, vs[:, i])
196 jid = idx_to_joint_id_map[i]
197 joint_name = rmodel.names[jid].lower()
198 if vel_limit is not None:
199 ylim = ax.get_ylim()
200 ax.hlines(-vel_limit[i], t0, tf, colors="k", linestyles="--")
201 ax.hlines(+vel_limit[i], t0, tf, colors="r", linestyles="dashdot")
202 ax.set_ylim(*ylim)
203 ax.set_title(joint_name, fontsize=8)
204
205 fig.supxlabel(xlabel)
206 fig.suptitle("Velocity trajectories")
207 fig.tight_layout()
208 return fig, axes
plot_convergence(HistoryCallback cb, plt.Axes ax, Results res=None, *, show_al_iters=False, float target_tol=None, legend_kwargs={})
Definition plotting.py:49
plt.Rectangle plot_se2_pose(np.ndarray q, plt.Axes ax, alpha=0.5, fc="tab:blue")
Definition plotting.py:87
tuple[plt.Figure, list[plt.Axes]] plot_controls_traj(times, us, ncols=2, axes=None, effort_limit=None, joint_names=None, rmodel=None, figsize=(6.4, 6.4), xlabel="Time (s)")
Definition plotting.py:119
plot_pd_errs(ax0, prim_errs, dual_errs)
Definition plotting.py:9
tuple[plt.Figure, list[plt.Axes]] plot_velocity_traj(times, vs, rmodel, axes=None, ncols=2, vel_limit=None, figsize=(6.4, 6.4), xlabel="Time (s)")
Definition plotting.py:168
list[plt.Axes] _axes_flatten_if_ndarray(axes)
Definition plotting.py:101