Skip to content

Commit e432f30

Browse files
authored
Merge pull request #242 from Jammy2211/feature/plot-improvements-4
Plot improvements: line_colors, is_subplot colorbar sizing, inversion panels
2 parents 481dcc2 + a8e8567 commit e432f30

File tree

3 files changed

+44
-14
lines changed

3 files changed

+44
-14
lines changed

autoarray/inversion/plot/mapper_plots.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def plot_mapper(
2020
use_log10: bool = False,
2121
mesh_grid=None,
2222
lines=None,
23+
line_colors=None,
2324
title: str = "Pixelization Mesh (Source-Plane)",
2425
zoom_to_brightest: bool = True,
2526
ax=None,
@@ -64,6 +65,7 @@ def plot_mapper(
6465
use_log10=use_log10,
6566
zoom_to_brightest=zoom_to_brightest,
6667
lines=numpy_lines(lines),
68+
line_colors=line_colors,
6769
grid=numpy_grid(mesh_grid),
6870
output_path=output_path,
6971
output_filename=output_filename,

autoarray/plot/array.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def plot_array(
5353
vmax: Optional[float] = None,
5454
use_log10: bool = False,
5555
cb_unit: Optional[str] = None,
56+
line_colors: Optional[List] = None,
5657
origin_imshow: str = "upper",
5758
# --- figure control (used only when ax is None) -----------------------------
5859
figsize: Optional[Tuple[int, int]] = None,
@@ -243,12 +244,19 @@ def plot_array(
243244
if positions is not None:
244245
colors = ["k", "g", "b", "m", "c", "y"]
245246
for i, pos in enumerate(positions):
247+
pos = np.asarray(pos).reshape(-1, 2)
246248
ax.scatter(pos[:, 1], pos[:, 0], s=20, c=colors[i % len(colors)], zorder=5)
247249

250+
248251
if lines is not None:
249-
for line in lines:
252+
for i, line in enumerate(lines):
250253
if line is not None and len(line) > 0:
251-
ax.plot(line[:, 1], line[:, 0], linewidth=2)
254+
line = np.asarray(line).reshape(-1, 2)
255+
color = line_colors[i] if (line_colors is not None and i < len(line_colors)) else None
256+
kw = {"linewidth": 2}
257+
if color is not None:
258+
kw["color"] = color
259+
ax.plot(line[:, 1], line[:, 0], **kw)
252260

253261
if vector_yx is not None:
254262
ax.quiver(

autoarray/plot/inversion.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def plot_inversion_reconstruction(
2828
zoom_to_brightest: bool = True,
2929
# --- overlays ---------------------------------------------------------------
3030
lines: Optional[List[np.ndarray]] = None,
31+
line_colors: Optional[List] = None,
3132
grid: Optional[np.ndarray] = None,
3233
# --- figure control (used only when ax is None) -----------------------------
3334
figsize: Optional[Tuple[int, int]] = None,
@@ -109,27 +110,34 @@ def plot_inversion_reconstruction(
109110
values=pixel_values, zoom_to_brightest=zoom_to_brightest
110111
)
111112

113+
is_subplot = not owns_figure
114+
112115
if isinstance(
113116
mapper.interpolator, (InterpolatorRectangular, InterpolatorRectangularUniform)
114117
):
115-
_plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent)
118+
_plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent, is_subplot=is_subplot)
116119
elif isinstance(
117120
mapper.interpolator, (InterpolatorDelaunay, InterpolatorKNearestNeighbor)
118121
):
119-
_plot_delaunay(ax, pixel_values, mapper, norm, colormap)
122+
_plot_delaunay(ax, pixel_values, mapper, norm, colormap, is_subplot=is_subplot)
120123

121124
# --- overlays --------------------------------------------------------------
122125
if lines is not None:
123-
for line in lines:
126+
for i, line in enumerate(lines):
124127
if line is not None and len(line) > 0:
125-
ax.plot(line[:, 1], line[:, 0], linewidth=2)
128+
line = np.asarray(line).reshape(-1, 2)
129+
color = line_colors[i] if (line_colors is not None and i < len(line_colors)) else None
130+
kw = {"linewidth": 2}
131+
if color is not None:
132+
kw["color"] = color
133+
ax.plot(line[:, 1], line[:, 0], **kw)
126134

127135
if grid is not None:
128136
ax.scatter(grid[:, 1], grid[:, 0], s=1, c="w", alpha=0.5)
129137

130138
apply_extent(ax, extent)
131139

132-
apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel)
140+
apply_labels(ax, title=title, xlabel="" if is_subplot else xlabel, ylabel="" if is_subplot else ylabel)
133141

134142
if owns_figure:
135143
save_figure(
@@ -140,7 +148,7 @@ def plot_inversion_reconstruction(
140148
)
141149

142150

143-
def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent):
151+
def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent, is_subplot=False):
144152
"""Render a rectangular pixelization reconstruction onto *ax*.
145153
146154
Uses ``imshow`` for uniform rectangular grids
@@ -164,6 +172,9 @@ def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent):
164172
Matplotlib colormap name.
165173
extent
166174
``[xmin, xmax, ymin, ymax]`` spatial extent; passed to ``imshow``.
175+
is_subplot
176+
When ``True`` uses ``labelsize_subplot`` from config for the colorbar
177+
tick labels (matches the behaviour of :func:`~autoarray.plot.array.plot_array`).
167178
"""
168179
from autoarray.inversion.mesh.interpolator.rectangular_uniform import (
169180
InterpolatorRectangularUniform,
@@ -175,6 +186,14 @@ def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent):
175186
if pixel_values is None:
176187
pixel_values = np.zeros(shape_native[0] * shape_native[1])
177188

189+
xmin, xmax, ymin, ymax = extent
190+
x_range = abs(xmax - xmin)
191+
y_range = abs(ymax - ymin)
192+
box_aspect = (x_range / y_range) if y_range > 0 else 1.0
193+
ax.set_aspect(box_aspect, adjustable="box")
194+
195+
from autoarray.plot.utils import _apply_colorbar
196+
178197
if isinstance(mapper.interpolator, InterpolatorRectangularUniform):
179198
from autoarray.structures.arrays.uniform_2d import Array2D
180199
from autoarray.structures.arrays import array_2d_util
@@ -196,8 +215,7 @@ def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent):
196215
aspect="auto",
197216
origin="upper",
198217
)
199-
from autoarray.plot.utils import _apply_colorbar
200-
_apply_colorbar(im, ax)
218+
_apply_colorbar(im, ax, is_subplot=is_subplot)
201219
else:
202220
y_edges, x_edges = mapper.mesh_geometry.edges_transformed.T
203221
Y, X = np.meshgrid(y_edges, x_edges, indexing="ij")
@@ -209,11 +227,10 @@ def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent):
209227
norm=norm,
210228
cmap=colormap,
211229
)
212-
from autoarray.plot.utils import _apply_colorbar
213-
_apply_colorbar(im, ax)
230+
_apply_colorbar(im, ax, is_subplot=is_subplot)
214231

