Skip to content

Commit ceea220

Browse files
committed
Move autoarray extraction into plot_array/plot_grid/plot_yx; remove all private _plot_* helpers
plot_array, plot_grid, and plot_yx now accept autoarray objects directly: - plot_array: calls zoom_array, extracts .native.array / .geometry.extent, derives mask via auto_mask_edge, converts all overlay params (grid, positions, lines, border, origin, array_overlay) via numpy_* helpers - plot_grid: extracts .array from grid objects, converts lines via numpy_lines, preserves extent_with_buffer_from before numpy conversion - plot_yx: extracts .array, falls back to .grid_radial for default x Consequences: - All private _plot_fit_array / _plot_dataset_array / _plot_array / _plot_grid / _plot_yx helpers removed from every subplot file; callers now call the core functions directly - plot_mapper_image removed (was just plot_array with extraction, now redundant) - structure_plots.py reduced to three re-export aliases - symmetric_vmin_vmax moved to utils.py and exported publicly https://claude.ai/code/session_01B9sVEV54XWCa2LJw1C8gvv
1 parent b491a11 commit ceea220

File tree

14 files changed

+181
-706
lines changed

14 files changed

+181
-706
lines changed
Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,9 @@
1-
import numpy as np
21
from typing import Optional
32

43
import matplotlib.pyplot as plt
54

65
from autoarray.plot.plots.array import plot_array
7-
from autoarray.plot.plots.utils import (
8-
auto_mask_edge,
9-
zoom_array,
10-
numpy_grid,
11-
numpy_lines,
12-
numpy_positions,
13-
subplot_save,
14-
)
15-
16-
17-
def _plot_dataset_array(
18-
array,
19-
ax,
20-
title,
21-
colormap,
22-
use_log10,
23-
grid=None,
24-
positions=None,
25-
lines=None,
26-
):
27-
"""Internal helper: plot one array component onto *ax*."""
28-
if array is None:
29-
return
30-
31-
array = zoom_array(array)
32-
33-
try:
34-
arr = array.native.array
35-
extent = array.geometry.extent
36-
except AttributeError:
37-
arr = np.asarray(array)
38-
extent = None
39-
40-
plot_array(
41-
array=arr,
42-
ax=ax,
43-
extent=extent,
44-
mask=auto_mask_edge(array) if hasattr(array, "mask") else None,
45-
grid=numpy_grid(grid),
46-
positions=numpy_positions(positions),
47-
lines=numpy_lines(lines),
48-
title=title,
49-
colormap=colormap,
50-
use_log10=use_log10,
51-
)
6+
from autoarray.plot.plots.utils import subplot_save
527

538

