-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy patharray.py
More file actions
298 lines (265 loc) · 9.53 KB
/
array.py
File metadata and controls
298 lines (265 loc) · 9.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""
Standalone function for plotting a 2D array (image) directly with matplotlib.
"""
import os
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm, Normalize
from autoarray.plot.utils import (
apply_extent,
apply_labels,
conf_figsize,
save_figure,
zoom_array,
auto_mask_edge,
numpy_grid,
numpy_lines,
numpy_positions,
)
_zoom_array_2d = zoom_array
_mask_edge_coords = auto_mask_edge
def plot_array(
array,
ax: Optional[plt.Axes] = None,
# --- spatial metadata -------------------------------------------------------
extent: Optional[Tuple[float, float, float, float]] = None,
# --- overlays ---------------------------------------------------------------
mask: Optional[np.ndarray] = None,
border=None,
origin=None,
grid=None,
mesh_grid=None,
positions=None,
lines=None,
vector_yx: Optional[np.ndarray] = None,
array_overlay=None,
patches: Optional[List] = None,
fill_region: Optional[List] = None,
contours: Optional[int] = None,
# --- cosmetics --------------------------------------------------------------
title: str = "",
xlabel: str = "",
ylabel: str = "",
colormap: Optional[str] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
use_log10: bool = False,
cb_unit: Optional[str] = None,
origin_imshow: str = "upper",
# --- figure control (used only when ax is None) -----------------------------
figsize: Optional[Tuple[int, int]] = None,
output_path: Optional[str] = None,
output_filename: str = "array",
output_format: str = "png",
structure=None,
) -> None:
"""
Plot a 2D array (image) using ``plt.imshow``.
Parameters
----------
array
2D numpy array of pixel values.
ax
Existing matplotlib ``Axes`` to draw onto. If ``None`` a new figure
is created and saved / shown according to *output_path*.
extent
``[xmin, xmax, ymin, ymax]`` spatial extent in data coordinates.
mask
Array of shape ``(N, 2)`` with ``(y, x)`` coordinates of masked
pixels to overlay as black dots (auto-derived from array.mask by caller).
border
Array of shape ``(N, 2)`` with ``(y, x)`` border pixel coordinates.
origin
``(y, x)`` origin coordinate(s) to scatter as a marker.
grid
Array of shape ``(N, 2)`` with ``(y, x)`` coordinates to scatter.
mesh_grid
Array of shape ``(N, 2)`` mesh grid coordinates to scatter.
positions
List of ``(N, 2)`` arrays; each is scattered as a distinct group.
lines
List of ``(N, 2)`` arrays with ``(y, x)`` columns to plot as lines.
vector_yx
Array of shape ``(N, 4)`` — ``(y, x, vy, vx)`` — plotted as quiver.
array_overlay
A second 2D array rendered on top of *array* with partial alpha.
patches
List of matplotlib ``Patch`` objects to draw over the image.
fill_region
List of two arrays ``[y1_arr, y2_arr]`` passed to ``ax.fill_between``.
title
Figure title string.
xlabel, ylabel
Axis label strings.
colormap
Matplotlib colormap name.
vmin, vmax
Explicit color scale limits.
use_log10
When ``True`` a ``LogNorm`` is applied.
origin_imshow
Passed directly to ``imshow`` (``"upper"`` or ``"lower"``).
figsize
Figure size in inches.
output_path
Directory to save the figure. When empty / ``None`` ``plt.show()``
is called instead.
output_filename
Base file name (without extension).
output_format
File format, e.g. ``"png"``.
"""
# --- autoarray extraction --------------------------------------------------
array = zoom_array(array)
try:
if structure is None:
structure = array
if extent is None:
extent = array.geometry.extent
if mask is None:
mask = auto_mask_edge(array)
array = array.native.array
except AttributeError:
array = np.asarray(array)
if array is None or array.size == 0:
return
if colormap is None:
from autoarray.plot.utils import _default_colormap
colormap = _default_colormap()
# convert overlay params (safe for None and already-numpy inputs)
border = numpy_grid(border)
origin = numpy_grid(origin)
grid = numpy_grid(grid)
mesh_grid = numpy_grid(mesh_grid)
positions = numpy_positions(positions)
lines = numpy_lines(lines)
if array_overlay is not None:
try:
array_overlay = array_overlay.native.array
except AttributeError:
array_overlay = np.asarray(array_overlay)
owns_figure = ax is None
if owns_figure:
figsize = figsize or conf_figsize("figures")
fig, ax = plt.subplots(1, 1, figsize=figsize)
else:
fig = ax.get_figure()
# --- colour normalisation --------------------------------------------------
if use_log10:
try:
from autoconf import conf as _conf
log10_min = _conf.instance["visualize"]["general"]["general"][
"log10_min_value"
]
except Exception:
log10_min = 1.0e-4
clipped = np.clip(array, log10_min, None)
vmin_log = vmin if (vmin is not None and np.isfinite(vmin)) else log10_min
if vmax is not None and np.isfinite(vmax):
vmax_log = vmax
else:
with np.errstate(all="ignore"):
vmax_log = np.nanmax(clipped)
if not np.isfinite(vmax_log) or vmax_log <= vmin_log:
vmax_log = vmin_log * 10.0
norm = LogNorm(vmin=vmin_log, vmax=vmax_log)
elif vmin is not None or vmax is not None:
norm = Normalize(vmin=vmin, vmax=vmax)
else:
norm = None
# Compute the axes-box aspect ratio from the data extent so that the
# physical cell is correctly shaped and tight_layout has no whitespace
# to absorb. This reproduces the old "square" subplot behaviour where
# ratio = x_range / y_range was passed to plt.subplot(aspect=ratio).
if extent is not None:
x_range = abs(extent[1] - extent[0])
y_range = abs(extent[3] - extent[2])
_box_aspect = (x_range / y_range) if y_range > 0 else 1.0
else:
h, w = array.shape[:2]
_box_aspect = (w / h) if h > 0 else 1.0
im = ax.imshow(
array,
cmap=colormap,
norm=norm,
extent=extent,
aspect="auto", # image fills the axes box; box shape set below
origin=origin_imshow,
)
# Shape the axes box to match the data so there is no surrounding
# whitespace when the panel is embedded in a subplot grid.
ax.set_aspect(_box_aspect, adjustable="box")
from autoarray.plot.utils import _apply_colorbar
_apply_colorbar(im, ax, cb_unit=cb_unit, is_subplot=not owns_figure)
# --- overlays --------------------------------------------------------------
if array_overlay is not None:
ax.imshow(
array_overlay,
cmap="Greys",
alpha=0.5,
extent=extent,
aspect="auto",
origin=origin_imshow,
)
if mask is not None:
ax.scatter(mask[:, 1], mask[:, 0], s=1, c="k")
if border is not None:
ax.scatter(border[:, 1], border[:, 0], s=1, c="b")
if origin is not None:
origin_arr = np.asarray(origin)
if origin_arr.ndim == 1:
origin_arr = origin_arr[np.newaxis, :]
ax.scatter(
origin_arr[:, 1], origin_arr[:, 0], s=20, c="r", marker="x", zorder=6
)
if grid is not None:
ax.scatter(grid[:, 1], grid[:, 0], s=1, c="k")
if mesh_grid is not None:
ax.scatter(mesh_grid[:, 1], mesh_grid[:, 0], s=1, c="w", alpha=0.5)
if positions is not None:
colors = ["k", "g", "b", "m", "c", "y"]
for i, pos in enumerate(positions):
ax.scatter(pos[:, 1], pos[:, 0], s=20, c=colors[i % len(colors)], zorder=5)
if lines is not None:
for line in lines:
if line is not None and len(line) > 0:
ax.plot(line[:, 1], line[:, 0], linewidth=2)
if vector_yx is not None:
ax.quiver(
vector_yx[:, 1],
vector_yx[:, 0],
vector_yx[:, 3],
vector_yx[:, 2],
)
if patches is not None:
for patch in patches:
import copy
ax.add_patch(copy.copy(patch))
if fill_region is not None:
y1, y2 = fill_region[0], fill_region[1]
x_fill = np.arange(len(y1))
ax.fill_between(x_fill, y1, y2, alpha=0.3)
if contours is not None and contours > 0:
try:
levels = np.linspace(np.nanmin(array), np.nanmax(array), contours)
cs = ax.contour(array[::-1], levels=levels, extent=extent, colors="k")
try:
ax.clabel(cs, levels=levels, inline=True, fontsize=8)
except (ValueError, IndexError):
pass
except Exception:
pass
# --- labels / ticks --------------------------------------------------------
apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel, is_subplot=not owns_figure)
if extent is not None:
apply_extent(ax, extent)
# --- output ----------------------------------------------------------------
if owns_figure:
save_figure(
fig,
path=output_path or "",
filename=output_filename,
format=output_format,
structure=structure,
)