215232

216-
def _plot_delaunay(ax, pixel_values, mapper, norm, colormap):
233+
def _plot_delaunay(ax, pixel_values, mapper, norm, colormap, is_subplot=False):
217234
"""Render a Delaunay or KNN pixelization reconstruction onto *ax*.
218235
219236
Uses ``ax.tripcolor`` with Gouraud shading so that the reconstructed
@@ -235,6 +252,9 @@ def _plot_delaunay(ax, pixel_values, mapper, norm, colormap):
235252
``None`` for automatic scaling.
236253
colormap
237254
Matplotlib colormap name.
255+
is_subplot
256+
When ``True`` uses ``labelsize_subplot`` from config for the colorbar
257+
tick labels (matches the behaviour of :func:`~autoarray.plot.array.plot_array`).
238258
"""
239259
mesh_grid = mapper.source_plane_mesh_grid
240260
x = mesh_grid[:, 1]
@@ -247,4 +267,4 @@ def _plot_delaunay(ax, pixel_values, mapper, norm, colormap):
247267

248268
tc = ax.tripcolor(x, y, vals, cmap=colormap, norm=norm, shading="gouraud")
249269
from autoarray.plot.utils import _apply_colorbar
250-
_apply_colorbar(tc, ax)
270+
_apply_colorbar(tc, ax, is_subplot=is_subplot)

0 commit comments

Comments
 (0)