549
def subplot_imaging_dataset(
@@ -95,17 +50,17 @@ def subplot_imaging_dataset(
9550
fig, axes = plt.subplots(3, 3, figsize=(21, 21))
9651
axes = axes.flatten()
9752

98-
_plot_dataset_array(dataset.data, axes[0], "Data", colormap, use_log10, grid, positions, lines)
99-
_plot_dataset_array(dataset.data, axes[1], "Data (log10)", colormap, True, grid, positions, lines)
100-
_plot_dataset_array(dataset.noise_map, axes[2], "Noise-Map", colormap, use_log10, grid, positions, lines)
53+
plot_array(dataset.data, ax=axes[0], title="Data", colormap=colormap, use_log10=use_log10, grid=grid, positions=positions, lines=lines)
54+
plot_array(dataset.data, ax=axes[1], title="Data (log10)", colormap=colormap, use_log10=True, grid=grid, positions=positions, lines=lines)
55+
plot_array(dataset.noise_map, ax=axes[2], title="Noise-Map", colormap=colormap, use_log10=use_log10, grid=grid, positions=positions, lines=lines)
10156

10257
if dataset.psf is not None:
103-
_plot_dataset_array(dataset.psf.kernel, axes[3], "Point Spread Function", colormap, use_log10)
104-
_plot_dataset_array(dataset.psf.kernel, axes[4], "PSF (log10)", colormap, True)
58+
plot_array(dataset.psf.kernel, ax=axes[3], title="Point Spread Function", colormap=colormap, use_log10=use_log10)
59+
plot_array(dataset.psf.kernel, ax=axes[4], title="PSF (log10)", colormap=colormap, use_log10=True)
10560

106-
_plot_dataset_array(dataset.signal_to_noise_map, axes[5], "Signal-To-Noise Map", colormap, use_log10, grid, positions, lines)
107-
_plot_dataset_array(dataset.grids.over_sample_size_lp, axes[6], "Over Sample Size (Light Profiles)", colormap, use_log10)
108-
_plot_dataset_array(dataset.grids.over_sample_size_pixelization, axes[7], "Over Sample Size (Pixelization)", colormap, use_log10)
61+
plot_array(dataset.signal_to_noise_map, ax=axes[5], title="Signal-To-Noise Map", colormap=colormap, use_log10=use_log10, grid=grid, positions=positions, lines=lines)
62+
plot_array(dataset.grids.over_sample_size_lp, ax=axes[6], title="Over Sample Size (Light Profiles)", colormap=colormap, use_log10=use_log10)
63+
plot_array(dataset.grids.over_sample_size_pixelization, ax=axes[7], title="Over Sample Size (Pixelization)", colormap=colormap, use_log10=use_log10)
10964

11065
plt.tight_layout()
11166
subplot_save(fig, output_path, output_filename, output_format)

autoarray/dataset/plot/interferometer_plots.py

Lines changed: 13 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,10 @@
66
from autoarray.plot.plots.array import plot_array
77
from autoarray.plot.plots.grid import plot_grid
88
from autoarray.plot.plots.yx import plot_yx
9-
from autoarray.plot.plots.utils import auto_mask_edge, zoom_array, subplot_save
9+
from autoarray.plot.plots.utils import subplot_save
1010
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
1111

1212

13-
def _plot_array(array, ax, title, colormap, use_log10, output_path=None, output_filename=None, output_format="png"):
14-
array = zoom_array(array)
15-
try:
16-
arr = array.native.array
17-
extent = array.geometry.extent
18-
except AttributeError:
19-
arr = np.asarray(array)
20-
extent = None
21-
22-
plot_array(
23-
array=arr,
24-
ax=ax,
25-
extent=extent,
26-
mask=auto_mask_edge(array) if hasattr(array, "mask") else None,
27-
title=title,
28-
colormap=colormap,
29-
use_log10=use_log10,
30-
output_path=output_path,
31-
output_filename=output_filename or "",
32-
output_format=output_format,
33-
)
34-
35-
36-
def _plot_grid(grid, ax, title, colormap, color_array=None, output_path=None, output_filename=None, output_format="png"):
37-
plot_grid(
38-
grid=np.array(grid.array),
39-
ax=ax,
40-
color_array=color_array,
41-
title=title,
42-
output_path=output_path,
43-
output_filename=output_filename or "",
44-
output_format=output_format,
45-
)
46-
47-
48-
def _plot_yx(y, x, ax, title, ylabel="", xlabel="", plot_axis_type="linear",
49-
output_path=None, output_filename=None, output_format="png"):
50-
plot_yx(
51-
y=np.asarray(y),
52-
x=np.asarray(x) if x is not None else None,
53-
ax=ax,
54-
title=title,
55-
ylabel=ylabel,
56-
xlabel=xlabel,
57-
plot_axis_type=plot_axis_type,
58-
output_path=output_path,
59-
output_filename=output_filename or "",
60-
output_format=output_format,
61-
)
62-
63-
6413
def subplot_interferometer_dataset(
6514
dataset,
6615
output_path: Optional[str] = None,
@@ -93,26 +42,20 @@ def subplot_interferometer_dataset(
9342
fig, axes = plt.subplots(2, 3, figsize=(21, 14))
9443
axes = axes.flatten()
9544

96-
_plot_grid(dataset.data.in_grid, axes[0], "Visibilities", colormap)
97-
_plot_grid(
45+
plot_grid(dataset.data.in_grid, ax=axes[0], title="Visibilities")
46+
plot_grid(
9847
Grid2DIrregular.from_yx_1d(
9948
y=dataset.uv_wavelengths[:, 1] / 10**3.0,
10049
x=dataset.uv_wavelengths[:, 0] / 10**3.0,
10150
),
102-
axes[1], "UV-Wavelengths", colormap,
103-
)
104-
_plot_yx(
105-
dataset.amplitudes, dataset.uv_distances / 10**3.0,
106-
axes[2], "Amplitudes vs UV-distances",
107-
ylabel="Jy", xlabel="k$\\lambda$", plot_axis_type="scatter",
108-
)
109-
_plot_yx(
110-
dataset.phases, dataset.uv_distances / 10**3.0,
111-
axes[3], "Phases vs UV-distances",
112-
ylabel="deg", xlabel="k$\\lambda$", plot_axis_type="scatter",
51+
ax=axes[1], title="UV-Wavelengths",
11352
)
114-
_plot_array(dataset.dirty_image, axes[4], "Dirty Image", colormap, use_log10)
115-
_plot_array(dataset.dirty_signal_to_noise_map, axes[5], "Dirty Signal-To-Noise Map", colormap, use_log10)
53+
plot_yx(dataset.amplitudes, dataset.uv_distances / 10**3.0, ax=axes[2],
54+
title="Amplitudes vs UV-distances", ylabel="Jy", xlabel="k$\\lambda$", plot_axis_type="scatter")
55+
plot_yx(dataset.phases, dataset.uv_distances / 10**3.0, ax=axes[3],
56+
title="Phases vs UV-distances", ylabel="deg", xlabel="k$\\lambda$", plot_axis_type="scatter")
57+
plot_array(dataset.dirty_image, ax=axes[4], title="Dirty Image", colormap=colormap, use_log10=use_log10)
58+
plot_array(dataset.dirty_signal_to_noise_map, ax=axes[5], title="Dirty Signal-To-Noise Map", colormap=colormap, use_log10=use_log10)
11659

11760
plt.tight_layout()
11861
subplot_save(fig, output_path, output_filename, output_format)
@@ -146,9 +89,9 @@ def subplot_interferometer_dirty_images(
14689
"""
14790
fig, axes = plt.subplots(1, 3, figsize=(21, 7))
14891

149-
_plot_array(dataset.dirty_image, axes[0], "Dirty Image", colormap, use_log10)
150-
_plot_array(dataset.dirty_noise_map, axes[1], "Dirty Noise Map", colormap, use_log10)
151-
_plot_array(dataset.dirty_signal_to_noise_map, axes[2], "Dirty Signal-To-Noise Map", colormap, use_log10)
92+
plot_array(dataset.dirty_image, ax=axes[0], title="Dirty Image", colormap=colormap, use_log10=use_log10)
93+
plot_array(dataset.dirty_noise_map, ax=axes[1], title="Dirty Noise Map", colormap=colormap, use_log10=use_log10)
94+
plot_array(dataset.dirty_signal_to_noise_map, ax=axes[2], title="Dirty Signal-To-Noise Map", colormap=colormap, use_log10=use_log10)
15295

15396
plt.tight_layout()
15497
subplot_save(fig, output_path, output_filename, output_format)
Lines changed: 9 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,9 @@
1-
import numpy as np
21
from typing import Optional
32

43
import matplotlib.pyplot as plt
54

65
from autoarray.plot.plots.array import plot_array
7-
from autoarray.plot.plots.utils import (
8-
auto_mask_edge,
9-
zoom_array,
10-
numpy_grid,
11-
numpy_lines,
12-
numpy_positions,
13-
subplot_save,
14-
)
15-
16-
17-
def _plot_fit_array(
18-
array,
19-
ax,
20-
title,
21-
colormap,
22-
use_log10,
23-
vmin=None,
24-
vmax=None,
25-
grid=None,
26-
positions=None,
27-
lines=None,
28-
):
29-
if array is None:
30-
return
31-
32-
array = zoom_array(array)
33-
34-
try:
35-
arr = array.native.array
36-
extent = array.geometry.extent
37-
except AttributeError:
38-
arr = np.asarray(array)
39-
extent = None
40-
41-
plot_array(
42-
array=arr,
43-
ax=ax,
44-
extent=extent,
45-
mask=auto_mask_edge(array) if hasattr(array, "mask") else None,
46-
grid=numpy_grid(grid),
47-
positions=numpy_positions(positions),
48-
lines=numpy_lines(lines),
49-
title=title,
50-
colormap=colormap,
51-
use_log10=use_log10,
52-
vmin=vmin,
53-
vmax=vmax,
54-
)
55-
56-
57-
def _symmetric_vmin_vmax(array):
58-
"""Return (-abs_max, abs_max) for a symmetric colormap."""
59-
try:
60-
arr = array.native.array if hasattr(array, "native") else np.asarray(array)
61-
abs_max = np.nanmax(np.abs(arr))
62-
return -abs_max, abs_max
63-
except Exception:
64-
return None, None
6+
from autoarray.plot.plots.utils import subplot_save, symmetric_vmin_vmax
657

668

679
def subplot_fit_imaging(
@@ -104,19 +46,19 @@ def subplot_fit_imaging(
10446
fig, axes = plt.subplots(2, 3, figsize=(21, 14))
10547
axes = axes.flatten()
10648

107-
_plot_fit_array(fit.data, axes[0], "Data", colormap, use_log10, grid=grid, positions=positions, lines=lines)
108-
_plot_fit_array(fit.signal_to_noise_map, axes[1], "Signal-To-Noise Map", colormap, use_log10, grid=grid, positions=positions, lines=lines)
109-
_plot_fit_array(fit.model_data, axes[2], "Model Image", colormap, use_log10, grid=grid, positions=positions, lines=lines)
49+
plot_array(fit.data, ax=axes[0], title="Data", colormap=colormap, use_log10=use_log10, grid=grid, positions=positions, lines=lines)
50+
plot_array(fit.signal_to_noise_map, ax=axes[1], title="Signal-To-Noise Map", colormap=colormap, use_log10=use_log10, grid=grid, positions=positions, lines=lines)
51+
plot_array(fit.model_data, ax=axes[2], title="Model Image", colormap=colormap, use_log10=use_log10, grid=grid, positions=positions, lines=lines)
11052

11153
if residuals_symmetric_cmap:
112-
vmin_r, vmax_r = _symmetric_vmin_vmax(fit.residual_map)
113-
vmin_n, vmax_n = _symmetric_vmin_vmax(fit.normalized_residual_map)
54+
vmin_r, vmax_r = symmetric_vmin_vmax(fit.residual_map)
55+
vmin_n, vmax_n = symmetric_vmin_vmax(fit.normalized_residual_map)
11456
else:
11557
vmin_r = vmax_r = vmin_n = vmax_n = None
11658

117-
_plot_fit_array(fit.residual_map, axes[3], "Residual Map", colormap, False, vmin=vmin_r, vmax=vmax_r, grid=grid, positions=positions, lines=lines)
118-
_plot_fit_array(fit.normalized_residual_map, axes[4], "Normalized Residual Map", colormap, False, vmin=vmin_n, vmax=vmax_n, grid=grid, positions=positions, lines=lines)
119-
_plot_fit_array(fit.chi_squared_map, axes[5], "Chi-Squared Map", colormap, use_log10, grid=grid, positions=positions, lines=lines)
59+
plot_array(fit.residual_map, ax=axes[3], title="Residual Map", colormap=colormap, use_log10=False, vmin=vmin_r, vmax=vmax_r, grid=grid, positions=positions, lines=lines)
60+
plot_array(fit.normalized_residual_map, ax=axes[4], title="Normalized Residual Map", colormap=colormap, use_log10=False, vmin=vmin_n, vmax=vmax_n, grid=grid, positions=positions, lines=lines)
61+
plot_array(fit.chi_squared_map, ax=axes[5], title="Chi-Squared Map", colormap=colormap, use_log10=use_log10, grid=grid, positions=positions, lines=lines)
12062

12163
plt.tight_layout()
12264
subplot_save(fig, output_path, output_filename, output_format)

0 commit comments

Comments
 (0)