diff --git a/.gitignore b/.gitignore index 1cd0ed3ff..e5f27da06 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,6 @@ MultiNest/ hyper hyper_galaxies files -imaging Nonesubplots scripts subplots @@ -66,3 +65,4 @@ docs/_static docs/_templates docs/generated docs/api/generated +autolens_workspace_test/ diff --git a/PLOT_REFACTOR_PLAN.md b/PLOT_REFACTOR_PLAN.md new file mode 100644 index 000000000..00ca380f0 --- /dev/null +++ b/PLOT_REFACTOR_PLAN.md @@ -0,0 +1,610 @@ +# Plot Module Refactoring Plan + +Remove `MatPlot`, `MatWrap`, `Visuals`, and `Output` from PyAutoArray/PyAutoGalaxy/PyAutoLens +in favour of direct matplotlib calls with explicit parameters. + +--- + +## Goals + +- Delete `MatPlot1D`, `MatPlot2D` and all `~50` `MatWrap` wrapper classes +- Delete `Visuals1D`, `Visuals2D` and all subclasses +- Delete `Output` — replaced by a `save_figure()` helper function +- Delete all `mat_wrap*.yaml` config files — only `visualize/general.yaml` (figsize) survives +- Keep all `*Plotter` classes as the public API (internal wiring rewired) +- Keep `plots.yaml` which controls which subplots are auto-generated during analysis runs +- All unit tests pass after every PR + +--- + +## What is wrong with the current design + +### MatWrap / MatPlot + +Every matplotlib concept (colormap, ticks, colorbar, title, axis extent, …) has a +corresponding Python class that loads default values from a YAML config file. +There are ~50 such classes, each with a `figure:` and `subplot:` config section, +totalling three config files (~10 KB of YAML) just for plot defaults. + +The indirection adds no value: the same result is achieved with plain function +default-parameter values, which are visible in the code and require no config lookup. + +### Visuals + +`Visuals2D` is a dataclass of overlays (critical curves, caustics, centres, …). +It has many variants to satisfy the config-switching machinery. The same information +can be passed as typed list/array arguments to the plot functions. + +### The subplot state machine + +The current system tracks subplot position through a mutable integer +`mat_plot_2d.subplot_index` that auto-increments after every plot call. + +Problems: +- Developers manually patch it (`self.mat_plot_2d.subplot_index = 6`) to skip slots +- `mat_plot_1d` and `mat_plot_2d` have *independent* counters; a comment in the code + describes the workaround as a "nasty hack" +- The config system switches every wrap object between `figure:` and `subplot:` sections + based on whether `subplot_index is not None`, adding hidden state to every config lookup +- Nested plotters (FitImagingPlotter → TracerPlotter → InversionPlotter) share one + mat_plot object so their indices accumulate in the same global counter + +The fix is to use matplotlib's native `plt.subplots()` and pass `ax` objects directly. + +### Ticks + +`XTicks` / `YTicks` MatWrap classes have special-case logic for log-scale ticks. +The replacement generates 3 evenly spaced linear ticks from the extent using +`np.linspace` (as in the reference `plot_grid()` example). Log scales on colorbars +are handled by passing `LogNorm()` to `imshow` — matplotlib handles the ticks itself. + +--- + +## New design in one picture + +``` +Plotter.subplot_fit() + │ + ├── fig, axes = plt.subplots(3, 4, figsize=conf_figsize("subplots")) + │ + ├── plot_array(array=fit.data, title="Data", ax=axes[0,0]) + ├── plot_array(array=fit.noise_map, title="Noise", ax=axes[0,1]) + ├── plot_array(array=fit.model_image, title="Model", ax=axes[0,2]) + │ lines=[critical_curves], + ├── ... + │ + ├── save_figure(fig, path=output_path, filename="subplot_fit") + └── plt.close(fig) + +Plotter.figure_convergence(ax=None) + │ + ├── owns_figure = ax is None + ├── if owns_figure: fig, ax = plt.subplots(1, 1, figsize=conf_figsize("figures")) + │ + ├── plot_array(array=tracer.convergence, title="Convergence", ax=ax, + │ lines=critical_curves + radial_curves) + │ + └── if owns_figure: save_figure(fig, ...) ; plt.close(fig) +``` + +Key rules: +1. Every `plot_*` function accepts an optional `ax` parameter. + - `ax=None` → creates its own figure, saves/shows, closes. + - `ax` provided → draws onto it, does **not** save/show/close (caller is responsible). +2. Overlay data (critical curves, caustics, centres, positions, …) are plain + `List[np.ndarray]` arguments, not `Visuals` objects. +3. `figsize` is the only value read from config; all other defaults are function + parameter defaults visible in source code. +4. Ticks: 3 linear ticks generated with `np.linspace` from the axis extent. + Colorbar log-scaling uses `matplotlib.colors.LogNorm` passed to `imshow`. + +--- + +## `save_figure` replacing `Output` + +```python +# autoarray/plot/plots/utils.py + +def save_figure( + fig: plt.Figure, + path: str, + filename: str, + format: str = "png", + dpi: int = 300, +) -> None: + """Save fig to /. and close it.""" + os.makedirs(path, exist_ok=True) + fig.savefig( + os.path.join(path, f"{filename}.{format}"), + dpi=dpi, + bbox_inches="tight", + ) + plt.close(fig) +``` + +The plotter base class holds `output_path: str` and `output_format: str = "png"`. +Individual `figure_*` and `subplot_*` methods call `save_figure(fig, self.output_path, "name")`. +If `output_path` is empty, `plt.show()` is called instead of saving. + +--- + +## `conf_figsize` helper + +```python +# autoarray/plot/plots/utils.py + +def conf_figsize(context: str = "figures") -> Tuple[int, int]: + """Read figsize from visualize/general.yaml for 'figures' or 'subplots'.""" + return tuple(conf.instance["visualize"]["general"][context]["figsize"]) +``` + +`visualize/general.yaml` (only surviving config for plots): +```yaml +figures: + figsize: [7, 7] +subplots: + figsize: [19, 16] +``` + +--- + +## The 13 PRs + +### Phase 1 — PyAutoArray (5 PRs) + +--- + +#### PR A1 · New `autoarray/plot/plots/` module (additive, no deletions) + +Create the replacement plot functions. No existing code is touched; existing tests +continue to pass. + +``` +autoarray/plot/plots/ + __init__.py + utils.py → save_figure(), conf_figsize(), _make_ticks(), _apply_extent() + array.py → plot_array() + grid.py → plot_grid() + yx.py → plot_yx() + inversion.py → plot_inversion_reconstruction(), plot_inversion_mappings() +``` + +**`plot_array` signature (canonical example):** + +```python +def plot_array( + array: np.ndarray, + ax: Optional[plt.Axes] = None, + # overlays + mask: Optional[np.ndarray] = None, + grid: Optional[np.ndarray] = None, + positions: Optional[List[np.ndarray]] = None, + lines: Optional[List[np.ndarray]] = None, + vector_yx: Optional[np.ndarray] = None, + # cosmetics + title: str = "", + xlabel: str = "x (arcsec)", + ylabel: str = "y (arcsec)", + colormap: str = "jet", + vmin: Optional[float] = None, + vmax: Optional[float] = None, + use_log10: bool = False, + # figure control (used only when ax is None) + figsize: Optional[Tuple[int, int]] = None, + filename: Optional[str] = None, +) -> None: + owns_figure = ax is None + if owns_figure: + figsize = figsize or conf_figsize("figures") + fig, ax = plt.subplots(1, 1, figsize=figsize) + + norm = LogNorm() if use_log10 else None + if vmin is not None or vmax is not None: + norm = Normalize(vmin=vmin, vmax=vmax) + + im = ax.imshow(array, cmap=colormap, norm=norm, origin="lower") + plt.colorbar(im, ax=ax) + + if mask is not None: + ax.scatter(mask[:, 1], mask[:, 0], s=1, c="k") + if positions is not None: + for pos in positions: + ax.scatter(pos[:, 1], pos[:, 0], s=10, c="r") + if lines is not None: + for line in lines: + ax.plot(line[:, 1], line[:, 0], linewidth=2) + + ax.set_title(title, fontsize=16) + ax.set_xlabel(xlabel, fontsize=14) + ax.set_ylabel(ylabel, fontsize=14) + ax.tick_params(labelsize=12) + + if owns_figure: + if filename: + save_figure(fig, path=os.path.dirname(filename), + filename=os.path.basename(filename)) + else: + plt.show() + plt.close(fig) +``` + +**Ticks:** The 3-linear-tick approach from the reference example is baked into +`_make_ticks(extent)`: +```python +def _apply_extent(ax, extent): + """extent = [xmin, xmax, ymin, ymax]; apply axis limits and 3 linear ticks.""" + ax.set_xlim(extent[0], extent[1]) + ax.set_ylim(extent[2], extent[3]) + ax.set_xticks(np.linspace(extent[0], extent[1], 3)) + ax.set_yticks(np.linspace(extent[2], extent[3], 3)) +``` +No `XTicks` / `YTicks` classes needed. Log colorbars: pass `LogNorm()` to `imshow`; +matplotlib generates appropriate log-spaced colorbar ticks automatically. + +New unit tests: `test_autoarray/plot/plots/test_array.py` etc., asserting that +PNG files are written when a filename is provided. + +--- + +#### PR A2 · Update `Array2DPlotter` and `Grid2DPlotter` + +Switch the two most-used base plotters to the new functions. + +- Remove `mat_plot_2d`, `visuals_2d` constructor params. +- Add explicit overlay params: `mask`, `grid`, `positions`, `lines`. +- Each `figure_*` method calls `plot_array(..., ax=ax)` where `ax` defaults to `None`. +- `subplot_*` methods: create `fig, axes = plt.subplots(...)`, pass each `ax` slice. + +The subplot open/close/index machinery is **deleted**. A subplot method looks like: + +```python +def subplot_array(self): + fig, axes = plt.subplots(1, 2, figsize=conf_figsize("subplots")) + self.figure_array(ax=axes[0]) + self.figure_array_log10(ax=axes[1]) + save_figure(fig, self.output_path, "subplot_array", self.output_format) +``` + +No `subplot_index`, no `open_subplot_figure()`, no `close_subplot_figure()`. + +Existing test assertions about output filenames keep working because plotter +constructor accepts `output_path` and `output_filename` strings. + +--- + +#### PR A3 · Update `ImagingPlotter`, `InversionPlotter`, `MapperPlotter`, `InterferometerPlotter` + +Same `ax`-passing pattern. Mixed 1D/2D subplots (e.g. interferometer) use: + +```python +fig, axes = plt.subplots(2, 3, figsize=conf_figsize("subplots")) +plot_array(array=dirty_image, ax=axes[0, 0]) +plot_yx(y=visibilities.real, ax=axes[1, 0]) +``` + +`AbstractPlotter` base class is simplified to hold only: + +```python +class AbstractPlotter: + def __init__( + self, + output_path: str = "", + output_filename: str = "", + output_format: str = "png", + figsize_figures: Optional[Tuple] = None, + figsize_subplots: Optional[Tuple] = None, + ): + self.output_path = output_path + self.output_filename = output_filename + self.output_format = output_format + self.figsize_figures = figsize_figures or conf_figsize("figures") + self.figsize_subplots = figsize_subplots or conf_figsize("subplots") + + def _filename(self, name: str) -> Optional[str]: + if self.output_path: + return os.path.join(self.output_path, + f"{name}.{self.output_format}") + return None +``` + +No subplot state, no mat_plot slots, no visuals slots. + +--- + +#### PR A4 · Delete `mat_plot/`, `wrap/`, `visuals/` directories + +``` +autoarray/plot/mat_plot/ ← deleted (3 files) +autoarray/plot/wrap/ ← deleted (~40 files) +autoarray/plot/visuals/ ← deleted (3 files) +autoarray/config/visualize/mat_wrap.yaml ← deleted +autoarray/config/visualize/mat_wrap_1d.yaml ← deleted +autoarray/config/visualize/mat_wrap_2d.yaml ← deleted +``` + +Update `autoarray/plot/__init__.py` to remove all `MatPlot*`, `Visuals*`, `MatWrap*` +re-exports. Tests that imported these classes directly are deleted or rewritten. + +--- + +#### PR A5 · Simplify config; finalise `save_figure` / `conf_figsize` + +`visualize/general.yaml` after cleanup: + +```yaml +figures: + figsize: [7, 7] +subplots: + figsize: [19, 16] +``` + +All other YAML files that existed purely for MatWrap defaults are deleted. +`plots.yaml` (which controls whether `subplot_fit` etc. are auto-generated during +analysis runs) is **kept unchanged**. + +--- + +### Phase 2 — PyAutoGalaxy (4 PRs) + +--- + +#### PR G1 · New galaxy overlay helpers (additive) + +``` +autogalaxy/plot/plots/ + __init__.py + overlays.py → overlay_critical_curves(ax, curves, color="w", linewidth=2) + overlay_caustics(ax, curves, color="y", linewidth=2) + overlay_light_profile_centres(ax, centres, marker="+", s=40) + overlay_mass_profile_centres(ax, centres, marker="x", s=40) + overlay_multiple_images(ax, positions, marker="o", s=40) +``` + +These are pure overlay helpers that accept an `ax` and draw onto it. +They have no config dependency. + +--- + +#### PR G2 · Update `LightProfilePlotter`, `MassProfilePlotter`, `GalaxyPlotter`, `GalaxiesPlotter` + +Each plotter computes its own overlay data from its galaxy/profile then passes it +to `plot_array`: + +```python +class GalaxiesPlotter(AbstractPlotter): + def figure_image(self, ax=None): + owns = ax is None + if owns: + fig, ax = plt.subplots(figsize=self.figsize_figures) + + array = self.galaxies.image_2d_from(grid=self.grid) + plot_array(array=array.native, ax=ax, title="Image", + lines=self._critical_curves() + self._caustics()) + _apply_extent(ax, self._extent()) + + if owns: + save_figure(fig, self.output_path, "image", self.output_format) +``` + +Remove autogalaxy `MatPlot2D` subclass and autogalaxy `Visuals2D` subclass. + +--- + +#### PR G3 · Update autogalaxy `FitImagingPlotter` and `FitInterferometerPlotter` + +```python +def subplot_fit(self): + fig, axes = plt.subplots(3, 4, figsize=self.figsize_subplots) + + plot_array(array=self.fit.data.native, title="Data", ax=axes[0, 0]) + plot_array(array=self.fit.noise_map.native, title="Noise", ax=axes[0, 1]) + # ... etc., no subplot_index needed + + save_figure(fig, self.output_path, "subplot_fit", self.output_format) +``` + +--- + +#### PR G4 · Remove autogalaxy MatPlot/Visuals extensions + +``` +autogalaxy/plot/mat_plot/ ← deleted +autogalaxy/plot/visuals/ ← deleted +``` + +Update `autogalaxy/plot/__init__.py`. + +--- + +### Phase 3 — PyAutoLens (4 PRs) + +--- + +#### PR L1 · Update `TracerPlotter` + +The plotter computes critical curves / caustics itself from the tracer, then passes +them as `lines` to `plot_array`: + +```python +class TracerPlotter(AbstractPlotter): + def figure_convergence(self, ax=None): + owns = ax is None + if owns: + fig, ax = plt.subplots(figsize=self.figsize_figures) + + array = self.tracer.convergence_2d_from(self.grid) + tang = self.tracer.tangential_critical_curves_from(self.grid) + rad = self.tracer.radial_critical_curves_from(self.grid) + + plot_array(array=array.native, ax=ax, title="Convergence", + lines=tang + rad) + + if owns: + save_figure(fig, self.output_path, + self.output_filename or "convergence") + + def subplot_tracer(self): + fig, axes = plt.subplots(3, 3, figsize=self.figsize_subplots) + + self.figure_image(ax=axes[0, 0]) + self.figure_source_plane(ax=axes[0, 1]) + self.figure_convergence(ax=axes[0, 2]) + self.figure_potential(ax=axes[1, 0]) + self.figure_magnification(ax=axes[1, 1]) + self.figure_deflections_y(ax=axes[1, 2]) + self.figure_deflections_x(ax=axes[2, 0]) + axes[2, 1].set_visible(False) + axes[2, 2].set_visible(False) + + save_figure(fig, self.output_path, "subplot_tracer") +``` + +Constructor: remove `mat_plot_2d`, `visuals_2d`, `visuals_2d_of_planes_list`. +Add `show_critical_curves: bool = True`, `show_caustics: bool = True`. + +--- + +#### PR L2 · Update `FitImagingPlotter` + +Largest single plotter. The 12-panel `subplot_fit` becomes: + +```python +def subplot_fit(self): + fig, axes = plt.subplots(3, 4, figsize=self.figsize_subplots) + + plot_array(array=self.fit.data.native, + title="Data", ax=axes[0, 0]) + plot_array(array=self.fit.signal_to_noise_map.native, + title="Signal-To-Noise Map", ax=axes[0, 1]) + plot_array(array=self.fit.model_image.native, + title="Model Image", ax=axes[0, 2], + lines=self._tangential_critical_curves()) + # leave axes[0, 3] blank or use for something else + + # plane decomposition (delegate to sub-plotter with explicit ax) + tracer_plotter = self.tracer_plotter_of_plane(plane_index=0) + tracer_plotter.figure_plane_image(ax=axes[1, 0]) + + plot_array(array=self.fit.residual_map.native, + title="Residual Map", ax=axes[2, 0], + colormap="coolwarm", vmin=-0.1, vmax=0.1) + plot_array(array=self.fit.normalized_residual_map.native, + title="Normalised Residual Map", ax=axes[2, 1], + colormap="coolwarm", vmin=-3, vmax=3) + plot_array(array=self.fit.chi_squared_map.native, + title="Chi-Squared Map", ax=axes[2, 2]) + + save_figure(fig, self.output_path, "subplot_fit") +``` + +No `subplot_index`, no `open_subplot_figure`, no `close_subplot_figure`, +no 1D/2D sync. + +Per-plane subplot (`subplot_of_planes`) creates its own figure: +```python +def subplot_of_planes(self): + n = len(self.fit.tracer.planes) + fig, axes = plt.subplots(1, n * 4, figsize=(n * 4 * 4, 4)) + for i in range(n): + ... +``` + +--- + +#### PR L3 · Update `FitInterferometerPlotter`, `PointDatasetPlotter`, `FitPointDatasetPlotter` + +**PointDatasetPlotter** — mixed 1D/2D, which was the "nasty hack" case: + +```python +def subplot_dataset(self): + fig, axes = plt.subplots(1, 2, figsize=self.figsize_subplots) + + plot_grid(grid=self.dataset.positions.array, + y_errors=self.dataset.positions_noise_map.array, + title="Positions", ax=axes[0]) + + plot_yx(y=self.dataset.fluxes.array, + y_errors=self.dataset.fluxes_noise_map.array, + title="Fluxes", ax=axes[1]) + + save_figure(fig, self.output_path, "subplot_dataset") +``` + +No sync hack: `axes[0]` is independent from `axes[1]`, they are just different `Axes` +objects obtained from the same `plt.subplots()` call. + +--- + +#### PR L4 · Update `SubhaloPlotter`, `SubhaloSensitivityPlotter`; clean up `autolens/plot/abstract_plotters.py` + +`autolens/plot/abstract_plotters.py` final form: + +```python +from autogalaxy.plot.abstract_plotters import AbstractPlotter + +class Plotter(AbstractPlotter): + """PyAutoLens plotter base — no MatPlot or Visuals slots.""" + pass +``` + +SubhaloPlotter significance maps use `plot_array` with an `ArrayOverlay` equivalent: + +```python +plot_array( + array=self.result.figure_of_merit_array().native, + title="Subhalo Detection Significance", + ax=ax, + positions=self.result.subhalo_centres_grid.array, +) +``` + +--- + +## Summary table + +| PR | Repo | Change type | Tests | +|---|---|---|---| +| A1 | autoarray | Add `plots/` module | New unit tests | +| A2 | autoarray | Rewrite Array2D/Grid2DPlotter | Update existing | +| A3 | autoarray | Rewrite Imaging/Inversion/Mapper/InterferometerPlotter | Update existing | +| A4 | autoarray | Delete mat_plot/, wrap/, visuals/ | Delete wrap tests | +| A5 | autoarray | Config cleanup, finalise helpers | Smoke tests | +| G1 | autogalaxy | Add overlay helpers | New unit tests | +| G2 | autogalaxy | Rewrite Galaxy/Mass/LightProfile plotters | Update existing | +| G3 | autogalaxy | Rewrite FitImaging/FitInterferometer plotters | Update existing | +| G4 | autogalaxy | Delete MatPlot2D/Visuals2D extensions | Delete wrap tests | +| L1 | autolens | Rewrite TracerPlotter | Update existing | +| L2 | autolens | Rewrite FitImagingPlotter | Update existing | +| L3 | autolens | Rewrite FitInterferometer/Point plotters | Update existing | +| L4 | autolens | Rewrite Subhalo plotters, clean abstract_plotters | Update existing | + +--- + +## Design rules applied consistently across all PRs + +1. **`ax` parameter on every `figure_*` method and every `plot_*` function.** + `ax=None` → owns the figure (creates, saves, closes). + `ax` provided → draws only, caller owns the figure lifecycle. + +2. **Overlay data as typed list/array args.** + `lines: List[np.ndarray]` replaces `Visuals2D.tangential_critical_curves` etc. + `positions: List[np.ndarray]` replaces `Visuals2D.positions`. + No `Visuals` objects anywhere. + +3. **No subplot state machine.** + `plt.subplots(rows, cols)` returns `axes`; pass each `ax` slice explicitly. + No `subplot_index`, no `open_subplot_figure`, no `close_subplot_figure`. + Blank panels: `ax.set_visible(False)`. + +4. **`figsize` from config only.** Every other default (fontsize, colormap, marker size, + linewidth, …) is an inline function-parameter default, visible in source code. + +5. **Linear ticks: `np.linspace(lo, hi, 3)`.** + Log colorbars: pass `matplotlib.colors.LogNorm()` to `imshow`; matplotlib generates + log-spaced colorbar ticks automatically — no custom tick class needed. + +6. **`save_figure(fig, path, filename, format, dpi)` replaces `Output`.** + If `output_path` is empty string, call `plt.show()` + `plt.close()` instead of saving. + +7. **No deprecation warnings.** Old `mat_plot_2d` / `visuals_2d` constructor parameters + are simply removed; callers are updated in the same PR. diff --git a/autolens/analysis/plotter_interface.py b/autolens/analysis/plotter_interface.py index 7365018d7..7c6596652 100644 --- a/autolens/analysis/plotter_interface.py +++ b/autolens/analysis/plotter_interface.py @@ -1,184 +1,163 @@ -import ast -import numpy as np -from typing import Optional - -from autoconf import conf -from autoconf.fitsable import hdu_list_for_output_from - -import autoarray as aa -import autogalaxy as ag -import autogalaxy.plot as aplt - -from autogalaxy.analysis.plotter_interface import plot_setting - -from autogalaxy.analysis.plotter_interface import PlotterInterface as AgPlotterInterface - -from autolens.lens.tracer import Tracer -from autolens.lens.plot.tracer_plotters import TracerPlotter - - -class PlotterInterface(AgPlotterInterface): - """ - Visualizes the maximum log likelihood model of a model-fit, including components of the model and fit objects. - - The methods of the `PlotterInterface` are called throughout a non-linear search using the `Analysis` - classes `visualize` method. - - The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml`. - - Parameters - ---------- - image_path - The path on the hard-disk to the `image` folder of the non-linear searches results. - """ - - def tracer( - self, - tracer: Tracer, - grid: aa.type.Grid2DLike, - visuals_2d_of_planes_list: Optional[aplt.Visuals2D] = None, - ): - """ - Visualizes a `Tracer` object. - - Images are output to the `image` folder of the `image_path`. When used with a non-linear search the `image_path` - points to the search's results folder and this function visualizes the maximum log likelihood `Tracer` - inferred by the search so far. - - Visualization includes a subplot of individual images of attributes of the tracer (e.g. its image, convergence, - deflection angles) and .fits files containing its attributes grouped together. - - The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml` under - the `tracer` header. - - Parameters - ---------- - tracer - The maximum log likelihood `Tracer` of the non-linear search. - grid - A 2D grid of (y,x) arc-second coordinates used to perform ray-tracing, which is the masked grid tied to - the dataset. - """ - - def should_plot(name): - return plot_setting(section="tracer", name=name) - - mat_plot_2d = self.mat_plot_2d_from() - - tracer_plotter = TracerPlotter( - tracer=tracer, - grid=grid, - mat_plot_2d=mat_plot_2d, - visuals_2d_of_planes_list=visuals_2d_of_planes_list, - ) - - if should_plot("subplot_galaxies_images"): - tracer_plotter.subplot_galaxies_images() - - if should_plot("fits_tracer"): - - zoom = aa.Zoom2D(mask=grid.mask) - mask = zoom.mask_2d_from(buffer=1) - grid_zoom = aa.Grid2D.from_mask(mask=mask) - - image_list = [ - tracer.convergence_2d_from(grid=grid_zoom).native, - tracer.potential_2d_from(grid=grid_zoom).native, - tracer.deflections_yx_2d_from(grid=grid_zoom).native[:, :, 0], - tracer.deflections_yx_2d_from(grid=grid_zoom).native[:, :, 1], - ] - - hdu_list = hdu_list_for_output_from( - values_list=[image_list[0].mask.astype("float")] + image_list, - ext_name_list=[ - "mask", - "convergence", - "potential", - "deflections_y", - "deflections_x", - ], - header_dict=grid_zoom.mask.header_dict, - ) - - hdu_list.writeto(self.image_path / "tracer.fits", overwrite=True) - - if should_plot("fits_source_plane_images"): - - shape_native = conf.instance["visualize"]["plots"]["tracer"][ - "fits_source_plane_shape" - ] - shape_native = ast.literal_eval(shape_native) - - zoom = aa.Zoom2D(mask=grid.mask) - mask = zoom.mask_2d_from(buffer=1) - grid_source_plane = aa.Grid2D.from_extent( - extent=mask.geometry.extent, shape_native=tuple(shape_native) - ) - - image_list = [grid_source_plane.mask.astype("float")] - ext_name_list = ["mask"] - - for i, plane in enumerate(tracer.planes[1:]): - - if plane.has(cls=ag.LightProfile): - - image = plane.image_2d_from( - grid=grid_source_plane, - ).native - - else: - - image = np.zeros(grid_source_plane.shape_native) - - image_list.append(image) - ext_name_list.append(f"source_plane_image_{i+1}") - - hdu_list = hdu_list_for_output_from( - values_list=image_list, - ext_name_list=ext_name_list, - header_dict=grid_source_plane.mask.header_dict, - ) - - hdu_list.writeto( - self.image_path / "source_plane_images.fits", overwrite=True - ) - - def image_with_positions(self, image: aa.Array2D, positions: aa.Grid2DIrregular): - """ - Visualizes the positions of a model-fit, where these positions are used to penalize lens models where - the positions to do trace within an input threshold of one another in the source-plane. - - Images are output to the `image` folder of the `image_path`. When used with a non-linear search the `image_path` - is the output folder of the non-linear search. - - The visualization is an image of the strong lens with the positions overlaid. - - The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml` under the - `positions` header. - - Parameters - ---------- - imaging - The imaging dataset whose image the positions are overlaid. - positions - The 2D (y,x) arc-second positions used to penalize inaccurate mass models. - """ - - def should_plot(name): - return plot_setting(section=["positions"], name=name) - - mat_plot_2d = self.mat_plot_2d_from() - - if positions is not None: - visuals_2d = aplt.Visuals2D(positions=positions) - - image_plotter = aplt.Array2DPlotter( - array=image, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - image_plotter.set_filename("image_with_positions") - - if should_plot("image_with_positions"): - image_plotter.figure_2d() +import ast +import numpy as np +from typing import Optional + +from autoconf import conf +from autoconf.fitsable import hdu_list_for_output_from + +import autoarray as aa +import autogalaxy as ag + +from autogalaxy.analysis.plotter_interface import plot_setting + +from autogalaxy.analysis.plotter_interface import PlotterInterface as AgPlotterInterface + +from autolens.lens.tracer import Tracer +from autolens.lens.plot.tracer_plots import subplot_galaxies_images +from autoarray.plot.plots.array import plot_array + + +class PlotterInterface(AgPlotterInterface): + """ + Visualizes the maximum log likelihood model of a model-fit, including components of the model and fit objects. + + The methods of the `PlotterInterface` are called throughout a non-linear search using the `Analysis` + classes `visualize` method. + + The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml`. + + Parameters + ---------- + image_path + The path on the hard-disk to the `image` folder of the non-linear searches results. + """ + + def tracer( + self, + tracer: Tracer, + grid: aa.type.Grid2DLike, + ): + """ + Visualizes a `Tracer` object. + + Parameters + ---------- + tracer + The maximum log likelihood `Tracer` of the non-linear search. + grid + A 2D grid of (y,x) arc-second coordinates used to perform ray-tracing. + """ + + def should_plot(name): + return plot_setting(section="tracer", name=name) + + output_path = str(self.image_path) + fmt = self.fmt + + if should_plot("subplot_galaxies_images"): + subplot_galaxies_images( + tracer=tracer, + grid=grid, + output_path=output_path, + output_format=fmt, + ) + + if should_plot("fits_tracer"): + + zoom = aa.Zoom2D(mask=grid.mask) + mask = zoom.mask_2d_from(buffer=1) + grid_zoom = aa.Grid2D.from_mask(mask=mask) + + image_list = [ + tracer.convergence_2d_from(grid=grid_zoom).native, + tracer.potential_2d_from(grid=grid_zoom).native, + tracer.deflections_yx_2d_from(grid=grid_zoom).native[:, :, 0], + tracer.deflections_yx_2d_from(grid=grid_zoom).native[:, :, 1], + ] + + hdu_list = hdu_list_for_output_from( + values_list=[image_list[0].mask.astype("float")] + image_list, + ext_name_list=[ + "mask", + "convergence", + "potential", + "deflections_y", + "deflections_x", + ], + header_dict=grid_zoom.mask.header_dict, + ) + + hdu_list.writeto(self.image_path / "tracer.fits", overwrite=True) + + if should_plot("fits_source_plane_images"): + + shape_native = conf.instance["visualize"]["plots"]["tracer"][ + "fits_source_plane_shape" + ] + shape_native = ast.literal_eval(shape_native) + + zoom = aa.Zoom2D(mask=grid.mask) + mask = zoom.mask_2d_from(buffer=1) + grid_source_plane = aa.Grid2D.from_extent( + extent=mask.geometry.extent, shape_native=tuple(shape_native) + ) + + image_list = [grid_source_plane.mask.astype("float")] + ext_name_list = ["mask"] + + for i, plane in enumerate(tracer.planes[1:]): + + if plane.has(cls=ag.LightProfile): + + image = plane.image_2d_from( + grid=grid_source_plane, + ).native + + else: + + image = np.zeros(grid_source_plane.shape_native) + + image_list.append(image) + ext_name_list.append(f"source_plane_image_{i+1}") + + hdu_list = hdu_list_for_output_from( + values_list=image_list, + ext_name_list=ext_name_list, + header_dict=grid_source_plane.mask.header_dict, + ) + + hdu_list.writeto( + self.image_path / "source_plane_images.fits", overwrite=True + ) + + def image_with_positions(self, image: aa.Array2D, positions: aa.Grid2DIrregular): + """ + Visualizes the positions of a model-fit. + + Parameters + ---------- + image + The imaging dataset whose image the positions are overlaid. + positions + The 2D (y,x) arc-second positions used to penalize inaccurate mass models. + """ + + def should_plot(name): + return plot_setting(section=["positions"], name=name) + + if positions is not None and should_plot("image_with_positions"): + pos_arr = np.array( + positions.array if hasattr(positions, "array") else positions + ) + + fmt = self.fmt + if isinstance(fmt, (list, tuple)): + fmt = fmt[0] + + plot_array( + array=image, + positions=[pos_arr], + output_path=str(self.image_path), + output_filename="image_with_positions", + output_format=fmt, + ) diff --git a/autolens/imaging/model/plotter_interface.py b/autolens/imaging/model/plotter_interface.py index 3d72c3a8d..420b77d56 100644 --- a/autolens/imaging/model/plotter_interface.py +++ b/autolens/imaging/model/plotter_interface.py @@ -1,229 +1,124 @@ -from typing import List, Optional - -import autoarray.plot as aplt - -from autogalaxy.imaging.model.plotter_interface import PlotterInterfaceImaging as AgPlotterInterfaceImaging - -from autogalaxy.imaging.model.plotter_interface import fits_to_fits - -from autolens.analysis.plotter_interface import PlotterInterface -from autolens.imaging.fit_imaging import FitImaging -from autolens.imaging.plot.fit_imaging_plotters import FitImagingPlotter - -from autolens.analysis.plotter_interface import plot_setting - - -class PlotterInterfaceImaging(PlotterInterface): - - imaging = AgPlotterInterfaceImaging.imaging - imaging_combined = AgPlotterInterfaceImaging.imaging_combined - - def fit_imaging( - self, fit: FitImaging, visuals_2d_of_planes_list : Optional[aplt.Visuals2D] = None, quick_update: bool = False - ): - """ - Visualizes a `FitImaging` object, which fits an imaging dataset. - - Images are output to the `image` folder of the `image_path`. When used with a non-linear search the `image_path` - points to the search's results folder and this function visualizes the maximum log likelihood `FitImaging` - inferred by the search so far. - - Visualization includes a subplot of individual images of attributes of the `FitImaging` (e.g. the model data, - residual map) and .fits files containing its attributes grouped together. - - The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml` under - the `fit` and `fit_imaging` header. - - Parameters - ---------- - fit - The maximum log likelihood `FitImaging` of the non-linear search which is used to plot the fit. - """ - - def should_plot(name): - return plot_setting(section=["fit", "fit_imaging"], name=name) - - mat_plot_2d = self.mat_plot_2d_from(quick_update=quick_update) - - fit_plotter = FitImagingPlotter( - fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list, - ) - - plane_indexes_to_plot = [i for i in fit.tracer.plane_indexes_with_images if i != 0] - - if should_plot("subplot_fit") or quick_update: - - # This loop means that multiple subplot_fit objects are output for a double source plane lens. - - if len(fit.tracer.planes) > 2: - for plane_index in plane_indexes_to_plot: - fit_plotter.subplot_fit(plane_index=plane_index) - else: - fit_plotter.subplot_fit() - - if quick_update: - return - - if plot_setting(section="tracer", name="subplot_tracer"): - - mat_plot_2d = self.mat_plot_2d_from() - - fit_plotter = FitImagingPlotter( - fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list, - ) - - fit_plotter.subplot_tracer() - - if should_plot("subplot_fit_log10"): - - try: - if len(fit.tracer.planes) > 2: - for plane_index in plane_indexes_to_plot: - fit_plotter.subplot_fit_log10(plane_index=plane_index) - else: - fit_plotter.subplot_fit_log10() - except ValueError: - pass - - if should_plot("subplot_of_planes"): - fit_plotter.subplot_of_planes() - - if plot_setting(section="inversion", name="subplot_mappings"): - try: - fit_plotter.subplot_mappings_of_plane(plane_index=len(fit.tracer.planes) - 1) - except IndexError: - pass - - fits_to_fits(should_plot=should_plot, image_path=self.image_path, fit=fit) - - def fit_imaging_combined( - self, - fit_list: List[FitImaging], - visuals_2d_of_planes_list : Optional[aplt.Visuals2D] = None, - quick_update: bool = False, - ): - """ - Output visualization of all `FitImaging` objects in a summed combined analysis, typically during or after a - model-fit is performed. - - Images are output to the `image` folder of the `image_path`. When used with a non-linear search the `image_path` - is the output folder of the non-linear search. - - Visualization includes a subplot of individual images of attributes of each fit (e.g. data, normalized - residual-map) on a single subplot, such that the full suite of multiple datasets can be viewed on the same figure. - - The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml` under - the `fit` and `fit_imaging` headers. - - Parameters - ---------- - fit_list - The list of imaging fits which are visualized. - """ - - def should_plot(name): - return plot_setting(section=["fit", "fit_imaging"], name=name) - - mat_plot_2d = self.mat_plot_2d_from(quick_update=quick_update) - - fit_plotter_list = [ - FitImagingPlotter( - fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list, - ) - for fit in fit_list - ] - - subplot_columns = 6 - - subplot_shape = (len(fit_list), subplot_columns) - - multi_plotter = aplt.MultiFigurePlotter( - plotter_list=fit_plotter_list, subplot_shape=subplot_shape - ) - - if should_plot("subplot_fit") or quick_update: - - def make_subplot_fit(filename_suffix): - - multi_plotter.subplot_of_figures_multi( - func_name_list=["figures_2d"], - figure_name_list=[ - "data", - ], - filename_suffix=filename_suffix, - number_subplots=len(fit_list) * subplot_columns, - close_subplot=False, - ) - - multi_plotter.subplot_of_figures_multi( - func_name_list=["figures_2d_of_planes"], - figure_name_list=[ - "subtracted_image", - ], - filename_suffix=filename_suffix, - number_subplots=len(fit_list) * subplot_columns, - open_subplot=False, - close_subplot=False, - subplot_index_offset=1, - plane_index=1 - ) - - multi_plotter.subplot_of_figures_multi( - func_name_list=["figures_2d_of_planes"], - figure_name_list=[ - "model_image", - ], - filename_suffix=filename_suffix, - number_subplots=len(fit_list) * subplot_columns, - open_subplot=False, - close_subplot=False, - subplot_index_offset=2, - plane_index=0 - ) - - multi_plotter.subplot_of_figures_multi( - func_name_list=["figures_2d_of_planes"], - figure_name_list=[ - "model_image", - ], - filename_suffix=filename_suffix, - number_subplots=len(fit_list) * subplot_columns, - open_subplot=False, - close_subplot=False, - subplot_index_offset=3, - plane_index=len(fit_list[0].tracer.planes) - 1 - ) - - multi_plotter.subplot_of_figures_multi( - func_name_list=["figures_2d_of_planes"], - figure_name_list=[ - "plane_image", - ], - filename_suffix=filename_suffix, - number_subplots=len(fit_list) * subplot_columns, - open_subplot=False, - close_subplot=False, - subplot_index_offset=4, - plane_index=len(fit_list[0].tracer.planes) - 1 - ) - - multi_plotter.subplot_of_figures_multi( - func_name_list=["figures_2d"], - figure_name_list=[ - "normalized_residual_map", - ], - filename_suffix=filename_suffix, - number_subplots=len(fit_list) * subplot_columns, - subplot_index_offset=5, - open_subplot=False, - ) - - make_subplot_fit(filename_suffix="fit_combined") - - if quick_update: - return - - for plotter in multi_plotter.plotter_list: - plotter.mat_plot_2d.use_log10 = True - - make_subplot_fit(filename_suffix="fit_combined_log10") \ No newline at end of file +import matplotlib.pyplot as plt +import numpy as np +from typing import List + +from autogalaxy.imaging.model.plotter_interface import PlotterInterfaceImaging as AgPlotterInterfaceImaging +from autogalaxy.imaging.model.plotter_interface import fits_to_fits + +from autolens.analysis.plotter_interface import PlotterInterface +from autolens.imaging.fit_imaging import FitImaging +from autolens.imaging.plot.fit_imaging_plots import ( + subplot_fit, + subplot_fit_log10, + subplot_of_planes, + subplot_tracer_from_fit, + subplot_fit_combined, + subplot_fit_combined_log10, +) + +from autolens.analysis.plotter_interface import plot_setting + + +class PlotterInterfaceImaging(PlotterInterface): + + imaging = AgPlotterInterfaceImaging.imaging + imaging_combined = AgPlotterInterfaceImaging.imaging_combined + + def fit_imaging( + self, fit: FitImaging, quick_update: bool = False + ): + """ + Visualizes a `FitImaging` object, which fits an imaging dataset. + + Parameters + ---------- + fit + The maximum log likelihood `FitImaging` of the non-linear search. + quick_update + If True only the essential subplot_fit is output. + """ + + def should_plot(name): + return plot_setting(section=["fit", "fit_imaging"], name=name) + + output_path = str(self.image_path) + fmt = self.fmt + + plane_indexes_to_plot = [i for i in fit.tracer.plane_indexes_with_images if i != 0] + + if should_plot("subplot_fit") or quick_update: + + if len(fit.tracer.planes) > 2: + for plane_index in plane_indexes_to_plot: + subplot_fit(fit, output_path=output_path, output_format=fmt, + plane_index=plane_index) + else: + subplot_fit(fit, output_path=output_path, output_format=fmt) + + if quick_update: + return + + if plot_setting(section="tracer", name="subplot_tracer"): + subplot_tracer_from_fit(fit, output_path=output_path, output_format=fmt) + + if should_plot("subplot_fit_log10"): + try: + if len(fit.tracer.planes) > 2: + for plane_index in plane_indexes_to_plot: + subplot_fit_log10(fit, output_path=output_path, output_format=fmt, + plane_index=plane_index) + else: + subplot_fit_log10(fit, output_path=output_path, output_format=fmt) + except ValueError: + pass + + if should_plot("subplot_of_planes"): + subplot_of_planes(fit, output_path=output_path, output_format=fmt) + + if plot_setting(section="inversion", name="subplot_mappings"): + try: + import autogalaxy.plot as aplt + output = self.output_from() + inversion_plotter = aplt.InversionPlotter( + inversion=fit.inversion, + mat_plot_2d=aplt.MatPlot2D( + output=aplt.Output(path=self.image_path, format=fmt), + ), + ) + pixelization_index = 0 + inversion_plotter.subplot_of_mapper( + mapper_index=pixelization_index, + auto_filename=f"subplot_mappings_{pixelization_index}", + ) + except (IndexError, AttributeError, TypeError, Exception): + pass + + fits_to_fits(should_plot=should_plot, image_path=self.image_path, fit=fit) + + def fit_imaging_combined( + self, + fit_list: List[FitImaging], + quick_update: bool = False, + ): + """ + Output visualization of all `FitImaging` objects in a summed combined analysis. + + Parameters + ---------- + fit_list + The list of imaging fits which are visualized. + """ + + def should_plot(name): + return plot_setting(section=["fit", "fit_imaging"], name=name) + + output_path = str(self.image_path) + fmt = self.fmt + + if should_plot("subplot_fit") or quick_update: + subplot_fit_combined(fit_list, output_path=output_path, output_format=fmt) + + if quick_update: + return + + subplot_fit_combined_log10(fit_list, output_path=output_path, output_format=fmt) diff --git a/autolens/imaging/model/visualizer.py b/autolens/imaging/model/visualizer.py index fcd30e822..d28e2d07a 100644 --- a/autolens/imaging/model/visualizer.py +++ b/autolens/imaging/model/visualizer.py @@ -7,7 +7,6 @@ from autolens.imaging.model.plotter_interface import PlotterInterfaceImaging -from autolens.lens import tracer_util from autolens import exc logger = logging.getLogger(__name__) @@ -97,11 +96,6 @@ def visualize( fit = analysis.fit_from(instance=instance) tracer = fit.tracer_linear_light_profiles_to_light_profiles - visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from( - tracer=fit.tracer, - grid=fit.grids.lp.mask.derive_grid.all_false - ) - plotter_interface = PlotterInterfaceImaging( image_path=paths.image_path, title_prefix=analysis.title_prefix, @@ -110,7 +104,6 @@ def visualize( try: plotter_interface.fit_imaging( fit=fit, - visuals_2d_of_planes_list=visuals_2d_of_planes_list, quick_update=quick_update, ) except exc.InversionException: @@ -152,7 +145,6 @@ def visualize( plotter_interface.tracer( tracer=tracer, grid=grid, - visuals_2d_of_planes_list=visuals_2d_of_planes_list ) plotter_interface.galaxies( galaxies=tracer.galaxies, diff --git a/autolens/imaging/plot/fit_imaging_plots.py b/autolens/imaging/plot/fit_imaging_plots.py new file mode 100644 index 000000000..db3605439 --- /dev/null +++ b/autolens/imaging/plot/fit_imaging_plots.py @@ -0,0 +1,514 @@ +import matplotlib.pyplot as plt +import numpy as np +from typing import Optional, List + +import autoarray as aa +import autogalaxy as ag + +from autoarray.plot.plots.array import plot_array, _zoom_array_2d +from autoarray.plot.plots.utils import save_figure +from autolens.plot.plot_utils import ( + _to_lines, + _critical_curves_from, + _caustics_from, +) + + +def _get_source_vmax(fit): + """Return vmax based on source-plane model images, or None.""" + try: + return float(np.max([mi.array for mi in fit.model_images_of_planes_list[1:]])) + except (ValueError, IndexError): + return None + + +def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True, + colormap="jet", use_log10=False): + """Plot source plane image or inversion reconstruction into *ax*.""" + tracer = fit.tracer_linear_light_profiles_to_light_profiles + if not tracer.planes[plane_index].has(cls=aa.Pixelization): + zoom = aa.Zoom2D(mask=fit.mask) + grid = aa.Grid2D.from_extent( + extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native + ) + traced_grids = tracer.traced_grid_2d_list_from(grid=grid) + plane_galaxies = ag.Galaxies(galaxies=tracer.planes[plane_index]) + image = plane_galaxies.image_2d_from(grid=traced_grids[plane_index]) + plot_array( + array=image, ax=ax, + title=f"Source Plane {plane_index}", + colormap=colormap, use_log10=use_log10, + ) + else: + # Inversion path: in subplot context show a blank panel. + if ax is not None: + ax.axis("off") + ax.set_title(f"Source Reconstruction (plane {plane_index})") + + +def subplot_fit( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + plane_index: Optional[int] = None, +): + """12-panel subplot of the imaging fit. + + For single-plane tracers delegates to :func:`subplot_fit_x1_plane`. + """ + if len(fit.tracer.planes) == 1: + return subplot_fit_x1_plane(fit, output_path=output_path, + output_format=output_format, colormap=colormap) + + plane_index_tag = "" if plane_index is None else f"_{plane_index}" + final_plane_index = ( + len(fit.tracer.planes) - 1 if plane_index is None else plane_index + ) + + source_vmax = _get_source_vmax(fit) + + fig, axes = plt.subplots(3, 4, figsize=(28, 21)) + axes_flat = list(axes.flatten()) + + plot_array(array=fit.data, ax=axes_flat[0], title="Data", colormap=colormap) + + # Data at source scale + plot_array(array=fit.data, ax=axes_flat[1], title="Data (Source Scale)", + colormap=colormap, vmax=source_vmax) + + plot_array(array=fit.signal_to_noise_map, ax=axes_flat[2], + title="Signal-To-Noise Map", colormap=colormap) + plot_array(array=fit.model_data, ax=axes_flat[3], title="Model Image", + colormap=colormap) + + # Lens model image + try: + lens_model_img = fit.model_images_of_planes_list[0] + except (IndexError, AttributeError): + lens_model_img = None + if lens_model_img is not None: + plot_array(array=lens_model_img, ax=axes_flat[4], + title="Lens Light Model Image", colormap=colormap) + else: + axes_flat[4].axis("off") + + # Subtracted image at source scale + try: + subtracted_img = fit.subtracted_images_of_planes_list[final_plane_index] + except (IndexError, AttributeError): + subtracted_img = None + if subtracted_img is not None: + plot_array(array=subtracted_img, ax=axes_flat[5], title="Lens Light Subtracted", + colormap=colormap, vmin=0.0 if source_vmax is not None else None, + vmax=source_vmax) + else: + axes_flat[5].axis("off") + + # Source model image at source scale + try: + source_model_img = fit.model_images_of_planes_list[final_plane_index] + except (IndexError, AttributeError): + source_model_img = None + if source_model_img is not None: + plot_array(array=source_model_img, ax=axes_flat[6], title="Source Model Image", + colormap=colormap, vmax=source_vmax) + else: + axes_flat[6].axis("off") + + # Source plane zoomed + _plot_source_plane(fit, axes_flat[7], final_plane_index, zoom_to_brightest=True, + colormap=colormap) + + # Normalized residual map (symmetric) + norm_resid = fit.normalized_residual_map + _abs_max = _symmetric_vmax(norm_resid) + plot_array(array=norm_resid, ax=axes_flat[8], title="Normalized Residual Map", + colormap=colormap, vmin=-_abs_max, vmax=_abs_max) + + # Normalized residual map clipped to [-1, 1] + plot_array(array=norm_resid, ax=axes_flat[9], + title=r"Normalized Residual Map $1\sigma$", + colormap=colormap, vmin=-1.0, vmax=1.0) + + plot_array(array=fit.chi_squared_map, ax=axes_flat[10], + title="Chi-Squared Map", colormap=colormap) + + # Source plane not zoomed + _plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False, + colormap=colormap) + + plt.tight_layout() + save_figure(fig, path=output_path, filename=f"subplot_fit{plane_index_tag}", format=output_format) + + +def subplot_fit_x1_plane( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", +): + """6-panel subplot for a single-plane tracer fit.""" + fig, axes = plt.subplots(2, 3, figsize=(21, 14)) + axes_flat = list(axes.flatten()) + + try: + vmax = float(np.max(fit.model_images_of_planes_list[0].array)) + except (IndexError, AttributeError, ValueError): + vmax = None + + plot_array(array=fit.data, ax=axes_flat[0], title="Data", colormap=colormap, vmax=vmax) + + plot_array(array=fit.signal_to_noise_map, ax=axes_flat[1], + title="Signal-To-Noise Map", colormap=colormap) + + plot_array(array=fit.model_data, ax=axes_flat[2], title="Model Image", + colormap=colormap, vmax=vmax) + + norm_resid = fit.normalized_residual_map + plot_array(array=norm_resid, ax=axes_flat[3], title="Lens Light Subtracted", + colormap=colormap) + + plot_array(array=norm_resid, ax=axes_flat[4], title="Subtracted Image Zero Minimum", + colormap=colormap, vmin=0.0) + + _abs_max = _symmetric_vmax(norm_resid) + plot_array(array=norm_resid, ax=axes_flat[5], title="Normalized Residual Map", + colormap=colormap, vmin=-_abs_max, vmax=_abs_max) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_fit_x1_plane", format=output_format) + + +def subplot_fit_log10( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + plane_index: Optional[int] = None, +): + """12-panel log10 subplot of the imaging fit.""" + if len(fit.tracer.planes) == 1: + return subplot_fit_log10_x1_plane(fit, output_path=output_path, + output_format=output_format, colormap=colormap) + + plane_index_tag = "" if plane_index is None else f"_{plane_index}" + final_plane_index = ( + len(fit.tracer.planes) - 1 if plane_index is None else plane_index + ) + + source_vmax = _get_source_vmax(fit) + + fig, axes = plt.subplots(3, 4, figsize=(28, 21)) + axes_flat = list(axes.flatten()) + + plot_array(array=fit.data, ax=axes_flat[0], title="Data", colormap=colormap, + use_log10=True) + + try: + plot_array(array=fit.data, ax=axes_flat[1], title="Data (Source Scale)", + colormap=colormap, use_log10=True) + except ValueError: + axes_flat[1].axis("off") + + try: + plot_array(array=fit.signal_to_noise_map, ax=axes_flat[2], + title="Signal-To-Noise Map", colormap=colormap, use_log10=True) + except ValueError: + axes_flat[2].axis("off") + + plot_array(array=fit.model_data, ax=axes_flat[3], title="Model Image", + colormap=colormap, use_log10=True) + + try: + lens_model_img = fit.model_images_of_planes_list[0] + plot_array(array=lens_model_img, ax=axes_flat[4], + title="Lens Light Model Image", colormap=colormap, use_log10=True) + except (IndexError, AttributeError): + axes_flat[4].axis("off") + + try: + subtracted_img = fit.subtracted_images_of_planes_list[final_plane_index] + plot_array(array=subtracted_img, ax=axes_flat[5], + title="Lens Light Subtracted", colormap=colormap, use_log10=True) + except (IndexError, AttributeError): + axes_flat[5].axis("off") + + try: + source_model_img = fit.model_images_of_planes_list[final_plane_index] + plot_array(array=source_model_img, ax=axes_flat[6], + title="Source Model Image", colormap=colormap, use_log10=True) + except (IndexError, AttributeError): + axes_flat[6].axis("off") + + _plot_source_plane(fit, axes_flat[7], final_plane_index, zoom_to_brightest=True, + colormap=colormap, use_log10=True) + + norm_resid = fit.normalized_residual_map + _abs_max = _symmetric_vmax(norm_resid) + plot_array(array=norm_resid, ax=axes_flat[8], title="Normalized Residual Map", + colormap=colormap, vmin=-_abs_max, vmax=_abs_max) + + plot_array(array=norm_resid, ax=axes_flat[9], + title=r"Normalized Residual Map $1\sigma$", + colormap=colormap, vmin=-1.0, vmax=1.0) + + plot_array(array=fit.chi_squared_map, ax=axes_flat[10], title="Chi-Squared Map", + colormap=colormap, use_log10=True) + + _plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False, + colormap=colormap, use_log10=True) + + plt.tight_layout() + save_figure(fig, path=output_path, filename=f"subplot_fit_log10{plane_index_tag}", format=output_format) + + +def subplot_fit_log10_x1_plane( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", +): + """6-panel log10 subplot for a single-plane tracer fit.""" + fig, axes = plt.subplots(2, 3, figsize=(21, 14)) + axes_flat = list(axes.flatten()) + + try: + vmax = float(np.max(fit.model_images_of_planes_list[0].array)) + except (IndexError, AttributeError, ValueError): + vmax = None + + plot_array(array=fit.data, ax=axes_flat[0], title="Data", colormap=colormap, + vmax=vmax, use_log10=True) + + try: + plot_array(array=fit.signal_to_noise_map, ax=axes_flat[1], + title="Signal-To-Noise Map", colormap=colormap, use_log10=True) + except ValueError: + axes_flat[1].axis("off") + + plot_array(array=fit.model_data, ax=axes_flat[2], title="Model Image", + colormap=colormap, vmax=vmax, use_log10=True) + + norm_resid = fit.normalized_residual_map + plot_array(array=norm_resid, ax=axes_flat[3], title="Lens Light Subtracted", + colormap=colormap) + _abs_max = _symmetric_vmax(norm_resid) + plot_array(array=norm_resid, ax=axes_flat[4], title="Normalized Residual Map", + colormap=colormap, vmin=-_abs_max, vmax=_abs_max) + plot_array(array=fit.chi_squared_map, ax=axes_flat[5], title="Chi-Squared Map", + colormap=colormap, use_log10=True) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_fit_log10", format=output_format) + + +def subplot_of_planes( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + plane_index: Optional[int] = None, +): + """4-panel subplot per plane: data, subtracted, model image, plane image.""" + if plane_index is None: + plane_indexes = range(len(fit.tracer.planes)) + else: + plane_indexes = [plane_index] + + for pidx in plane_indexes: + fig, axes = plt.subplots(1, 4, figsize=(28, 7)) + axes_flat = list(axes.flatten()) + + plot_array(array=fit.data, ax=axes_flat[0], title="Data", colormap=colormap) + + try: + subtracted = fit.subtracted_images_of_planes_list[pidx] + plot_array(array=subtracted, ax=axes_flat[1], + title=f"Subtracted Image Plane {pidx}", colormap=colormap) + except (IndexError, AttributeError): + axes_flat[1].axis("off") + + try: + model_img = fit.model_images_of_planes_list[pidx] + plot_array(array=model_img, ax=axes_flat[2], + title=f"Model Image Plane {pidx}", colormap=colormap) + except (IndexError, AttributeError): + axes_flat[2].axis("off") + + _plot_source_plane(fit, axes_flat[3], pidx, colormap=colormap) + + plt.tight_layout() + save_figure(fig, path=output_path, filename=f"subplot_of_plane_{pidx}", format=output_format) + + +def subplot_tracer_from_fit( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", +): + """9-panel tracer subplot derived from a FitImaging object.""" + final_plane_index = len(fit.tracer.planes) - 1 + + fig, axes = plt.subplots(3, 3, figsize=(21, 21)) + axes_flat = list(axes.flatten()) + + tracer = fit.tracer_linear_light_profiles_to_light_profiles + + plot_array(array=fit.model_data, ax=axes_flat[0], title="Model Image", + colormap=colormap) + + try: + source_model_img = fit.model_images_of_planes_list[final_plane_index] + source_vmax = float(np.max(source_model_img.array)) + plot_array(array=source_model_img, ax=axes_flat[1], title="Source Model Image", + colormap=colormap, vmax=source_vmax) + except (IndexError, AttributeError, ValueError): + axes_flat[1].axis("off") + + _plot_source_plane(fit, axes_flat[2], final_plane_index, zoom_to_brightest=False, + colormap=colormap) + + # Lens plane mass quantities (log10) + zoom = aa.Zoom2D(mask=fit.mask) + grid = aa.Grid2D.from_extent( + extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native + ) + + tan_cc, rad_cc = _critical_curves_from(tracer, grid) + image_plane_lines = _to_lines(tan_cc, rad_cc) + + traced_grids = tracer.traced_grid_2d_list_from(grid=grid) + lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0]) + lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0]) + plot_array(array=lens_image, ax=axes_flat[3], title="Lens Image", + lines=image_plane_lines, colormap=colormap, use_log10=True) + + for i in range(4, 9): + axes_flat[i].axis("off") + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_tracer", format=output_format) + + +def subplot_fit_combined( + fit_list: List, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", +): + """Combined multi-row subplot for a list of FitImaging objects.""" + n_fits = len(fit_list) + n_cols = 6 + fig, axes = plt.subplots(n_fits, n_cols, figsize=(7 * n_cols, 7 * n_fits)) + if n_fits == 1: + all_axes = [list(axes)] + else: + all_axes = [list(axes[i]) for i in range(n_fits)] + + final_plane_index = len(fit_list[0].tracer.planes) - 1 + + for row, fit in enumerate(fit_list): + row_axes = all_axes[row] + + plot_array(array=fit.data, ax=row_axes[0], title="Data", colormap=colormap) + + try: + subtracted = fit.subtracted_images_of_planes_list[1] + plot_array(array=subtracted, ax=row_axes[1], title="Subtracted Image", + colormap=colormap) + except (IndexError, AttributeError): + row_axes[1].axis("off") + + try: + lens_model = fit.model_images_of_planes_list[0] + plot_array(array=lens_model, ax=row_axes[2], title="Lens Model Image", + colormap=colormap) + except (IndexError, AttributeError): + row_axes[2].axis("off") + + try: + source_model = fit.model_images_of_planes_list[final_plane_index] + plot_array(array=source_model, ax=row_axes[3], title="Source Model Image", + colormap=colormap) + except (IndexError, AttributeError): + row_axes[3].axis("off") + + try: + _plot_source_plane(fit, row_axes[4], final_plane_index, colormap=colormap) + except Exception: + row_axes[4].axis("off") + + plot_array(array=fit.normalized_residual_map, ax=row_axes[5], + title="Normalized Residual Map", colormap=colormap) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_fit_combined", format=output_format) + + +def subplot_fit_combined_log10( + fit_list: List, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", +): + """Combined log10 multi-row subplot for a list of FitImaging objects.""" + n_fits = len(fit_list) + n_cols = 6 + fig, axes = plt.subplots(n_fits, n_cols, figsize=(7 * n_cols, 7 * n_fits)) + if n_fits == 1: + all_axes = [list(axes)] + else: + all_axes = [list(axes[i]) for i in range(n_fits)] + + final_plane_index = len(fit_list[0].tracer.planes) - 1 + + for row, fit in enumerate(fit_list): + row_axes = all_axes[row] + + plot_array(array=fit.data, ax=row_axes[0], title="Data", colormap=colormap, + use_log10=True) + + try: + subtracted = fit.subtracted_images_of_planes_list[1] + plot_array(array=subtracted, ax=row_axes[1], title="Subtracted Image", + colormap=colormap, use_log10=True) + except (IndexError, AttributeError): + row_axes[1].axis("off") + + try: + lens_model = fit.model_images_of_planes_list[0] + plot_array(array=lens_model, ax=row_axes[2], title="Lens Model Image", + colormap=colormap, use_log10=True) + except (IndexError, AttributeError): + row_axes[2].axis("off") + + try: + source_model = fit.model_images_of_planes_list[final_plane_index] + plot_array(array=source_model, ax=row_axes[3], title="Source Model Image", + colormap=colormap, use_log10=True) + except (IndexError, AttributeError): + row_axes[3].axis("off") + + try: + _plot_source_plane(fit, row_axes[4], final_plane_index, colormap=colormap, + use_log10=True) + except Exception: + row_axes[4].axis("off") + + plot_array(array=fit.normalized_residual_map, ax=row_axes[5], + title="Normalized Residual Map", colormap=colormap) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="fit_combined_log10", format=output_format) + + +def _symmetric_vmax(array) -> float: + """Return abs-max finite value for symmetric colormap scaling.""" + try: + vals = _zoom_array_2d(array).native.array + except AttributeError: + vals = np.asarray(array) + finite = vals[np.isfinite(vals)] + return float(np.max(np.abs(finite))) if finite.size else 1.0 diff --git a/autolens/imaging/plot/fit_imaging_plotters.py b/autolens/imaging/plot/fit_imaging_plotters.py deleted file mode 100644 index f5e8f8d60..000000000 --- a/autolens/imaging/plot/fit_imaging_plotters.py +++ /dev/null @@ -1,980 +0,0 @@ -import copy -import numpy as np -from typing import Optional - -from autoconf import conf - -import autoarray as aa -import autogalaxy.plot as aplt - -from autoarray.plot.auto_labels import AutoLabels -from autoarray.fit.plot.fit_imaging_plotters import FitImagingPlotterMeta - -from autolens.plot.abstract_plotters import Plotter -from autolens.imaging.fit_imaging import FitImaging -from autolens.lens.plot.tracer_plotters import TracerPlotter - -from autolens.lens import tracer_util - - -class FitImagingPlotter(Plotter): - def __init__( - self, - fit: FitImaging, - mat_plot_2d: aplt.MatPlot2D = None, - visuals_2d: aplt.Visuals2D = None, - residuals_symmetric_cmap: bool = True, - visuals_2d_of_planes_list : Optional = None - ): - """ - Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to an imaging dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make the plot. - visuals_2d - Contains visuals that can be overlaid on the plot. - residuals_symmetric_cmap - If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such - that `abs(vmin) = abs(vmax)`. - """ - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.fit = fit - - self._fit_imaging_meta_plotter = FitImagingPlotterMeta( - fit=self.fit, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - residuals_symmetric_cmap=residuals_symmetric_cmap, - ) - - self.residuals_symmetric_cmap = residuals_symmetric_cmap - - self._visuals_2d_of_planes_list = visuals_2d_of_planes_list - - @property - def visuals_2d_of_planes_list(self): - - if self._visuals_2d_of_planes_list is None: - - self._visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from( - tracer=self.fit.tracer, - grid=self.fit.grids.lp.mask.derive_grid.all_false, - ) - - return self._visuals_2d_of_planes_list - - def visuals_2d_from( - self, plane_index: Optional[int] = None, remove_critical_caustic: bool = False - ) -> aplt.Visuals2D: - """ - Returns the `Visuals2D` of the plotter with critical curves and caustics added, which are used to plot - the critical curves and caustics of the `Tracer` object. - - If `remove_critical_caustic` is `True`, critical curves and caustics are not included in the visuals. - - Parameters - ---------- - plane_index - The index of the plane in the tracer which is used to extract quantities, as only one plane is plotted - at a time. - remove_critical_caustic - Whether to remove critical curves and caustics from the visuals. - """ - if remove_critical_caustic: - return self.visuals_2d - - return ( - self.visuals_2d - + self.visuals_2d_of_planes_list[plane_index] - ) - - @property - def tracer(self): - return self.fit.tracer_linear_light_profiles_to_light_profiles - - def tracer_plotter_of_plane( - self, plane_index: int, remove_critical_caustic: bool = False - ) -> TracerPlotter: - """ - Returns an `TracerPlotter` corresponding to the `Tracer` in the `FitImaging`. - """ - - zoom = aa.Zoom2D(mask=self.fit.mask) - - grid = aa.Grid2D.from_extent( - extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native - ) - return TracerPlotter( - tracer=self.tracer, - grid=grid, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d_from( - plane_index=plane_index, remove_critical_caustic=remove_critical_caustic - ), - ) - - def inversion_plotter_of_plane( - self, plane_index: int, remove_critical_caustic: bool = False - ) -> aplt.InversionPlotter: - """ - Returns an `InversionPlotter` corresponding to one of the `Inversion`'s in the fit, which is specified via - the index of the `Plane` that inversion was performed on. - - Parameters - ---------- - plane_index - The index of the inversion in the inversion which is used to create the `InversionPlotter`. - - Returns - ------- - InversionPlotter - An object that plots inversions which is used for plotting attributes of the inversion. - """ - - inversion_plotter = aplt.InversionPlotter( - inversion=self.fit.inversion, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d_from( - plane_index=plane_index, remove_critical_caustic=remove_critical_caustic - ), - ) - return inversion_plotter - - def plane_indexes_from(self, plane_index: int): - """ - Returns a list of all indexes of the planes in the fit, which is iterated over in figures that plot - individual figures of each plane in a tracer. - - Parameters - ---------- - plane_index - A specific plane index which when input means that only a single plane index is returned. - - Returns - ------- - list - A list of galaxy indexes corresponding to planes in the plane. - """ - if plane_index is None: - return range(len(self.fit.tracer.planes)) - return [plane_index] - - def figures_2d_of_planes( - self, - plane_index: Optional[int] = None, - subtracted_image: bool = False, - model_image: bool = False, - plane_image: bool = False, - plane_noise_map: bool = False, - plane_signal_to_noise_map: bool = False, - use_source_vmax: bool = False, - zoom_to_brightest: bool = True, - remove_critical_caustic: bool = False, - ): - """ - Plots images representing each individual `Plane` in the fit's `Tracer` in 2D, which are computed via the - plotter's 2D grid object. - - These images subtract or omit the contribution of other planes in the plane, such that plots showing - each individual plane are made. - - The API is such that every plottable attribute of the `Plane` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - plane_index - The index of the plane which figures are plotted for. - subtracted_image - Whether to make a 2D plot (via `imshow`) of the subtracted image of a plane, where this image is - the fit's `data` minus the model images of all other planes, thereby showing an individual plane in the - data. - model_image - Whether to make a 2D plot (via `imshow`) of the model image of a plane, where this image is the - model image of one plane, thereby showing how much it contributes to the overall model image. - plane_image - Whether to make a 2D plot (via `imshow`) of the image of a plane in its source-plane (e.g. unlensed). - Depending on how the fit is performed, this could either be an image of light profiles of the reconstruction - of an `Inversion`. - plane_noise_map - Whether to make a 2D plot of the noise-map of a plane in its source-plane, where the - noise map can only be computed when a pixelized source reconstruction is performed and they correspond to - the noise map in each reconstructed pixel as given by the inverse curvature matrix. - plane_signal_to_noise_map - Whether to make a 2D plot of the signal-to-noise map of a plane in its source-plane, - where the signal-to-noise map values can only be computed when a pixelized source reconstruction and they - are the ratio of reconstructed flux to error in each pixel. - use_source_vmax - If `True`, the maximum value of the lensed source (e.g. in the image-plane) is used to set the `vmax` of - certain plots (e.g. the `data`) in order to ensure the lensed source is visible compared to the lens. - zoom_to_brightest - For images not in the image-plane (e.g. the `plane_image`), whether to automatically zoom the plot to - the brightest regions of the galaxies being plotted as opposed to the full extent of the grid. - remove_critical_caustic - Whether to remove critical curves and caustics from the plot. - """ - - plane_indexes = self.plane_indexes_from(plane_index=plane_index) - - for plane_index in plane_indexes: - - if use_source_vmax: - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max( - self.fit.model_images_of_planes_list[plane_index].array - ) - - if subtracted_image: - - title = f"Subtracted Image of Plane {plane_index}" - filename = f"subtracted_image_of_plane_{plane_index}" - - if len(self.tracer.planes) == 2: - - if plane_index == 0: - - title = "Source Subtracted Image" - filename = "source_subtracted_image" - - elif plane_index == 1: - - title = "Lens Subtracted Image" - filename = "lens_subtracted_image" - - self.mat_plot_2d.plot_array( - array=self.fit.subtracted_images_of_planes_list[plane_index], - visuals_2d=self.visuals_2d_from( - plane_index=plane_index, - remove_critical_caustic=remove_critical_caustic, - ), - auto_labels=aplt.AutoLabels(title=title, filename=filename), - ) - - if model_image: - - title = f"Model Image of Plane {plane_index}" - filename = f"model_image_of_plane_{plane_index}" - - if len(self.tracer.planes) == 2: - - if plane_index == 0: - - title = "Lens Model Image" - filename = "lens_model_image" - - elif plane_index == 1: - - title = "Source Model Image" - filename = "source_model_image" - - self.mat_plot_2d.plot_array( - array=self.fit.model_images_of_planes_list[plane_index], - visuals_2d=self.visuals_2d_from( - plane_index=plane_index, - remove_critical_caustic=remove_critical_caustic, - ), - auto_labels=aplt.AutoLabels(title=title, filename=filename), - ) - - if plane_image: - - if not self.tracer.planes[plane_index].has(cls=aa.Pixelization): - - tracer_plotter = self.tracer_plotter_of_plane( - plane_index=plane_index, - remove_critical_caustic=remove_critical_caustic, - ) - - tracer_plotter.figures_2d_of_planes( - plane_image=True, - plane_index=plane_index, - zoom_to_brightest=zoom_to_brightest, - retain_visuals=True, - ) - - elif self.tracer.planes[plane_index].has(cls=aa.Pixelization): - - inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index, - remove_critical_caustic=remove_critical_caustic, - ) - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, - reconstruction=True, - zoom_to_brightest=zoom_to_brightest, - ) - - if use_source_vmax: - try: - self.mat_plot_2d.cmap.kwargs.pop("vmax") - except KeyError: - pass - - if plane_noise_map: - - if self.tracer.planes[plane_index].has(cls=aa.Pixelization): - - inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index, - remove_critical_caustic=remove_critical_caustic, - ) - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, - reconstruction_noise_map=True, - zoom_to_brightest=zoom_to_brightest, - ) - - if plane_signal_to_noise_map: - - if self.tracer.planes[plane_index].has(cls=aa.Pixelization): - - inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index, - remove_critical_caustic=remove_critical_caustic, - ) - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, - signal_to_noise_map=True, - zoom_to_brightest=zoom_to_brightest, - ) - - def subplot( - self, - data: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - model_data: bool = False, - residual_map: bool = False, - normalized_residual_map: bool = False, - chi_squared_map: bool = False, - auto_filename: str = "subplot_fit", - ): - """ - Plots the individual attributes of the plotter's `FitImaging` object in 2D on a subplot. - - The API is such that every plottable attribute of the `FitImaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - data - Whether to include a 2D plot (via `imshow`) of the image data. - noise_map - Whether to include a 2D plot (via `imshow`) of the noise map. - psf - Whether to include a 2D plot (via `imshow`) of the psf. - signal_to_noise_map - Whether to include a 2D plot (via `imshow`) of the signal-to-noise map. - model_data - Whether to include a 2D plot (via `imshow`) of the model image. - residual_map - Whether to include a 2D plot (via `imshow`) of the residual map. - normalized_residual_map - Whether to include a 2D plot (via `imshow`) of the normalized residual map. - chi_squared_map - Whether to include a 2D plot (via `imshow`) of the chi-squared map. - auto_filename - The default filename of the output subplot if written to hard-disk. - """ - self._subplot_custom_plot( - data=data, - noise_map=noise_map, - signal_to_noise_map=signal_to_noise_map, - model_image=model_data, - residual_map=residual_map, - normalized_residual_map=normalized_residual_map, - chi_squared_map=chi_squared_map, - auto_labels=AutoLabels(filename=auto_filename), - ) - - - def subplot_fit_x1_plane(self): - """ - Standard subplot of the attributes of the plotter's `FitImaging` object. - """ - - self.open_subplot_figure(number_subplots=6) - - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[0].array) - self.figures_2d(data=True) - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - self.figures_2d(signal_to_noise_map=True) - - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[0].array) - self.figures_2d(model_image=True) - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - self.residuals_symmetric_cmap = False - self.set_title(label="Lens Light Subtracted") - self.figures_2d(normalized_residual_map=True) - - self.mat_plot_2d.cmap.kwargs["vmin"] = 0.0 - self.set_title(label="Subtracted Image Zero Minimum") - self.figures_2d(normalized_residual_map=True) - self.mat_plot_2d.cmap.kwargs.pop("vmin") - - self.residuals_symmetric_cmap = True - self.set_title(label="Normalized Residual Map") - self.figures_2d(normalized_residual_map=True) - self.set_title(label=None) - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_fit_x1_plane") - self.close_subplot_figure() - - def subplot_fit_log10_x1_plane(self): - """ - Standard subplot of the attributes of the plotter's `FitImaging` object. - """ - - contour_original = copy.copy(self.mat_plot_2d.contour) - use_log10_original = self.mat_plot_2d.use_log10 - - self.open_subplot_figure(number_subplots=6) - - self.mat_plot_2d.contour = False - self.mat_plot_2d.use_log10 = True - - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[0].array) - self.figures_2d(data=True) - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - self.figures_2d(signal_to_noise_map=True) - - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[0].array) - self.figures_2d(model_image=True) - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - self.residuals_symmetric_cmap = False - self.set_title(label="Lens Light Subtracted") - self.figures_2d(normalized_residual_map=True) - - self.residuals_symmetric_cmap = True - self.set_title(label="Normalized Residual Map") - self.figures_2d(normalized_residual_map=True) - self.set_title(label=None) - - self.figures_2d(chi_squared_map=True) - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_fit_log10") - self.close_subplot_figure() - - self.mat_plot_2d.use_log10 = use_log10_original - self.mat_plot_2d.contour = contour_original - - def subplot_fit(self, plane_index: Optional[int] = None): - """ - Standard subplot of the attributes of the plotter's `FitImaging` object. - """ - - if len(self.fit.tracer.planes) == 1: - return self.subplot_fit_x1_plane() - - self.open_subplot_figure(number_subplots=12) - - self.figures_2d(data=True) - - self.set_title(label="Data (Source Scale)") - self.figures_2d(data=True, use_source_vmax=True) - self.set_title(label=None) - - self.figures_2d(signal_to_noise_map=True) - self.figures_2d(model_image=True) - - self.set_title(label="Lens Light Model Image") - self.figures_2d_of_planes( - plane_index=0, model_image=True, remove_critical_caustic=True - ) - - # If the lens light is not included the subplot index does not increase, so we must manually set it to 4 - self.mat_plot_2d.subplot_index = 6 - - plane_index_tag = "" if plane_index is None else f"_{plane_index}" - - plane_index = ( - len(self.fit.tracer.planes) - 1 if plane_index is None else plane_index - ) - - self.mat_plot_2d.cmap.kwargs["vmin"] = 0.0 - - self.set_title(label="Lens Light Subtracted") - self.figures_2d_of_planes( - plane_index=plane_index, - subtracted_image=True, - use_source_vmax=True, - remove_critical_caustic=True, - ) - - self.set_title(label="Source Model Image") - self.figures_2d_of_planes( - plane_index=plane_index, - model_image=True, - use_source_vmax=True, - remove_critical_caustic=True, - ) - - self.mat_plot_2d.cmap.kwargs.pop("vmin") - - self.set_title(label="Source Plane (Zoomed)") - self.figures_2d_of_planes( - plane_index=plane_index, plane_image=True, use_source_vmax=True - ) - - self.set_title(label=None) - - self.mat_plot_2d.subplot_index = 9 - - self.figures_2d(normalized_residual_map=True) - - self.mat_plot_2d.cmap.kwargs["vmin"] = -1.0 - self.mat_plot_2d.cmap.kwargs["vmax"] = 1.0 - - self.set_title(label=r"Normalized Residual Map $1\sigma$") - self.figures_2d(normalized_residual_map=True) - self.set_title(label=None) - - self.mat_plot_2d.cmap.kwargs.pop("vmin") - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - self.figures_2d(chi_squared_map=True) - - self.set_title(label="Source Plane (No Zoom)") - self.figures_2d_of_planes( - plane_index=plane_index, - plane_image=True, - zoom_to_brightest=False, - use_source_vmax=True, - ) - - self.set_title(label=None) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"subplot_fit{plane_index_tag}", - # also_show=self.mat_plot_2d.quick_update - ) - self.close_subplot_figure() - - def subplot_fit_log10(self, plane_index: Optional[int] = None): - """ - Standard subplot of the attributes of the plotter's `FitImaging` object. - """ - - if len(self.fit.tracer.planes) == 1: - return self.subplot_fit_log10_x1_plane() - - contour_original = copy.copy(self.mat_plot_2d.contour) - use_log10_original = self.mat_plot_2d.use_log10 - - self.open_subplot_figure(number_subplots=12) - - self.mat_plot_2d.contour = False - self.mat_plot_2d.use_log10 = True - - self.figures_2d(data=True) - - self.set_title(label="Data (Source Scale)") - - try: - self.figures_2d(data=True, use_source_vmax=True) - except ValueError: - pass - - self.set_title(label=None) - - try: - self.figures_2d(signal_to_noise_map=True) - except ValueError: - pass - - self.figures_2d(model_image=True) - - self.set_title(label="Lens Light Model Image") - self.figures_2d_of_planes(plane_index=0, model_image=True, remove_critical_caustic=True) - - # If the lens light is not included the subplot index does not increase, so we must manually set it to 4 - self.mat_plot_2d.subplot_index = 6 - - plane_index_tag = "" if plane_index is None else f"_{plane_index}" - - plane_index = ( - len(self.fit.tracer.planes) - 1 if plane_index is None else plane_index - ) - - self.mat_plot_2d.cmap.kwargs["vmin"] = 0.0 - - self.set_title(label="Lens Light Subtracted") - self.figures_2d_of_planes( - plane_index=plane_index, subtracted_image=True, use_source_vmax=True, remove_critical_caustic=True - ) - - self.set_title(label="Source Model Image") - self.figures_2d_of_planes( - plane_index=plane_index, model_image=True, use_source_vmax=True, remove_critical_caustic=True - ) - - self.mat_plot_2d.cmap.kwargs.pop("vmin") - - self.set_title(label="Source Plane (Zoomed)") - self.figures_2d_of_planes( - plane_index=plane_index, plane_image=True, use_source_vmax=True - ) - - self.set_title(label=None) - - self.mat_plot_2d.use_log10 = False - - self.mat_plot_2d.subplot_index = 9 - - self.figures_2d(normalized_residual_map=True) - - self.mat_plot_2d.cmap.kwargs["vmin"] = -1.0 - self.mat_plot_2d.cmap.kwargs["vmax"] = 1.0 - - self.set_title(label=r"Normalized Residual Map $1\sigma$") - self.figures_2d(normalized_residual_map=True) - self.set_title(label=None) - - self.mat_plot_2d.cmap.kwargs.pop("vmin") - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - self.mat_plot_2d.use_log10 = True - - self.figures_2d(chi_squared_map=True) - - self.set_title(label="Source Plane (No Zoom)") - self.figures_2d_of_planes( - plane_index=plane_index, - plane_image=True, - zoom_to_brightest=False, - use_source_vmax=True, - ) - - self.set_title(label=None) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"subplot_fit_log10{plane_index_tag}" - ) - self.close_subplot_figure() - - self.mat_plot_2d.use_log10 = use_log10_original - self.mat_plot_2d.contour = contour_original - - def subplot_of_planes(self, plane_index: Optional[int] = None): - """ - Plots images representing each individual `Plane` in the plotter's `Tracer` in 2D on a subplot, which are - computed via the plotter's 2D grid object. - - These images subtract or omit the contribution of other planes in the plane, such that plots showing - each individual plane are made. - - The subplot plots the subtracted image, model image and plane image of each plane, where are described in the - `figures_2d_of_planes` function. - - Parameters - ---------- - plane_index - The index of the plane whose images are included on the subplot. - """ - - plane_indexes = self.plane_indexes_from(plane_index=plane_index) - - for plane_index in plane_indexes: - - self.open_subplot_figure(number_subplots=4) - - self.figures_2d(data=True) - - self.figures_2d_of_planes(subtracted_image=True, plane_index=plane_index) - self.figures_2d_of_planes(model_image=True, plane_index=plane_index) - self.figures_2d_of_planes(plane_image=True, plane_index=plane_index) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"subplot_of_plane_{plane_index}" - ) - self.close_subplot_figure() - - def subplot_tracer(self): - """ - Standard subplot of a Tracer. - - The `subplot_tracer` method in the `Tracer` class cannot plot the images of galaxies which are computed - via an `Inversion`. Therefore, using the `subplot_tracer` method of the `FitImagingPLotter` can plot - more information. - - Returns - ------- - - """ - - use_log10_original = self.mat_plot_2d.use_log10 - - final_plane_index = len(self.fit.tracer.planes) - 1 - - self.open_subplot_figure(number_subplots=9) - - self.figures_2d(model_image=True) - - self.set_title(label="Lensed Source Image") - self.figures_2d_of_planes( - plane_index=final_plane_index, model_image=True, use_source_vmax=True - ) - self.set_title(label=None) - - self.set_title(label="Source Plane") - self.figures_2d_of_planes( - plane_index=final_plane_index, - plane_image=True, - zoom_to_brightest=False, - use_source_vmax=True, - ) - - tracer_plotter = self.tracer_plotter_of_plane(plane_index=0) - - tracer_plotter._subplot_lens_and_mass() - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_tracer") - self.close_subplot_figure() - - self.mat_plot_2d.use_log10 = use_log10_original - - def subplot_mappings_of_plane( - self, plane_index: Optional[int] = None, auto_filename: str = "subplot_mappings" - ): - - try: - - plane_indexes = self.plane_indexes_from(plane_index=plane_index) - - for plane_index in plane_indexes: - - pixelization_index = 0 - - inversion_plotter = self.inversion_plotter_of_plane(plane_index=0) - - inversion_plotter.open_subplot_figure(number_subplots=4) - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=pixelization_index, data_subtracted=True - ) - - total_pixels = conf.instance["visualize"]["general"]["inversion"][ - "total_mappings_pixels" - ] - - mapper = inversion_plotter.inversion.cls_list_from( - cls=aa.Mapper - )[0] - - pix_indexes = inversion_plotter.inversion.max_pixel_list_from( - total_pixels=total_pixels, filter_neighbors=True - ) - - indexes = mapper.slim_indexes_for_pix_indexes(pix_indexes=pix_indexes) - - inversion_plotter.visuals_2d.indexes = indexes - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=pixelization_index, reconstructed_operated_data=True - ) - - self.visuals_2d.source_plane_mesh_indexes = [ - [index] for index in pix_indexes[pixelization_index] - ] - - self.figures_2d_of_planes( - plane_index=plane_index, plane_image=True, use_source_vmax=True - ) - - self.set_title(label="Source Reconstruction (Unzoomed)") - self.figures_2d_of_planes( - plane_index=plane_index, - plane_image=True, - zoom_to_brightest=False, - use_source_vmax=True, - ) - self.set_title(label=None) - - self.visuals_2d.source_plane_mesh_indexes = None - - inversion_plotter.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"{auto_filename}_{pixelization_index}" - ) - - inversion_plotter.close_subplot_figure() - - except (IndexError, AttributeError, ValueError): - - pass - - def figures_2d( - self, - data: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - model_image: bool = False, - residual_map: bool = False, - normalized_residual_map: bool = False, - chi_squared_map: bool = False, - residual_flux_fraction_map: bool = False, - use_source_vmax: bool = False, - suffix: str = "", - ): - """ - Plots the individual attributes of the plotter's `FitImaging` object in 2D. - - The API is such that every plottable attribute of the `FitImaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - data - Whether to make a 2D plot (via `imshow`) of the image data. - noise_map - Whether to make a 2D plot (via `imshow`) of the noise map. - signal_to_noise_map - Whether to make a 2D plot (via `imshow`) of the signal-to-noise map. - model_image - Whether to make a 2D plot (via `imshow`) of the model image. - residual_map - Whether to make a 2D plot (via `imshow`) of the residual map. - normalized_residual_map - Whether to make a 2D plot (via `imshow`) of the normalized residual map. - chi_squared_map - Whether to make a 2D plot (via `imshow`) of the chi-squared map. - residual_flux_fraction_map - Whether to make a 2D plot (via `imshow`) of the residual flux fraction map. - use_source_vmax - If `True`, the maximum value of the lensed source (e.g. in the image-plane) is used to set the `vmax` of - certain plots (e.g. the `data`) in order to ensure the lensed source is visible compared to the lens. - """ - - if use_source_vmax: - try: - source_vmax = np.max( - [ - model_image.array - for model_image in self.fit.model_images_of_planes_list[1:] - ] - ) - except ValueError: - source_vmax = None - - if data: - - if use_source_vmax: - self.mat_plot_2d.cmap.kwargs["vmax"] = source_vmax - - self.mat_plot_2d.plot_array( - array=self.fit.data, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Data", filename=f"data{suffix}"), - ) - - if use_source_vmax: - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - if noise_map: - - self.mat_plot_2d.plot_array( - array=self.fit.noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Noise-Map", filename=f"noise_map{suffix}" - ), - ) - - if signal_to_noise_map: - - self.mat_plot_2d.plot_array( - array=self.fit.signal_to_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Signal-To-Noise Map", - cb_unit=" S/N", - filename=f"signal_to_noise_map{suffix}", - ), - ) - - if model_image: - - if use_source_vmax: - self.mat_plot_2d.cmap.kwargs["vmax"] = source_vmax - - self.mat_plot_2d.plot_array( - array=self.fit.model_data, - visuals_2d=self.visuals_2d_from(plane_index=0), - auto_labels=AutoLabels( - title="Model Image", filename=f"model_image{suffix}" - ), - ) - - if use_source_vmax: - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - cmap_original = self.mat_plot_2d.cmap - - if self.residuals_symmetric_cmap: - - self.mat_plot_2d.cmap = self.mat_plot_2d.cmap.symmetric_cmap_from() - - if residual_map: - - self.mat_plot_2d.plot_array( - array=self.fit.residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Residual Map", filename=f"residual_map{suffix}" - ), - ) - - if normalized_residual_map: - - self.mat_plot_2d.plot_array( - array=self.fit.normalized_residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Normalized Residual Map", - cb_unit=r" $\sigma$", - filename=f"normalized_residual_map{suffix}", - ), - ) - - self.mat_plot_2d.cmap = cmap_original - - if chi_squared_map: - - self.mat_plot_2d.plot_array( - array=self.fit.chi_squared_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Chi-Squared Map", - cb_unit=r" $\chi^2$", - filename=f"chi_squared_map{suffix}", - ), - ) - - if residual_flux_fraction_map: - - self.mat_plot_2d.plot_array( - array=self.fit.residual_flux_fraction_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Residual Flux Fraction Map", - filename=f"residual_flux_fraction_map{suffix}", - ), - ) diff --git a/autolens/interferometer/model/plotter_interface.py b/autolens/interferometer/model/plotter_interface.py index fef75f04d..7a17d71a9 100644 --- a/autolens/interferometer/model/plotter_interface.py +++ b/autolens/interferometer/model/plotter_interface.py @@ -1,90 +1,83 @@ -from typing import Optional - -import autoarray.plot as aplt - -from autogalaxy.interferometer.model.plotter_interface import ( - PlotterInterfaceInterferometer as AgPlotterInterfaceInterferometer, -) - -from autogalaxy.interferometer.model.plotter_interface import fits_to_fits - -from autolens.interferometer.fit_interferometer import FitInterferometer -from autolens.interferometer.plot.fit_interferometer_plotters import ( - FitInterferometerPlotter, -) -from autolens.analysis.plotter_interface import PlotterInterface - -from autolens.analysis.plotter_interface import plot_setting - - -class PlotterInterfaceInterferometer(PlotterInterface): - interferometer = AgPlotterInterfaceInterferometer.interferometer - - def fit_interferometer( - self, - fit: FitInterferometer, - visuals_2d_of_planes_list: Optional[aplt.Visuals2D] = None, - quick_update: bool = False, - ): - """ - Visualizes a `FitInterferometer` object, which fits an interferometer dataset. - - Images are output to the `image` folder of the `image_path`. When used with a non-linear search the `image_path` - is the output folder of the non-linear search. - - Visualization includes a subplot of individual images of attributes of the `FitInterferometer` (e.g. the model - data,residual map) and .fits files containing its attributes grouped together. - - The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml` under - the `fit` and `fit_interferometer` headers. - - Parameters - ---------- - fit - The maximum log likelihood `FitInterferometer` of the non-linear search which is used to plot the fit. - """ - - def should_plot(name): - return plot_setting(section=["fit", "fit_interferometer"], name=name) - - mat_plot_1d = self.mat_plot_1d_from() - mat_plot_2d = self.mat_plot_2d_from() - - fit_plotter = FitInterferometerPlotter( - fit=fit, - mat_plot_1d=mat_plot_1d, - mat_plot_2d=mat_plot_2d, - ) - - if should_plot("subplot_fit"): - fit_plotter.subplot_fit() - - if should_plot("subplot_fit_dirty_images"): - fit_plotter.subplot_fit_dirty_images() - - if quick_update: - return - - if should_plot("subplot_fit_real_space"): - fit_plotter.subplot_fit_real_space() - - mat_plot_1d = self.mat_plot_1d_from() - mat_plot_2d = self.mat_plot_2d_from() - - fit_plotter = FitInterferometerPlotter( - fit=fit, - mat_plot_1d=mat_plot_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d_of_planes_list=visuals_2d_of_planes_list, - ) - - if plot_setting(section="inversion", name="subplot_mappings"): - fit_plotter.subplot_mappings_of_plane( - plane_index=len(fit.tracer.planes) - 1 - ) - - fits_to_fits( - should_plot=should_plot, - image_path=self.image_path, - fit=fit, - ) +from autogalaxy.interferometer.model.plotter_interface import ( + PlotterInterfaceInterferometer as AgPlotterInterfaceInterferometer, +) + +from autogalaxy.interferometer.model.plotter_interface import fits_to_fits + +from autolens.interferometer.fit_interferometer import FitInterferometer +from autolens.interferometer.plot.fit_interferometer_plots import ( + subplot_fit, + subplot_fit_real_space, +) +from autolens.analysis.plotter_interface import PlotterInterface + +from autolens.analysis.plotter_interface import plot_setting + + +class PlotterInterfaceInterferometer(PlotterInterface): + interferometer = AgPlotterInterfaceInterferometer.interferometer + + def fit_interferometer( + self, + fit: FitInterferometer, + quick_update: bool = False, + ): + """ + Visualizes a `FitInterferometer` object. + + Parameters + ---------- + fit + The maximum log likelihood `FitInterferometer` of the non-linear search. + """ + + def should_plot(name): + return plot_setting(section=["fit", "fit_interferometer"], name=name) + + output_path = str(self.image_path) + fmt = self.fmt + + if should_plot("subplot_fit"): + subplot_fit(fit, output_path=output_path, output_format=fmt) + + if should_plot("subplot_fit_dirty_images"): + # Use the autoarray FitInterferometerMeta plotter for dirty images subplot + try: + import autogalaxy.plot as aplt + from autoarray.fit.plot.fit_interferometer_plotters import FitInterferometerPlotterMeta + output = self.output_from() + meta_plotter = FitInterferometerPlotterMeta( + fit=fit, + output=output, + ) + meta_plotter.subplot_fit_dirty_images() + except Exception: + pass + + if quick_update: + return + + if should_plot("subplot_fit_real_space"): + subplot_fit_real_space(fit, output_path=output_path, output_format=fmt) + + if plot_setting(section="inversion", name="subplot_mappings"): + try: + import autogalaxy.plot as aplt + inversion_plotter = aplt.InversionPlotter( + inversion=fit.inversion, + mat_plot_2d=aplt.MatPlot2D( + output=aplt.Output(path=self.image_path, format=fmt), + ), + ) + inversion_plotter.subplot_of_mapper( + mapper_index=0, + auto_filename="subplot_mappings_0", + ) + except (IndexError, AttributeError, TypeError, Exception): + pass + + fits_to_fits( + should_plot=should_plot, + image_path=self.image_path, + fit=fit, + ) diff --git a/autolens/interferometer/model/visualizer.py b/autolens/interferometer/model/visualizer.py index 58ba934c6..86fbf1789 100644 --- a/autolens/interferometer/model/visualizer.py +++ b/autolens/interferometer/model/visualizer.py @@ -6,7 +6,6 @@ from autolens.interferometer.model.plotter_interface import ( PlotterInterfaceInterferometer, ) -from autolens.lens import tracer_util from autogalaxy import exc logger = logging.getLogger(__name__) @@ -95,10 +94,6 @@ def visualize( """ fit = analysis.fit_from(instance=instance) - visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from( - tracer=fit.tracer, grid=fit.grids.lp.mask.derive_grid.all_false - ) - plotter_interface = PlotterInterfaceInterferometer( image_path=paths.image_path, title_prefix=analysis.title_prefix ) @@ -106,7 +101,6 @@ def visualize( try: plotter_interface.fit_interferometer( fit=fit, - visuals_2d_of_planes_list=visuals_2d_of_planes_list, quick_update=quick_update, ) except exc.InversionException: @@ -146,17 +140,13 @@ def visualize( grid = ag.Grid2D.from_extent(extent=extent, shape_native=shape_native) try: - plotter_interface.fit_interferometer( - fit=fit, - visuals_2d_of_planes_list=visuals_2d_of_planes_list, - ) + plotter_interface.fit_interferometer(fit=fit) except exc.InversionException: pass plotter_interface.tracer( tracer=tracer, grid=grid, - visuals_2d_of_planes_list=visuals_2d_of_planes_list, ) plotter_interface.galaxies( galaxies=tracer.galaxies, diff --git a/autolens/interferometer/plot/fit_interferometer_plots.py b/autolens/interferometer/plot/fit_interferometer_plots.py new file mode 100644 index 000000000..d67339f73 --- /dev/null +++ b/autolens/interferometer/plot/fit_interferometer_plots.py @@ -0,0 +1,161 @@ +import matplotlib.pyplot as plt +import numpy as np +from typing import Optional + +import autoarray as aa +import autogalaxy as ag + +from autoarray.plot.plots.array import plot_array +from autoarray.plot.plots.utils import save_figure +from autolens.plot.plot_utils import ( + _to_lines, + _critical_curves_from, +) + + +def _plot_yx(y, x, ax, title, xlabel="", ylabel=""): + """Scatter plot of y vs x into an axes.""" + ax.scatter(x, y, s=1) + ax.set_title(title) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + +def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True, + colormap="jet", use_log10=False): + tracer = fit.tracer_linear_light_profiles_to_light_profiles + if not tracer.planes[plane_index].has(cls=aa.Pixelization): + zoom = aa.Zoom2D(mask=fit.dataset.real_space_mask) + grid = aa.Grid2D.from_extent( + extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native + ) + traced_grids = tracer.traced_grid_2d_list_from(grid=grid) + plane_galaxies = ag.Galaxies(galaxies=tracer.planes[plane_index]) + image = plane_galaxies.image_2d_from(grid=traced_grids[plane_index]) + plot_array( + array=image, ax=ax, + title=f"Source Plane {plane_index}", + colormap=colormap, use_log10=use_log10, + ) + else: + if ax is not None: + ax.axis("off") + ax.set_title(f"Source Reconstruction (plane {plane_index})") + + +def subplot_fit( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", +): + """12-panel subplot for an interferometer fit.""" + tracer = fit.tracer_linear_light_profiles_to_light_profiles + final_plane_index = len(fit.tracer.planes) - 1 + + fig, axes = plt.subplots(3, 4, figsize=(28, 21)) + axes_flat = list(axes.flatten()) + + # Panel 0: amplitudes vs UV distances + _plot_yx( + y=np.real(fit.residual_map), + x=fit.dataset.uv_distances / 10 ** 3.0, + ax=axes_flat[0], + title="Amplitudes vs UV-Distance", + xlabel=r"k$\lambda$", + ) + + plot_array(array=fit.dirty_image, ax=axes_flat[1], title="Dirty Image", + colormap=colormap) + plot_array(array=fit.dirty_signal_to_noise_map, ax=axes_flat[2], + title="Dirty Signal-To-Noise Map", colormap=colormap) + plot_array(array=fit.dirty_model_image, ax=axes_flat[3], title="Dirty Model Image", + colormap=colormap) + + # Panel 4: source image + _plot_source_plane(fit, axes_flat[4], final_plane_index, colormap=colormap) + + # Normalized residual vs UV distances (real) + _plot_yx( + y=np.real(fit.normalized_residual_map), + x=fit.dataset.uv_distances / 10 ** 3.0, + ax=axes_flat[5], + title="Norm Residual vs UV-Distance (real)", + ylabel=r"$\sigma$", + xlabel=r"k$\lambda$", + ) + + # Normalized residual vs UV distances (imag) + _plot_yx( + y=np.imag(fit.normalized_residual_map), + x=fit.dataset.uv_distances / 10 ** 3.0, + ax=axes_flat[6], + title="Norm Residual vs UV-Distance (imag)", + ylabel=r"$\sigma$", + xlabel=r"k$\lambda$", + ) + + # Panel 7: source plane zoomed + _plot_source_plane(fit, axes_flat[7], final_plane_index, colormap=colormap) + + plot_array(array=fit.dirty_normalized_residual_map, ax=axes_flat[8], + title="Dirty Normalized Residual Map", colormap=colormap) + + # Panel 9: clipped to [-1, 1] + from autolens.imaging.plot.fit_imaging_plots import _plot_with_vmin_vmax + _plot_with_vmin_vmax(fit.dirty_normalized_residual_map, axes_flat[9], + r"Normalized Residual Map $1\sigma$", colormap, + vmin=-1.0, vmax=1.0) + + plot_array(array=fit.dirty_chi_squared_map, ax=axes_flat[10], + title="Dirty Chi-Squared Map", colormap=colormap) + + # Panel 11: source plane not zoomed + _plot_source_plane(fit, axes_flat[11], final_plane_index, + zoom_to_brightest=False, colormap=colormap) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_fit", format=output_format) + + +def subplot_fit_real_space( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", +): + """Real-space subplot: image + source plane (or inversion panels).""" + tracer = fit.tracer_linear_light_profiles_to_light_profiles + final_plane_index = len(fit.tracer.planes) - 1 + + if fit.inversion is None: + # No inversion: image + source plane image + fig, axes = plt.subplots(1, 2, figsize=(14, 7)) + axes_flat = list(axes.flatten()) + + zoom = aa.Zoom2D(mask=fit.dataset.real_space_mask) + grid = aa.Grid2D.from_extent( + extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native + ) + traced_grids = tracer.traced_grid_2d_list_from(grid=grid) + + image = tracer.image_2d_from(grid=grid) + plot_array(array=image, ax=axes_flat[0], title="Image", colormap=colormap) + + source_galaxies = ag.Galaxies(galaxies=tracer.planes[final_plane_index]) + source_image = source_galaxies.image_2d_from( + grid=traced_grids[final_plane_index] + ) + plot_array(array=source_image, ax=axes_flat[1], title="Source Plane Image", + colormap=colormap) + else: + fig, axes = plt.subplots(1, 2, figsize=(14, 7)) + axes_flat = list(axes.flatten()) + # Inversion: show blank placeholder panels + for _ax in axes_flat: + _ax.axis("off") + axes_flat[0].set_title("Reconstructed Data") + axes_flat[1].set_title("Source Reconstruction") + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_fit_real_space", format=output_format) diff --git a/autolens/interferometer/plot/fit_interferometer_plotters.py b/autolens/interferometer/plot/fit_interferometer_plotters.py deleted file mode 100644 index 4f154468e..000000000 --- a/autolens/interferometer/plot/fit_interferometer_plotters.py +++ /dev/null @@ -1,528 +0,0 @@ -from typing import Optional - -from autoconf import conf - -import autoarray as aa -import autogalaxy.plot as aplt - -from autoarray.fit.plot.fit_interferometer_plotters import FitInterferometerPlotterMeta -from autoarray.plot.auto_labels import AutoLabels - -from autolens.interferometer.fit_interferometer import FitInterferometer -from autolens.lens.tracer import Tracer -from autolens.lens.plot.tracer_plotters import TracerPlotter -from autolens.plot.abstract_plotters import Plotter - -from autolens.lens import tracer_util - - -class FitInterferometerPlotter(Plotter): - def __init__( - self, - fit: FitInterferometer, - mat_plot_1d: aplt.MatPlot1D = None, - visuals_1d: aplt.Visuals1D = None, - mat_plot_2d: aplt.MatPlot2D = None, - visuals_2d: aplt.Visuals2D = None, - residuals_symmetric_cmap: bool = True, - visuals_2d_of_planes_list: Optional = None, - ): - """ - Plots the attributes of `FitInterferometer` objects using the matplotlib method `imshow()` and many - other matplotlib functions which customize the plot's appearance. - - The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, - the settings passed to every matplotlib function called are those specified in - the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2d` to - customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `FitInterferometer` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to an interferometer dataset the plotter plots. - mat_plot_1d - Contains objects which wrap the matplotlib function calls that make 1D plots. - visuals_1d - Contains 1D visuals that can be overlaid on 1D plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - residuals_symmetric_cmap - If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such - that `abs(vmin) = abs(vmax)`. - """ - super().__init__( - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - self.fit = fit - - self._fit_interferometer_meta_plotter = FitInterferometerPlotterMeta( - fit=self.fit, - mat_plot_1d=self.mat_plot_1d, - visuals_1d=self.visuals_1d, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - residuals_symmetric_cmap=residuals_symmetric_cmap, - ) - - self.subplot = self._fit_interferometer_meta_plotter.subplot - self.subplot_fit_dirty_images = ( - self._fit_interferometer_meta_plotter.subplot_fit_dirty_images - ) - - self._visuals_2d_of_planes_list = visuals_2d_of_planes_list - - @property - def visuals_2d_of_planes_list(self): - - if self._visuals_2d_of_planes_list is None: - self._visuals_2d_of_planes_list = ( - tracer_util.visuals_2d_of_planes_list_from( - tracer=self.fit.tracer, - grid=self.fit.grids.lp.mask.derive_grid.all_false, - ) - ) - - return self._visuals_2d_of_planes_list - - def visuals_2d_from( - self, plane_index: Optional[int] = None, remove_critical_caustic: bool = False - ) -> aplt.Visuals2D: - """ - Returns the `Visuals2D` of the plotter with critical curves and caustics added, which are used to plot - the critical curves and caustics of the `Tracer` object. - - If `remove_critical_caustic` is `True`, critical curves and caustics are not included in the visuals. - - Parameters - ---------- - plane_index - The index of the plane in the tracer which is used to extract quantities, as only one plane is plotted - at a time. - remove_critical_caustic - Whether to remove critical curves and caustics from the visuals. - """ - if remove_critical_caustic: - return self.visuals_2d - - return self.visuals_2d + self.visuals_2d_of_planes_list[plane_index] - - @property - def tracer(self) -> Tracer: - return self.fit.tracer_linear_light_profiles_to_light_profiles - - def tracer_plotter_of_plane( - self, plane_index: int, remove_critical_caustic: bool = False - ) -> TracerPlotter: - """ - Returns an `TracerPlotter` corresponding to the `Tracer` in the `FitImaging`. - """ - - zoom = aa.Zoom2D(mask=self.fit.dataset.real_space_mask) - - grid = aa.Grid2D.from_extent( - extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native - ) - return TracerPlotter( - tracer=self.tracer, - grid=grid, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d_from( - plane_index=plane_index, remove_critical_caustic=remove_critical_caustic - ), - ) - - def inversion_plotter_of_plane( - self, plane_index: int, remove_critical_caustic: bool = False - ) -> aplt.InversionPlotter: - """ - Returns an `InversionPlotter` corresponding to one of the `Inversion`'s in the fit, which is specified via - the index of the `Plane` that inversion was performed on. - - Parameters - ---------- - plane_index - The index of the inversion in the inversion which is used to create the `InversionPlotter`. - - Returns - ------- - InversionPlotter - An object that plots inversions which is used for plotting attributes of the inversion. - """ - - inversion_plotter = aplt.InversionPlotter( - inversion=self.fit.inversion, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d_from( - plane_index=plane_index, remove_critical_caustic=remove_critical_caustic - ), - ) - return inversion_plotter - - def plane_indexes_from(self, plane_index: int): - """ - Returns a list of all indexes of the planes in the fit, which is iterated over in figures that plot - individual figures of each plane in a tracer. - - Parameters - ---------- - plane_index - A specific plane index which when input means that only a single plane index is returned. - - Returns - ------- - list - A list of galaxy indexes corresponding to planes in the plane. - """ - if plane_index is None: - return range(len(self.fit.tracer.planes)) - return [plane_index] - - def figures_2d( - self, - data: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - amplitudes_vs_uv_distances: bool = False, - model_data: bool = False, - residual_map_real: bool = False, - residual_map_imag: bool = False, - normalized_residual_map_real: bool = False, - normalized_residual_map_imag: bool = False, - chi_squared_map_real: bool = False, - chi_squared_map_imag: bool = False, - image: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - dirty_model_image: bool = False, - dirty_residual_map: bool = False, - dirty_normalized_residual_map: bool = False, - dirty_chi_squared_map: bool = False, - ): - """ - Plots the individual attributes of the plotter's `FitInterferometer` object in 1D and 2D. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - data - Whether to make a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to make a 2D plot (via `scatter`) of the noise-map. - signal_to_noise_map - Whether to make a 2D plot (via `scatter`) of the signal-to-noise-map. - model_data - Whether to make a 2D plot (via `scatter`) of the model visibility data. - residual_map_real - Whether to make a 1D plot (via `plot`) of the real component of the residual map. - residual_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the residual map. - normalized_residual_map_real - Whether to make a 1D plot (via `plot`) of the real component of the normalized residual map. - normalized_residual_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the normalized residual map. - chi_squared_map_real - Whether to make a 1D plot (via `plot`) of the real component of the chi-squared map. - chi_squared_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the chi-squared map. - image - Whether to make a 2D plot (via `imshow`) of the source-plane image. - dirty_image - Whether to make a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty noise map. - dirty_model_image - Whether to make a 2D plot (via `imshow`) of the dirty model image. - dirty_residual_map - Whether to make a 2D plot (via `imshow`) of the dirty residual map. - dirty_normalized_residual_map - Whether to make a 2D plot (via `imshow`) of the dirty normalized residual map. - dirty_chi_squared_map - Whether to make a 2D plot (via `imshow`) of the dirty chi-squared map. - """ - self._fit_interferometer_meta_plotter.figures_2d( - data=data, - noise_map=noise_map, - signal_to_noise_map=signal_to_noise_map, - amplitudes_vs_uv_distances=amplitudes_vs_uv_distances, - model_data=model_data, - residual_map_real=residual_map_real, - residual_map_imag=residual_map_imag, - normalized_residual_map_real=normalized_residual_map_real, - normalized_residual_map_imag=normalized_residual_map_imag, - chi_squared_map_real=chi_squared_map_real, - chi_squared_map_imag=chi_squared_map_imag, - dirty_image=dirty_image, - dirty_noise_map=dirty_noise_map, - dirty_signal_to_noise_map=dirty_signal_to_noise_map, - dirty_residual_map=dirty_residual_map, - dirty_normalized_residual_map=dirty_normalized_residual_map, - dirty_chi_squared_map=dirty_chi_squared_map, - ) - - if image: - plane_index = len(self.tracer.planes) - 1 - - if not self.tracer.planes[plane_index].has(cls=aa.Pixelization): - - tracer_plotter = self.tracer_plotter_of_plane(plane_index=plane_index) - - tracer_plotter.figures_2d(image=True) - - elif self.tracer.planes[plane_index].has(cls=aa.Pixelization): - inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index - ) - inversion_plotter.figures_2d(reconstructed_operated_data=True) - - if dirty_model_image: - self.mat_plot_2d.plot_array( - array=self.fit.dirty_model_image, - visuals_2d=self.visuals_2d_of_planes_list[0], - auto_labels=AutoLabels( - title="Dirty Model Image", filename="dirty_model_image_2d" - ), - ) - - def figures_2d_of_planes( - self, - plane_index: Optional[int] = None, - plane_image: bool = False, - plane_noise_map: bool = False, - plane_signal_to_noise_map: bool = False, - zoom_to_brightest: bool = True, - ): - """ - Plots images representing each individual `Plane` in the fit's `Tracer` in 2D, which are computed via the - plotter's 2D grid object. - - These images subtract or omit the contribution of other planes in the plane, such that plots showing - each individual plane are made. - - The API is such that every plottable attribute of the `Plane` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - plane_index - The index of the plane which figures are plotted for. - plane_image - Whether to make a 2D plot (via `imshow`) of the image of a plane in its source-plane (e.g. unlensed). - Depending on how the fit is performed, this could either be an image of light profiles of the reconstruction - of an `Inversion`. - plane_noise_map - Whether to make a 2D plot of the noise-map of a plane in its source-plane, where the - noise map can only be computed when a pixelized source reconstruction is performed and they correspond to - the noise map in each reconstructed pixel as given by the inverse curvature matrix. - plane_signal_to_noise_map - Whether to make a 2D plot of the signal-to-noise map of a plane in its source-plane, - where the signal-to-noise map values can only be computed when a pixelized source reconstruction and they - are the ratio of reconstructed flux to error in each pixel. - zoom_to_brightest - For images not in the image-plane (e.g. the `plane_image`), whether to automatically zoom the plot to - the brightest regions of the galaxies being plotted as opposed to the full extent of the grid. - """ - if plane_image: - if not self.tracer.planes[plane_index].has(cls=aa.Pixelization): - - tracer_plotter = self.tracer_plotter_of_plane(plane_index=plane_index) - - tracer_plotter.figures_2d_of_planes( - plane_image=True, - plane_index=plane_index, - zoom_to_brightest=zoom_to_brightest, - ) - - elif self.tracer.planes[plane_index].has(cls=aa.Pixelization): - inversion_plotter = self.inversion_plotter_of_plane(plane_index=1) - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, - reconstruction=True, - zoom_to_brightest=zoom_to_brightest, - ) - - if plane_noise_map: - if self.tracer.planes[plane_index].has(cls=aa.Pixelization): - inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index - ) - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, - reconstruction_noise_map=True, - zoom_to_brightest=zoom_to_brightest, - ) - - if plane_signal_to_noise_map: - if self.tracer.planes[plane_index].has(cls=aa.Pixelization): - inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index - ) - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, - signal_to_noise_map=True, - zoom_to_brightest=zoom_to_brightest, - ) - - def subplot_fit(self): - """ - Standard subplot of the attributes of the plotter's `FitImaging` object. - """ - - self.open_subplot_figure(number_subplots=12) - - self.figures_2d(amplitudes_vs_uv_distances=True) - - self.mat_plot_1d.subplot_index = 2 - self.mat_plot_2d.subplot_index = 2 - - self.figures_2d(dirty_image=True) - self.figures_2d(dirty_signal_to_noise_map=True) - self.figures_2d(dirty_model_image=True) - self.figures_2d(image=True) - - self.mat_plot_1d.subplot_index = 6 - self.mat_plot_2d.subplot_index = 6 - - self.figures_2d(normalized_residual_map_real=True) - self.figures_2d(normalized_residual_map_imag=True) - - self.mat_plot_1d.subplot_index = 8 - self.mat_plot_2d.subplot_index = 8 - - final_plane_index = len(self.fit.tracer.planes) - 1 - - self.set_title(label="Source Plane (Zoomed)") - self.figures_2d_of_planes(plane_index=final_plane_index, plane_image=True) - self.set_title(label=None) - - self.figures_2d(dirty_normalized_residual_map=True) - - self.mat_plot_2d.cmap.kwargs["vmin"] = -1.0 - self.mat_plot_2d.cmap.kwargs["vmax"] = 1.0 - - self.set_title(label=r"Normalized Residual Map $1\sigma$") - self.figures_2d(dirty_normalized_residual_map=True) - self.set_title(label=None) - - self.mat_plot_2d.cmap.kwargs.pop("vmin") - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - self.figures_2d(dirty_chi_squared_map=True) - - self.set_title(label="Source Plane (No Zoom)") - self.figures_2d_of_planes( - plane_index=final_plane_index, - plane_image=True, - zoom_to_brightest=False, - ) - - self.set_title(label=None) - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_fit") - self.close_subplot_figure() - - def subplot_mappings_of_plane( - self, plane_index: Optional[int] = None, auto_filename: str = "subplot_mappings" - ): - if self.fit.inversion is None: - return - - plane_indexes = self.plane_indexes_from(plane_index=plane_index) - - for plane_index in plane_indexes: - pixelization_index = 0 - - inversion_plotter = self.inversion_plotter_of_plane(plane_index=0) - - inversion_plotter.open_subplot_figure(number_subplots=4) - - self.figures_2d(dirty_image=True) - - total_pixels = conf.instance["visualize"]["general"]["inversion"][ - "total_mappings_pixels" - ] - - pix_indexes = inversion_plotter.inversion.max_pixel_list_from( - total_pixels=total_pixels, filter_neighbors=True - ) - - inversion_plotter.visuals_2d.source_plane_mesh_indexes = [ - [index] for index in pix_indexes[pixelization_index] - ] - - inversion_plotter.visuals_2d.tangential_critical_curves = None - inversion_plotter.visuals_2d.radial_critical_curves = None - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=pixelization_index, reconstructed_operated_data=True - ) - - self.visuals_2d.source_plane_mesh_indexes = [ - [index] for index in pix_indexes[pixelization_index] - ] - - self.figures_2d_of_planes( - plane_index=plane_index, - plane_image=True, - ) - - self.set_title(label="Source Reconstruction (Unzoomed)") - self.figures_2d_of_planes( - plane_index=plane_index, - plane_image=True, - zoom_to_brightest=False, - ) - self.set_title(label=None) - - self.visuals_2d.source_plane_mesh_indexes = None - - inversion_plotter.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"{auto_filename}_{pixelization_index}" - ) - - inversion_plotter.close_subplot_figure() - - def subplot_fit_real_space(self): - """ - Standard subplot of the real-space attributes of the plotter's `FitInterferometer` object. - - Depending on whether `LightProfile`'s or an `Inversion` are used to represent galaxies in the `Tracer`, - different methods are called to create these real-space images. - """ - if self.fit.inversion is None: - - tracer_plotter = self.tracer_plotter_of_plane(plane_index=0) - - tracer_plotter.subplot( - image=True, source_plane=True, auto_filename="subplot_fit_real_space" - ) - - elif self.fit.inversion is not None: - self.open_subplot_figure(number_subplots=2) - - inversion_plotter = self.inversion_plotter_of_plane(plane_index=1) - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, reconstructed_operated_data=True - ) - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, reconstruction=True - ) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename="subplot_fit_real_space" - ) - self.close_subplot_figure() diff --git a/autolens/lens/plot/sensitivity_plots.py b/autolens/lens/plot/sensitivity_plots.py new file mode 100644 index 000000000..67219b90f --- /dev/null +++ b/autolens/lens/plot/sensitivity_plots.py @@ -0,0 +1,186 @@ +"""Standalone subplot functions for subhalo sensitivity mapping visualisation.""" +import matplotlib.pyplot as plt +import numpy as np +from typing import Optional + +import autoarray as aa + +from autoarray.plot.plots.array import plot_array +from autoarray.plot.plots.utils import save_figure + + +def subplot_tracer_images( + mask, + tracer_perturb, + tracer_no_perturb, + source_image, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + use_log10: bool = False, +): + """6-panel subplot showing lensed images and residuals from a perturbed tracer.""" + from autolens.lens.tracer_util import critical_curves_from, caustics_from + from autolens.plot.plot_utils import _to_lines + + grid = aa.Grid2D.from_mask(mask=mask) + + image = tracer_perturb.image_2d_from(grid=grid) + lensed_source_image = tracer_perturb.image_2d_via_input_plane_image_from( + grid=grid, plane_image=source_image + ) + lensed_source_image_no_perturb = tracer_no_perturb.image_2d_via_input_plane_image_from( + grid=grid, plane_image=source_image + ) + + unmasked_grid = mask.derive_grid.unmasked + + try: + tan_cc_p, rad_cc_p = critical_curves_from(tracer=tracer_perturb, grid=unmasked_grid) + perturb_cc_lines = _to_lines(list(tan_cc_p), list(rad_cc_p)) + except Exception: + perturb_cc_lines = None + + try: + tan_ca_p, rad_ca_p = caustics_from(tracer=tracer_perturb, grid=unmasked_grid) + perturb_ca_lines = _to_lines(list(tan_ca_p), list(rad_ca_p)) + except Exception: + perturb_ca_lines = None + + try: + tan_cc_n, rad_cc_n = critical_curves_from(tracer=tracer_no_perturb, grid=unmasked_grid) + no_perturb_cc_lines = _to_lines(list(tan_cc_n), list(rad_cc_n)) + except Exception: + no_perturb_cc_lines = None + + residual_map = lensed_source_image - lensed_source_image_no_perturb + + fig, axes = plt.subplots(1, 6, figsize=(42, 7)) + + plot_array(array=image, ax=axes[0], title="Image", + colormap=colormap, use_log10=use_log10) + plot_array(array=lensed_source_image, ax=axes[1], title="Lensed Source Image", + colormap=colormap, use_log10=use_log10, lines=perturb_cc_lines) + plot_array(array=source_image, ax=axes[2], title="Source Image", + colormap=colormap, use_log10=use_log10, lines=perturb_ca_lines) + plot_array(array=tracer_perturb.convergence_2d_from(grid=grid), ax=axes[3], + title="Convergence", colormap=colormap, use_log10=use_log10) + plot_array(array=lensed_source_image, ax=axes[4], + title="Lensed Source Image (No Subhalo)", + colormap=colormap, use_log10=use_log10, lines=no_perturb_cc_lines) + plot_array(array=residual_map, ax=axes[5], + title="Residual Map (Subhalo - No Subhalo)", + colormap=colormap, use_log10=use_log10, lines=no_perturb_cc_lines) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_lensed_images", format=output_format) + + +def subplot_sensitivity( + result, + data_subtracted, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + use_log10: bool = False, +): + """8-panel sensitivity subplot: log-likelihood/evidence maps and above-threshold map.""" + log_likelihoods = result.figure_of_merit_array( + use_log_evidences=False, + remove_zeros=True, + ) + + try: + log_evidences = result.figure_of_merit_array( + use_log_evidences=True, + remove_zeros=True, + ) + except TypeError: + log_evidences = np.zeros_like(log_likelihoods) + + above_threshold = np.where(log_likelihoods > 5.0, 1.0, 0.0) + above_threshold = aa.Array2D(values=above_threshold, mask=log_likelihoods.mask) + + fig, axes = plt.subplots(2, 4, figsize=(28, 14)) + axes_flat = list(axes.flatten()) + + plot_array(array=data_subtracted, ax=axes_flat[0], title="Subtracted Image", + colormap=colormap, use_log10=use_log10) + plot_array(array=log_evidences, ax=axes_flat[1], title="Increase in Log Evidence", + colormap=colormap) + plot_array(array=log_likelihoods, ax=axes_flat[2], title="Increase in Log Likelihood", + colormap=colormap) + plot_array(array=above_threshold, ax=axes_flat[3], title="Log Likelihood > 5.0", + colormap=colormap) + + ax_idx = 4 + try: + log_evidences_base = result._array_2d_from(result.log_evidences_base) + log_evidences_perturbed = result._array_2d_from(result.log_evidences_perturbed) + + base_vals = np.asarray(log_evidences_base) + perturb_vals = np.asarray(log_evidences_perturbed) + finite_base = base_vals[np.isfinite(base_vals) & (base_vals != 0)] + finite_perturb = perturb_vals[np.isfinite(perturb_vals) & (perturb_vals != 0)] + if len(finite_base) > 0 and len(finite_perturb) > 0: + vmin = float(np.min([np.min(finite_base), np.min(finite_perturb)])) + vmax = float(np.max([np.max(finite_base), np.max(finite_perturb)])) + else: + vmin = vmax = None + + plot_array(array=log_evidences_base, ax=axes_flat[ax_idx], + title="Log Evidence Base", colormap=colormap, vmin=vmin, vmax=vmax) + ax_idx += 1 + plot_array(array=log_evidences_perturbed, ax=axes_flat[ax_idx], + title="Log Evidence Perturb", colormap=colormap, vmin=vmin, vmax=vmax) + ax_idx += 1 + except (TypeError, AttributeError): + pass + + try: + log_likelihoods_base = result._array_2d_from(result.log_likelihoods_base) + log_likelihoods_perturbed = result._array_2d_from(result.log_likelihoods_perturbed) + + base_vals = np.asarray(log_likelihoods_base) + perturb_vals = np.asarray(log_likelihoods_perturbed) + finite_base = base_vals[np.isfinite(base_vals) & (base_vals != 0)] + finite_perturb = perturb_vals[np.isfinite(perturb_vals) & (perturb_vals != 0)] + if len(finite_base) > 0 and len(finite_perturb) > 0: + vmin = float(np.min([np.min(finite_base), np.min(finite_perturb)])) + vmax = float(np.max([np.max(finite_base), np.max(finite_perturb)])) + else: + vmin = vmax = None + + if ax_idx < len(axes_flat): + plot_array(array=log_likelihoods_base, ax=axes_flat[ax_idx], + title="Log Likelihood Base", colormap=colormap, vmin=vmin, vmax=vmax) + ax_idx += 1 + if ax_idx < len(axes_flat): + plot_array(array=log_likelihoods_perturbed, ax=axes_flat[ax_idx], + title="Log Likelihood Perturb", colormap=colormap, vmin=vmin, vmax=vmax) + except (TypeError, AttributeError): + pass + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_sensitivity", format=output_format) + + +def subplot_figures_of_merit_grid( + result, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + use_log_evidences: bool = True, + remove_zeros: bool = True, +): + """Single-panel subplot: the figures-of-merit grid for sensitivity mapping.""" + figures_of_merit = result.figure_of_merit_array( + use_log_evidences=use_log_evidences, + remove_zeros=remove_zeros, + ) + + fig, ax = plt.subplots(1, 1, figsize=(7, 7)) + plot_array(array=figures_of_merit, ax=ax, title="Increase in Log Evidence", + colormap=colormap) + plt.tight_layout() + save_figure(fig, path=output_path, filename="sensitivity", format=output_format) diff --git a/autolens/lens/plot/subhalo_plots.py b/autolens/lens/plot/subhalo_plots.py new file mode 100644 index 000000000..c1296babf --- /dev/null +++ b/autolens/lens/plot/subhalo_plots.py @@ -0,0 +1,104 @@ +"""Standalone subplot functions for subhalo detection visualisation.""" +import matplotlib.pyplot as plt +from typing import Optional + +from autoarray.plot.plots.array import plot_array +from autoarray.plot.plots.utils import save_figure +from autolens.imaging.plot.fit_imaging_plots import _plot_source_plane + + +def subplot_detection_imaging( + result, + fit_imaging_with_subhalo, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + use_log10: bool = False, + use_log_evidences: bool = True, + relative_to_value: float = 0.0, + remove_zeros: bool = False, +): + """4-panel subplot: data, S/N map, log-evidence increase, subhalo mass grid.""" + fig, axes = plt.subplots(1, 4, figsize=(28, 7)) + + plot_array( + array=fit_imaging_with_subhalo.data, + ax=axes[0], + title="Data", + colormap=colormap, + use_log10=use_log10, + ) + plot_array( + array=fit_imaging_with_subhalo.signal_to_noise_map, + ax=axes[1], + title="Signal-To-Noise Map", + colormap=colormap, + use_log10=use_log10, + ) + + fom_array = result.figure_of_merit_array( + use_log_evidences=use_log_evidences, + relative_to_value=relative_to_value, + remove_zeros=remove_zeros, + ) + plot_array( + array=fom_array, + ax=axes[2], + title="Increase in Log Evidence", + colormap=colormap, + ) + + mass_array = result.subhalo_mass_array + plot_array( + array=mass_array, + ax=axes[3], + title="Subhalo Mass", + colormap=colormap, + ) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_detection_imaging", format=output_format) + + +def subplot_detection_fits( + fit_imaging_no_subhalo, + fit_imaging_with_subhalo, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", +): + """6-panel subplot comparing fits with and without a subhalo.""" + fig, axes = plt.subplots(2, 3, figsize=(21, 14)) + + plot_array( + array=fit_imaging_no_subhalo.normalized_residual_map, + ax=axes[0][0], + title="Normalized Residual Map (No Subhalo)", + colormap=colormap, + ) + plot_array( + array=fit_imaging_no_subhalo.chi_squared_map, + ax=axes[0][1], + title="Chi-Squared Map (No Subhalo)", + colormap=colormap, + ) + _plot_source_plane(fit_imaging_no_subhalo, axes[0][2], plane_index=1, + colormap=colormap) + + plot_array( + array=fit_imaging_with_subhalo.normalized_residual_map, + ax=axes[1][0], + title="Normalized Residual Map (With Subhalo)", + colormap=colormap, + ) + plot_array( + array=fit_imaging_with_subhalo.chi_squared_map, + ax=axes[1][1], + title="Chi-Squared Map (With Subhalo)", + colormap=colormap, + ) + _plot_source_plane(fit_imaging_with_subhalo, axes[1][2], plane_index=1, + colormap=colormap) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_detection_fits", format=output_format) diff --git a/autolens/lens/plot/tracer_plots.py b/autolens/lens/plot/tracer_plots.py new file mode 100644 index 000000000..baedd59ef --- /dev/null +++ b/autolens/lens/plot/tracer_plots.py @@ -0,0 +1,180 @@ +import matplotlib.pyplot as plt +import numpy as np +from typing import Optional, List + +import autoarray as aa +import autogalaxy as ag + +from autoarray.plot.plots.array import plot_array +from autoarray.plot.plots.utils import save_figure +from autolens.plot.plot_utils import ( + _to_lines, + _to_positions, + _critical_curves_from, + _caustics_from, +) + + +def subplot_tracer( + tracer, + grid: aa.type.Grid2DLike, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + use_log10: bool = False, + positions=None, +): + """Multi-panel subplot of the tracer: image, source images, and mass quantities. + + Panels (3x3 = 9 axes): + 0: full lensed image with critical curves + 1: source galaxy image (no caustics) + 2: source plane image (with caustics) + 3: lens galaxy image (log10) + 4: convergence (log10, with critical curves) + 5: potential (log10, with critical curves) + 6: deflections y (with critical curves) + 7: deflections x (with critical curves) + 8: magnification (with critical curves) + """ + from autogalaxy.operate.lens_calc import LensCalc + + final_plane_index = len(tracer.planes) - 1 + traced_grids = tracer.traced_grid_2d_list_from(grid=grid) + + tan_cc, rad_cc = _critical_curves_from(tracer, grid) + tan_ca, rad_ca = _caustics_from(tracer, grid) + image_plane_lines = _to_lines(tan_cc, rad_cc) + source_plane_lines = _to_lines(tan_ca, rad_ca) + pos_list = _to_positions(positions) + + # --- compute arrays --- + image = tracer.image_2d_from(grid=grid) + + source_galaxies = ag.Galaxies(galaxies=tracer.planes[final_plane_index]) + source_image = source_galaxies.image_2d_from(grid=traced_grids[final_plane_index]) + + lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0]) + lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0]) + + convergence = tracer.convergence_2d_from(grid=grid) + potential = tracer.potential_2d_from(grid=grid) + + deflections = tracer.deflections_yx_2d_from(grid=grid) + deflections_y = aa.Array2D(values=deflections.slim[:, 0], mask=grid.mask) + deflections_x = aa.Array2D(values=deflections.slim[:, 1], mask=grid.mask) + + magnification = LensCalc.from_mass_obj(tracer).magnification_2d_from(grid=grid) + + fig, axes = plt.subplots(3, 3, figsize=(21, 21)) + axes_flat = list(axes.flatten()) + + plot_array(array=image, ax=axes_flat[0], title="Image", + lines=image_plane_lines, positions=pos_list, colormap=colormap, + use_log10=use_log10) + plot_array(array=source_image, ax=axes_flat[1], title="Source Image", + colormap=colormap, use_log10=use_log10) + plot_array(array=source_image, ax=axes_flat[2], title="Source Plane Image", + lines=source_plane_lines, colormap=colormap, use_log10=use_log10) + plot_array(array=lens_image, ax=axes_flat[3], title="Lens Image", + colormap=colormap, use_log10=use_log10) + plot_array(array=convergence, ax=axes_flat[4], title="Convergence", + lines=image_plane_lines, colormap=colormap, use_log10=use_log10) + plot_array(array=potential, ax=axes_flat[5], title="Potential", + lines=image_plane_lines, colormap=colormap, use_log10=use_log10) + plot_array(array=deflections_y, ax=axes_flat[6], title="Deflections Y", + lines=image_plane_lines, colormap=colormap) + plot_array(array=deflections_x, ax=axes_flat[7], title="Deflections X", + lines=image_plane_lines, colormap=colormap) + plot_array(array=magnification, ax=axes_flat[8], title="Magnification", + lines=image_plane_lines, colormap=colormap) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_tracer", format=output_format) + + +def subplot_lensed_images( + tracer, + grid: aa.type.Grid2DLike, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + use_log10: bool = False, +): + """One panel per plane showing the image of the galaxies in that plane.""" + traced_grids = tracer.traced_grid_2d_list_from(grid=grid) + n = tracer.total_planes + + fig, axes = plt.subplots(1, n, figsize=(7 * n, 7)) + axes_flat = [axes] if n == 1 else list(np.array(axes).flatten()) + + for plane_index in range(n): + galaxies = ag.Galaxies(galaxies=tracer.planes[plane_index]) + image = galaxies.image_2d_from(grid=traced_grids[plane_index]) + plot_array( + array=image, + ax=axes_flat[plane_index], + title=f"Image Of Plane {plane_index}", + colormap=colormap, + use_log10=use_log10, + ) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_lensed_images", format=output_format) + + +def subplot_galaxies_images( + tracer, + grid: aa.type.Grid2DLike, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: str = "jet", + use_log10: bool = False, +): + """Plane 0 image + for each plane > 0: lensed image + source plane image.""" + traced_grids = tracer.traced_grid_2d_list_from(grid=grid) + n = 2 * tracer.total_planes - 1 + + fig, axes = plt.subplots(1, n, figsize=(7 * n, 7)) + axes_flat = [axes] if n == 1 else list(np.array(axes).flatten()) + + idx = 0 + + lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0]) + lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0]) + plot_array( + array=lens_image, + ax=axes_flat[idx], + title="Image Of Plane 0", + colormap=colormap, + use_log10=use_log10, + ) + idx += 1 + + for plane_index in range(1, tracer.total_planes): + plane_galaxies = ag.Galaxies(galaxies=tracer.planes[plane_index]) + plane_grid = traced_grids[plane_index] + + image = plane_galaxies.image_2d_from(grid=plane_grid) + if idx < n: + plot_array( + array=image, + ax=axes_flat[idx], + title=f"Image Of Plane {plane_index}", + colormap=colormap, + use_log10=use_log10, + ) + idx += 1 + + if idx < n: + plot_array( + array=image, + ax=axes_flat[idx], + title=f"Plane Image Of Plane {plane_index}", + colormap=colormap, + use_log10=use_log10, + ) + idx += 1 + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_galaxies_images", format=output_format) diff --git a/autolens/lens/plot/tracer_plotters.py b/autolens/lens/plot/tracer_plotters.py deleted file mode 100644 index 7bf0ec3c4..000000000 --- a/autolens/lens/plot/tracer_plotters.py +++ /dev/null @@ -1,424 +0,0 @@ -from typing import Optional, List - -import autoarray as aa -import autogalaxy as ag -import autogalaxy.plot as aplt - -from autogalaxy.plot.mass_plotter import MassPlotter - -from autolens.plot.abstract_plotters import Plotter -from autolens.lens.tracer import Tracer - -from autolens import exc - -from autolens.lens import tracer_util - - -class TracerPlotter(Plotter): - def __init__( - self, - tracer: Tracer, - grid: aa.type.Grid2DLike, - mat_plot_1d: aplt.MatPlot1D = None, - visuals_1d: aplt.Visuals1D = None, - mat_plot_2d: aplt.MatPlot2D = None, - visuals_2d: aplt.Visuals2D = None, - visuals_2d_of_planes_list: Optional = None, - ): - """ - Plots the attributes of `Tracer` objects using the matplotlib methods `plot()` and `imshow()` and many - other matplotlib functions which customize the plot's appearance. - - The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, - the settings passed to every matplotlib function called are those specified in - the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2D` to - customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `MassProfile` and plotted via the visuals object. - - Parameters - ---------- - tracer - The tracer the plotter plots. - grid - The 2D (y,x) grid of coordinates used to evaluate the tracer's light and mass quantities that are plotted. - mat_plot_1d - Contains objects which wrap the matplotlib function calls that make 1D plots. - visuals_1d - Contains 1D visuals that can be overlaid on 1D plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - from autogalaxy.profiles.light.linear import ( - LightProfileLinear, - ) - - if tracer.has(cls=LightProfileLinear): - raise exc.raise_linear_light_profile_in_plot( - plotter_type=self.__class__.__name__, - ) - - super().__init__( - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - self.tracer = tracer - self.grid = grid - - self._mass_plotter = MassPlotter( - mass_obj=self.tracer, - grid=self.grid, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - ) - - self._visuals_2d_of_planes_list = visuals_2d_of_planes_list - - @property - def visuals_2d_of_planes_list(self): - - if self._visuals_2d_of_planes_list is None: - self._visuals_2d_of_planes_list = ( - tracer_util.visuals_2d_of_planes_list_from( - tracer=self.tracer, grid=self.grid - ) - ) - - return self._visuals_2d_of_planes_list - - def galaxies_plotter_from( - self, plane_index: int, retain_visuals: bool = False - ) -> aplt.GalaxiesPlotter: - """ - Returns an `GalaxiesPlotter` corresponding to a `Plane` in the `Tracer`. - - Returns - ------- - plane_index - The index of the plane in the `Tracer` used to make the `GalaxiesPlotter`. - """ - - plane_grid = self.tracer.traced_grid_2d_list_from(grid=self.grid)[plane_index] - - if retain_visuals: - - visuals_2d = self.visuals_2d - - else: - - visuals_2d = self.visuals_2d.add_critical_curves_or_caustics( - mass_obj=self.tracer, grid=self.grid, plane_index=plane_index - ) - - return aplt.GalaxiesPlotter( - galaxies=ag.Galaxies(galaxies=self.tracer.planes[plane_index]), - grid=plane_grid, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=visuals_2d, - ) - - def figures_2d( - self, - image: bool = False, - source_plane: bool = False, - convergence: bool = False, - potential: bool = False, - deflections_y: bool = False, - deflections_x: bool = False, - magnification: bool = False, - ): - """ - Plots the individual attributes of the plotter's `Tracer` object in 2D, which are computed via the plotter's 2D - grid object. - - The API is such that every plottable attribute of the `Tracer` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - image - Whether to make a 2D plot (via `imshow`) of the image of tracer in its image-plane (e.g. after - lensing). - source_plane - Whether to make a 2D plot (via `imshow`) of the image of the tracer in the source-plane (e.g. its - unlensed light). - convergence - Whether to make a 2D plot (via `imshow`) of the convergence. - potential - Whether to make a 2D plot (via `imshow`) of the potential. - deflections_y - Whether to make a 2D plot (via `imshow`) of the y component of the deflection angles. - deflections_x - Whether to make a 2D plot (via `imshow`) of the x component of the deflection angles. - magnification - Whether to make a 2D plot (via `imshow`) of the magnification. - """ - - if image: - self.mat_plot_2d.plot_array( - array=self.tracer.image_2d_from(grid=self.grid), - visuals_2d=self._mass_plotter.visuals_2d_with_critical_curves, - auto_labels=aplt.AutoLabels(title="Image", filename="image_2d"), - ) - - if source_plane: - self.figures_2d_of_planes( - plane_image=True, plane_index=len(self.tracer.planes) - 1 - ) - - self._mass_plotter.figures_2d( - convergence=convergence, - potential=potential, - deflections_y=deflections_y, - deflections_x=deflections_x, - magnification=magnification, - ) - - def plane_indexes_from(self, plane_index: Optional[int]) -> List[int]: - """ - Returns a list of all indexes of the planes in the fit, which is iterated over in figures that plot - individual figures of each plane in a tracer. - - Parameters - ---------- - plane_index - A specific plane index which when input means that only a single plane index is returned. - - Returns - ------- - list - A list of galaxy indexes corresponding to planes in the plane. - """ - if plane_index is None: - return list(range(len(self.tracer.planes))) - return [plane_index] - - def figures_2d_of_planes( - self, - plane_image: bool = False, - plane_grid: bool = False, - plane_index: Optional[int] = None, - zoom_to_brightest: bool = True, - retain_visuals: bool = False, - ): - """ - Plots source-plane images (e.g. the unlensed light) each individual `Plane` in the plotter's `Tracer` in 2D, - which are computed via the plotter's 2D grid object. - - The API is such that every plottable attribute of the `Plane` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - plane_image - Whether to make a 2D plot (via `imshow`) of the image of the plane in the soure-plane (e.g. its - unlensed light). - plane_grid - Whether to make a 2D plot (via `scatter`) of the lensed (y,x) coordinates of the plane in the - source-plane. - plane_index - If input, plots for only a single plane based on its index in the tracer are created. - zoom_to_brightest - For images not in the image-plane (e.g. the `plane_image`), whether to automatically zoom the plot to - the brightest regions of the galaxies being plotted as opposed to the full extent of the grid. - """ - plane_indexes = self.plane_indexes_from(plane_index=plane_index) - - for plane_index in plane_indexes: - galaxies_plotter = self.galaxies_plotter_from( - plane_index=plane_index, retain_visuals=retain_visuals - ) - - if plane_index == 1: - source_plane_title = True - else: - source_plane_title = False - - if plane_image: - galaxies_plotter.figures_2d( - plane_image=True, - zoom_to_brightest=zoom_to_brightest, - title_suffix=f" Of Plane {plane_index}", - filename_suffix=f"_of_plane_{plane_index}", - source_plane_title=source_plane_title, - ) - - if plane_grid: - galaxies_plotter.figures_2d( - plane_grid=True, - title_suffix=f" Of Plane {plane_index}", - filename_suffix=f"_of_plane_{plane_index}", - source_plane_title=source_plane_title, - ) - - def subplot( - self, - image: bool = False, - source_plane: bool = False, - convergence: bool = False, - potential: bool = False, - deflections_y: bool = False, - deflections_x: bool = False, - magnification: bool = False, - auto_filename: str = "subplot_tracer", - ): - """ - Plots the individual attributes of the plotter's `Tracer` object in 2D on a subplot, which are computed via - the plotter's 2D grid object. - - The API is such that every plottable attribute of the `Tracer` object is an input parameter of type bool of - the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - image - Whether to include a 2D plot (via `imshow`) of the image of tracer in its image-plane (e.g. after - lensing). - source_plane - Whether to include a 2D plot (via `imshow`) of the image of the tracer in the source-plane (e.g. its - unlensed light). - convergence - Whether to include a 2D plot (via `imshow`) of the convergence. - potential - Whether to include a 2D plot (via `imshow`) of the potential. - deflections_y - Whether to include a 2D plot (via `imshow`) of the y component of the deflection angles. - deflections_x - Whether to include a 2D plot (via `imshow`) of the x component of the deflection angles. - magnification - Whether to include a 2D plot (via `imshow`) of the magnification. - auto_filename - The default filename of the output subplot if written to hard-disk. - """ - - self._subplot_custom_plot( - image=image, - source_plane=source_plane, - convergence=convergence, - potential=potential, - deflections_y=deflections_y, - deflections_x=deflections_x, - magnification=magnification, - auto_labels=aplt.AutoLabels(filename=auto_filename), - ) - - def subplot_tracer(self): - """ - Standard subplot of the attributes of the plotter's `Tracer` object. - """ - - final_plane_index = len(self.tracer.planes) - 1 - - use_log10_original = self.mat_plot_2d.use_log10 - - self.open_subplot_figure(number_subplots=9) - - self.figures_2d(image=True) - - self.set_title(label="Lensed Source Image") - - galaxies_plotter = self.galaxies_plotter_from(plane_index=final_plane_index) - - galaxies_plotter.visuals_2d.tangential_caustics = None - galaxies_plotter.visuals_2d.radial_caustics = None - - galaxies_plotter.figures_2d( - image=True, - ) - - self.set_title(label="Source Plane Image") - self.figures_2d(source_plane=True) - self.set_title(label=None) - - self._subplot_lens_and_mass() - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_tracer") - self.close_subplot_figure() - - self.mat_plot_2d.use_log10 = use_log10_original - - def _subplot_lens_and_mass(self): - - self.mat_plot_2d.use_log10 = True - - self.set_title(label="Lens Galaxy Image") - - self.figures_2d_of_planes( - plane_image=True, - plane_index=0, - zoom_to_brightest=False, - retain_visuals=True, - ) - - self.mat_plot_2d.subplot_index = 5 - - self.set_title(label=None) - self.figures_2d(convergence=True) - - self.figures_2d(potential=True) - - self.mat_plot_2d.use_log10 = False - - self.figures_2d(magnification=True) - self.figures_2d(deflections_y=True) - self.figures_2d(deflections_x=True) - - def subplot_lensed_images(self): - """ - Subplot of the lensed image of every plane. - - For example, for a 2 plane `Tracer`, this creates a subplot with 2 panels, one for the image-plane image - and one for the source-plane lensed image. If there are 3 planes, 3 panels are created, showing - images at each plane. - """ - number_subplots = self.tracer.total_planes - - self.open_subplot_figure(number_subplots=number_subplots) - - for plane_index in range(0, self.tracer.total_planes): - galaxies_plotter = self.galaxies_plotter_from(plane_index=plane_index) - galaxies_plotter.figures_2d( - image=True, title_suffix=f" Of Plane {plane_index}" - ) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"subplot_lensed_images" - ) - self.close_subplot_figure() - - def subplot_galaxies_images(self): - """ - Subplot of the image of every plane in its own plane. - - For example, for a 2 plane `Tracer`, this creates a subplot with 2 panels, one for the image-plane image - and one for the source-plane (e.g. unlensed) image. If there are 3 planes, 3 panels are created, showing - images at each plane. - """ - number_subplots = 2 * self.tracer.total_planes - 1 - - self.open_subplot_figure(number_subplots=number_subplots) - - galaxies_plotter = self.galaxies_plotter_from(plane_index=0) - galaxies_plotter.figures_2d(image=True, title_suffix=" Of Plane 0") - - self.mat_plot_2d.subplot_index += 1 - - for plane_index in range(1, self.tracer.total_planes): - galaxies_plotter = self.galaxies_plotter_from(plane_index=plane_index) - galaxies_plotter.figures_2d( - image=True, title_suffix=f" Of Plane {plane_index}" - ) - galaxies_plotter.figures_2d( - plane_image=True, title_suffix=f" Of Plane {plane_index}" - ) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"subplot_galaxies_images" - ) - self.close_subplot_figure() diff --git a/autolens/lens/sensitivity.py b/autolens/lens/sensitivity.py index 0c2169c8c..549174b6a 100644 --- a/autolens/lens/sensitivity.py +++ b/autolens/lens/sensitivity.py @@ -22,14 +22,8 @@ import autofit as af import autoarray as aa - -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.plot.auto_labels import AutoLabels - from autolens.lens.tracer import Tracer -import autolens.plot as aplt - class SubhaloSensitivityResult(SensitivityResult): def __init__( @@ -152,479 +146,3 @@ def figure_of_merit_array( return self._array_2d_from(values=figures_of_merits) - -class SubhaloSensitivityPlotter(AbstractPlotter): - def __init__( - self, - mask: Optional[aa.Mask2D] = None, - tracer_perturb: Optional[Tracer] = None, - tracer_no_perturb: Optional[Tracer] = None, - source_image: Optional[aa.Array2D] = None, - result: Optional[SubhaloSensitivityResult] = None, - data_subtracted: Optional[aa.Array2D] = None, - mat_plot_2d: aplt.MatPlot2D = None, - visuals_2d: aplt.Visuals2D = None, - ): - """ - Plots the simulated datasets and results of a sensitivity mapping analysis, where dark matter halos are used - to simulate many strong lens datasets which are fitted to quantify how detectable they are. - - The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, - the settings passed to every matplotlib function called are those specified in - the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2D` to - customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `MassProfile` and plotted via the visuals object. - - Parameters - ---------- - tracer - The tracer the plotter plots. - grid - The 2D (y,x) grid of coordinates used to evaluate the tracer's light and mass quantities that are plotted. - mat_plot_1d - Contains objects which wrap the matplotlib function calls that make 1D plots. - visuals_1d - Contains 1D visuals that can be overlaid on 1D plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.mask = mask - self.tracer_perturb = tracer_perturb - self.tracer_no_perturb = tracer_no_perturb - self.source_image = source_image - self.result = result - self.data_subtracted = data_subtracted - self.mat_plot_2d = mat_plot_2d - self.visuals_2d = visuals_2d - - def update_mat_plot_array_overlay(self, evidence_max): - evidence_half = evidence_max / 2.0 - - self.mat_plot_2d.array_overlay = aplt.ArrayOverlay( - alpha=0.6, vmin=0.0, vmax=evidence_max - ) - self.mat_plot_2d.colorbar = aplt.Colorbar( - manual_tick_values=[0.0, evidence_half, evidence_max], - manual_tick_labels=[ - 0.0, - np.round(evidence_half, 1), - np.round(evidence_max, 1), - ], - ) - - def subplot_tracer_images(self): - """ - Output the tracer images of the dataset simulated for sensitivity mapping as a .png subplot. - - This dataset corresponds to a single grid-cell on the sensitivity mapping grid and therefore will be output - many times over the entire sensitivity mapping grid. - - The subplot includes the overall image, the lens galaxy image, the lensed source galaxy image and the source - galaxy image interpolated to a uniform grid. - - Images are masked before visualization, so that they zoom in on the region of interest which is actually - fitted. - """ - - grid = aa.Grid2D.from_mask(mask=self.mask) - - image = self.tracer_perturb.image_2d_from(grid=grid) - lensed_source_image = self.tracer_perturb.image_2d_via_input_plane_image_from( - grid=grid, plane_image=self.source_image - ) - lensed_source_image_no_perturb = ( - self.tracer_no_perturb.image_2d_via_input_plane_image_from( - grid=grid, plane_image=self.source_image - ) - ) - - plotter = aplt.Array2DPlotter( - array=image, - mat_plot_2d=self.mat_plot_2d, - ) - plotter.open_subplot_figure(number_subplots=6) - plotter.set_title("Image") - plotter.figure_2d() - - grid = self.mask.derive_grid.unmasked - - visuals_2d = aplt.Visuals2D( - mask=self.mask, - tangential_critical_curves=self.tracer_perturb.tangential_critical_curve_list_from( - grid=grid - ), - radial_critical_curves=self.tracer_perturb.radial_critical_curve_list_from( - grid=grid - ), - ) - - plotter = aplt.Array2DPlotter( - array=lensed_source_image, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=visuals_2d, - ) - plotter.set_title("Lensed Source Image") - plotter.figure_2d() - - visuals_2d = aplt.Visuals2D( - mask=self.mask, - tangential_caustics=self.tracer_perturb.tangential_caustic_list_from( - grid=grid - ), - radial_caustics=self.tracer_perturb.radial_caustic_list_from(grid=grid), - ) - - plotter = aplt.Array2DPlotter( - array=self.source_image, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=visuals_2d, - ) - plotter.set_title("Source Image") - plotter.figure_2d() - - plotter = aplt.Array2DPlotter( - array=self.tracer_perturb.convergence_2d_from(grid=grid), - mat_plot_2d=self.mat_plot_2d, - ) - plotter.set_title("Convergence") - plotter.figure_2d() - - visuals_2d = aplt.Visuals2D( - mask=self.mask, - tangential_critical_curves=self.tracer_no_perturb.tangential_critical_curve_list_from( - grid=grid - ), - radial_critical_curves=self.tracer_no_perturb.radial_critical_curve_list_from( - grid=grid - ), - ) - - plotter = aplt.Array2DPlotter( - array=lensed_source_image, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=visuals_2d, - ) - plotter.set_title("Lensed Source Image (No Subhalo)") - plotter.figure_2d() - - residual_map = lensed_source_image - lensed_source_image_no_perturb - - plotter = aplt.Array2DPlotter( - array=residual_map, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=visuals_2d, - ) - plotter.set_title("Residual Map (Subhalo - No Subhalo)") - plotter.figure_2d() - - plotter.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"subplot_lensed_images" - ) - plotter.close_subplot_figure() - - def set_auto_filename( - self, filename: str, use_log_evidences: Optional[bool] = None - ) -> bool: - """ - If a subplot figure does not have an input filename, this function is used to set one automatically. - - The filename is appended with a string that describes the figure of merit plotted, which is either the - log evidence or log likelihood. - - Parameters - ---------- - filename - The filename of the figure, e.g. 'subhalo_mass' - use_log_evidences - If `True`, figures which overlay the goodness-of-fit merit use the `log_evidence`, if `False` the - `log_likelihood` if used. - - Returns - ------- - - """ - - if self.mat_plot_2d.output.filename is None: - if use_log_evidences is None: - figure_of_merit = "" - elif use_log_evidences: - figure_of_merit = "_log_evidence" - else: - figure_of_merit = "_log_likelihood" - - self.set_filename( - filename=f"{filename}{figure_of_merit}", - ) - - return True - - return False - - def sensitivity_to_fits(self): - log_likelihoods = self.result.figure_of_merit_array( - use_log_evidences=False, - remove_zeros=False, - ) - - mat_plot_2d = aplt.MatPlot2D( - output=aplt.Output( - path=self.mat_plot_2d.output.path, - filename="sensitivity_log_likelihood", - format="fits", - ) - ) - - mat_plot_2d.plot_array( - array=log_likelihoods, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(), - ) - - try: - log_evidences = self.result.figure_of_merit_array( - use_log_evidences=True, - remove_zeros=False, - ) - - mat_plot_2d = aplt.MatPlot2D( - output=aplt.Output( - path=self.mat_plot_2d.output.path, - filename="sensitivity_log_evidence", - format="fits", - ) - ) - - mat_plot_2d.plot_array( - array=log_evidences, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(), - ) - - except TypeError: - pass - - def subplot_sensitivity(self): - log_likelihoods = self.result.figure_of_merit_array( - use_log_evidences=False, - remove_zeros=True, - ) - - try: - log_evidences = self.result.figure_of_merit_array( - use_log_evidences=True, - remove_zeros=True, - ) - except TypeError: - log_evidences = np.zeros_like(log_likelihoods) - - self.open_subplot_figure(number_subplots=8, subplot_shape=(2, 4)) - - plotter = aplt.Array2DPlotter( - array=self.data_subtracted, - mat_plot_2d=self.mat_plot_2d, - ) - - plotter.figure_2d() - - self.mat_plot_2d.plot_array( - array=log_evidences, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Increase in Log Evidence"), - ) - - self.mat_plot_2d.plot_array( - array=log_likelihoods, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Increase in Log Likelihood"), - ) - - above_threshold = np.where(log_likelihoods > 5.0, 1.0, 0.0) - - above_threshold = aa.Array2D(values=above_threshold, mask=log_likelihoods.mask) - - self.mat_plot_2d.plot_array( - array=above_threshold, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Log Likelihood > 5.0"), - ) - - try: - log_evidences_base = self.result._array_2d_from( - self.result.log_evidences_base - ) - log_evidences_perturbed = self.result._array_2d_from( - self.result.log_evidences_perturbed - ) - - log_evidences_base_min = np.nanmin( - np.where(log_evidences_base == 0, np.nan, log_evidences_base) - ) - log_evidences_base_max = np.nanmax( - np.where(log_evidences_base == 0, np.nan, log_evidences_base) - ) - log_evidences_perturbed_min = np.nanmin( - np.where(log_evidences_perturbed == 0, np.nan, log_evidences_perturbed) - ) - log_evidences_perturbed_max = np.nanmax( - np.where(log_evidences_perturbed == 0, np.nan, log_evidences_perturbed) - ) - - self.mat_plot_2d.cmap.kwargs["vmin"] = np.min( - [log_evidences_base_min, log_evidences_perturbed_min] - ) - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max( - [log_evidences_base_max, log_evidences_perturbed_max] - ) - - self.mat_plot_2d.plot_array( - array=log_evidences_base, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Log Evidence Base"), - ) - - self.mat_plot_2d.plot_array( - array=log_evidences_perturbed, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Log Evidence Perturb"), - ) - except TypeError: - pass - - log_likelihoods_base = self.result._array_2d_from( - self.result.log_likelihoods_base - ) - log_likelihoods_perturbed = self.result._array_2d_from( - self.result.log_likelihoods_perturbed - ) - - log_likelihoods_base_min = np.nanmin( - np.where(log_likelihoods_base == 0, np.nan, log_likelihoods_base) - ) - log_likelihoods_base_max = np.nanmax( - np.where(log_likelihoods_base == 0, np.nan, log_likelihoods_base) - ) - log_likelihoods_perturbed_min = np.nanmin( - np.where(log_likelihoods_perturbed == 0, np.nan, log_likelihoods_perturbed) - ) - log_likelihoods_perturbed_max = np.nanmax( - np.where(log_likelihoods_perturbed == 0, np.nan, log_likelihoods_perturbed) - ) - - self.mat_plot_2d.cmap.kwargs["vmin"] = np.min( - [log_likelihoods_base_min, log_likelihoods_perturbed_min] - ) - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max( - [log_likelihoods_base_max, log_likelihoods_perturbed_max] - ) - - self.mat_plot_2d.plot_array( - array=log_likelihoods_base, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Log Likelihood Base"), - ) - - self.mat_plot_2d.plot_array( - array=log_likelihoods_perturbed, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Log Likelihood Perturb"), - ) - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_sensitivity") - - self.close_subplot_figure() - - def subplot_figures_of_merit_grid( - self, - use_log_evidences: bool = True, - remove_zeros: bool = True, - show_max_in_title: bool = True, - ): - self.open_subplot_figure(number_subplots=1) - - figures_of_merit = self.result.figure_of_merit_array( - use_log_evidences=use_log_evidences, - remove_zeros=remove_zeros, - ) - - if show_max_in_title: - max_value = np.round(np.nanmax(figures_of_merit), 2) - self.set_title(label=f"Sensitivity Map {max_value}") - - self.update_mat_plot_array_overlay(evidence_max=np.max(figures_of_merit)) - - self.mat_plot_2d.plot_array( - array=figures_of_merit, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Increase in Log Evidence"), - ) - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="sensitivity") - self.close_subplot_figure() - - def figure_figures_of_merit_grid( - self, - use_log_evidences: bool = True, - remove_zeros: bool = True, - show_max_in_title: bool = True, - ): - """ - Plot the results of the subhalo grid search, where the figures of merit (e.g. `log_evidence`) of the - grid search are plotted over the image of the lensed source galaxy. - - The figures of merit can be customized to be relative to the lens model without a subhalo, or with zeros - rounded up to 0.0 to remove negative values. These produce easily to interpret and visually appealing - figure of merit overlays. - - Parameters - ---------- - use_log_evidences - If `True`, figures which overlay the goodness-of-fit merit use the `log_evidence`, if `False` the - `log_likelihood` if used. - relative_to_value - The value to subtract from every figure of merit, for example which will typically be that of the no - subhalo lens model so Bayesian model comparison can be easily performed. - remove_zeros - If `True`, the figure of merit array is altered so that all values below 0.0 and set to 0.0. For plotting - relative figures of merit for Bayesian model comparison, this is convenient to remove negative values - and produce a clearer visualization of the overlay. - show_max_in_title - Shows the maximum figure of merit value in the title of the figure, for easy reference. - """ - - reset_filename = self.set_auto_filename( - filename="sensitivity", - use_log_evidences=use_log_evidences, - ) - - array_overlay = self.result.figure_of_merit_array( - use_log_evidences=use_log_evidences, - remove_zeros=remove_zeros, - ) - - visuals_2d = self.visuals_2d + self.visuals_2d.__class__( - array_overlay=array_overlay, - ) - - self.update_mat_plot_array_overlay(evidence_max=np.max(array_overlay)) - - plotter = aplt.Array2DPlotter( - array=self.data_subtracted, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=visuals_2d, - ) - - if show_max_in_title: - max_value = np.round(np.nanmax(array_overlay), 2) - plotter.set_title(label=f"Sensitivity Map {max_value}") - - plotter.figure_2d() - - if reset_filename: - self.set_filename(filename=None) diff --git a/autolens/lens/subhalo.py b/autolens/lens/subhalo.py index c916d8400..759af930c 100644 --- a/autolens/lens/subhalo.py +++ b/autolens/lens/subhalo.py @@ -18,12 +18,6 @@ import autofit as af import autoarray as aa -import autogalaxy.plot as aplt - -from autoarray.plot.abstract_plotters import AbstractPlotter - -from autolens.imaging.fit_imaging import FitImaging -from autolens.imaging.plot.fit_imaging_plotters import FitImagingPlotter class SubhaloGridSearchResult(af.GridSearchResult): @@ -180,347 +174,3 @@ def subhalo_centres_grid(self) -> aa.Grid2D: pixel_scales=self.physical_step_sizes, shape_native=self.shape, ) - - -class SubhaloPlotter(AbstractPlotter): - def __init__( - self, - result: Optional[SubhaloGridSearchResult] = None, - fit_imaging_with_subhalo: Optional[FitImaging] = None, - fit_imaging_no_subhalo: Optional[FitImaging] = None, - mat_plot_2d: aplt.MatPlot2D = None, - visuals_2d: aplt.Visuals2D = None, - ): - """ - Plots the results of scanning for a dark matter subhalo in strong lens imaging. - - This produces the following style of plots: - - - Grid Overlay: The subhalo grid search of non-linear searches fits lens models where the (y,x) coordinates of - each DM subhalo are confined to a small region of the image plane via uniform priors. Corresponding plots - overlay the grid of results (e.g. the log evidence increase, subhalo mass) over the images of the fit. This - provides spatial information of where DM subhalos are detected. - - - Comparison Plots: Plots comparing the results of the model-fit with and without a subhalo, including the - best-fit lens model, residuals. This illuminates how the inclusion of a subhalo impacts the fit and why the - DM subhalo is inferred. - - Parameters - ---------- - result - The results of a grid search of non-linear searches where each DM subhalo's (y,x) coordinates are - confined to a small region of the image plane via uniform priors. - fit_imaging_with_subhalo - The `FitImaging` of the model-fit for the lens model with a subhalo (the `subhalo[3]` search in template - SLaM pipelines). - fit_imaging_no_subhalo - The `FitImaging` of the model-fit for the lens model without a subhalo (the `subhalo[1]` search in - template SLaM pipelines). - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.result = result - - self.fit_imaging_with_subhalo = fit_imaging_with_subhalo - self.fit_imaging_no_subhalo = fit_imaging_no_subhalo - - def update_mat_plot_array_overlay(self, evidence_max): - evidence_half = evidence_max / 2.0 - - self.mat_plot_2d.array_overlay = aplt.ArrayOverlay( - alpha=0.6, vmin=0.0, vmax=evidence_max - ) - self.mat_plot_2d.colorbar = aplt.Colorbar( - manual_tick_values=[0.0, evidence_half, evidence_max], - manual_tick_labels=[ - 0.0, - np.round(evidence_half, 1), - np.round(evidence_max, 1), - ], - ) - - @property - def fit_imaging_no_subhalo_plotter(self) -> FitImagingPlotter: - """ - The plotter which plots the results of the model-fit without a subhalo. - - This plot is used in figures such as `subplot_detection_fits` which compare the fits with and without a - subhalo. - """ - return FitImagingPlotter( - fit=self.fit_imaging_no_subhalo, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - ) - - @property - def fit_imaging_with_subhalo_plotter(self) -> FitImagingPlotter: - """ - The plotter which plots the results of the model-fit with a subhalo. - - This plot is used in figures such as `subplot_detection_fits` which compare the fits with and without a - subhalo, or `subplot_detection_imaging` which overlays subhalo grid search results over the image. - """ - return self.fit_imaging_with_subhalo_plotter_from(visuals_2d=self.visuals_2d) - - def fit_imaging_with_subhalo_plotter_from(self, visuals_2d) -> FitImagingPlotter: - """ - Returns a plotter of the model-fit with a subhalo, using a specific set of visuals. - - The input visuals are typically the overlay array of the grid search, so that the subhalo grid search results - can be plotted over the image. - - Parameters - ---------- - visuals_2d - The visuals that are plotted over the image of the fit, which are typically the results of the subhalo - grid search. - """ - return FitImagingPlotter( - fit=self.fit_imaging_with_subhalo, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=visuals_2d, - ) - - def set_auto_filename( - self, filename: str, use_log_evidences: Optional[bool] = None - ) -> bool: - """ - If a subplot figure does not have an input filename, this function is used to set one automatically. - - The filename is appended with a string that describes the figure of merit plotted, which is either the - log evidence or log likelihood. - - Parameters - ---------- - filename - The filename of the figure, e.g. 'subhalo_mass' - use_log_evidences - If `True`, figures which overlay the goodness-of-fit merit use the `log_evidence`, if `False` the - `log_likelihood` if used. - - Returns - ------- - - """ - - if self.mat_plot_2d.output.filename is None: - if use_log_evidences is None: - figure_of_merit = "" - elif use_log_evidences: - figure_of_merit = "_log_evidence" - else: - figure_of_merit = "_log_likelihood" - - self.set_filename( - filename=f"{filename}{figure_of_merit}", - ) - - return True - - return False - - def figure_figures_of_merit_grid( - self, - use_log_evidences: bool = True, - relative_to_value: float = 0.0, - remove_zeros: bool = True, - show_max_in_title: bool = True, - ): - """ - Plot the results of the subhalo grid search, where the figures of merit (e.g. `log_evidence`) of the - grid search are plotted over the image of the lensed source galaxy. - - The figures of merit can be customized to be relative to the lens model without a subhalo, or with zeros - rounded up to 0.0 to remove negative values. These produce easily to interpret and visually appealing - figure of merit overlays. - - Parameters - ---------- - use_log_evidences - If `True`, figures which overlay the goodness-of-fit merit use the `log_evidence`, if `False` the - `log_likelihood` if used. - relative_to_value - The value to subtract from every figure of merit, for example which will typically be that of the no - subhalo lens model so Bayesian model comparison can be easily performed. - remove_zeros - If `True`, the figure of merit array is altered so that all values below 0.0 and set to 0.0. For plotting - relative figures of merit for Bayesian model comparison, this is convenient to remove negative values - and produce a clearer visualization of the overlay. - show_max_in_title - Shows the maximum figure of merit value in the title of the figure, for easy reference. - """ - - reset_filename = self.set_auto_filename( - filename="subhalo_grid", - use_log_evidences=use_log_evidences, - ) - - array_overlay = self.result.figure_of_merit_array( - use_log_evidences=use_log_evidences, - relative_to_value=relative_to_value, - remove_zeros=remove_zeros, - ) - - try: - visuals_2d = self.visuals_2d + self.visuals_2d.__class__( - array_overlay=array_overlay, - mass_profile_centres=self.result.subhalo_centres_grid, - ) - except TypeError: - visuals_2d = self.visuals_2d - - self.update_mat_plot_array_overlay(evidence_max=np.max(array_overlay)) - - fit_plotter = self.fit_imaging_with_subhalo_plotter_from(visuals_2d=visuals_2d) - - if show_max_in_title: - max_value = np.round(np.nanmax(array_overlay), 2) - fit_plotter.set_title(label=f"Image {max_value}") - - try: - fit_plotter.figures_2d_of_planes(plane_index=-1, subtracted_image=True) - except AttributeError: - pass - - if reset_filename: - self.set_filename(filename=None) - - def figure_mass_grid(self): - """ - Plots the results of the subhalo grid search, where the subhalo mass of every grid search is plotted over - the image of the lensed source galaxy. - """ - - reset_filename = self.set_auto_filename( - filename="subhalo_mass", - ) - - array_overlay = self.result.subhalo_mass_array - - try: - visuals_2d = self.visuals_2d + self.visuals_2d.__class__( - array_overlay=array_overlay, - mass_profile_centres=self.result.subhalo_centres_grid, - ) - except TypeError: - visuals_2d = self.visuals_2d + self.visuals_2d.__class__( - array_overlay=array_overlay, - ) - - self.update_mat_plot_array_overlay(evidence_max=np.max(array_overlay)) - self.mat_plot_2d.colorbar.manual_log10 = True - - fit_plotter = self.fit_imaging_with_subhalo_plotter_from(visuals_2d=visuals_2d) - - try: - fit_plotter.figures_2d_of_planes(plane_index=-1, subtracted_image=True) - except AttributeError: - pass - - if reset_filename: - self.set_filename(filename=None) - - def subplot_detection_imaging( - self, - use_log_evidences: bool = True, - relative_to_value: float = 0.0, - remove_zeros: bool = False, - ): - """ - Plots a subplot showing the image, signal-to-noise-map, figures of merit and subhalo masses of the subhalo - grid search. - - The figures of merits are plotted as an array, which can be customized to be relative to the lens model without - a subhalo, or with zeros rounded up to 0.0 to remove negative values. These produce easily to interpret and - visually appealing figure of merit overlays. - - Parameters - ---------- - use_log_evidences - If `True`, figures which overlay the goodness-of-fit merit use the `log_evidence`, if `False` the - `log_likelihood` if used. - relative_to_value - The value to subtract from every figure of merit, for example which will typically be that of the no - subhalo lens model so Bayesian model comparison can be easily performed. - remove_zeros - If `True`, the figure of merit array is altered so that all values below 0.0 and set to 0.0. For plotting - relative figures of merit for Bayesian model comparison, this is convenient to remove negative values - and produce a clearer visualization of the overlay. - show_max_in_title - Shows the maximum figure of merit value in the title of the figure, for easy reference. - """ - self.open_subplot_figure(number_subplots=4) - - self.set_title("Image") - self.fit_imaging_with_subhalo_plotter.figures_2d(data=True) - - self.set_title("Signal-To-Noise Map") - self.fit_imaging_with_subhalo_plotter.figures_2d(signal_to_noise_map=True) - self.set_title(None) - - arr = self.result.figure_of_merit_array( - use_log_evidences=use_log_evidences, - relative_to_value=relative_to_value, - remove_zeros=remove_zeros, - ) - - self.mat_plot_2d.plot_array( - array=arr, - visuals_2d=self.visuals_2d, - auto_labels=aplt.AutoLabels(title="Increase in Log Evidence"), - ) - - arr = self.result.subhalo_mass_array - - self.mat_plot_2d.plot_array( - array=arr, - visuals_2d=self.visuals_2d, - auto_labels=aplt.AutoLabels(title="Subhalo Mass"), - ) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename="subplot_detection_imaging" - ) - self.close_subplot_figure() - - def subplot_detection_fits(self): - """ - Plots a subplot comparing the results of the best fit lens models with and without a subhalo. - - This subplot shows the normalized residuals, chi-squared map and source reconstructions of the model-fits - with and without a subhalo. - """ - - self.open_subplot_figure(number_subplots=6) - - self.set_title("Normalized Residuals (No Subhalo)") - self.fit_imaging_no_subhalo_plotter.figures_2d(normalized_residual_map=True) - - self.set_title("Chi-Squared Map (No Subhalo)") - self.fit_imaging_no_subhalo_plotter.figures_2d(chi_squared_map=True) - - self.set_title("Source Reconstruction (No Subhalo)") - self.fit_imaging_no_subhalo_plotter.figures_2d_of_planes( - plane_index=1, plane_image=True - ) - - self.set_title("Normailzed Residuals (With Subhalo)") - self.fit_imaging_with_subhalo_plotter.figures_2d(normalized_residual_map=True) - - self.set_title("Chi-Squared Map (With Subhalo)") - self.fit_imaging_with_subhalo_plotter.figures_2d(chi_squared_map=True) - - self.set_title("Source Reconstruction (With Subhalo)") - self.fit_imaging_with_subhalo_plotter.figures_2d_of_planes( - plane_index=1, plane_image=True - ) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename="subplot_detection_fits" - ) - self.close_subplot_figure() diff --git a/autolens/lens/tracer_util.py b/autolens/lens/tracer_util.py index aabb2fbc7..cd418e925 100644 --- a/autolens/lens/tracer_util.py +++ b/autolens/lens/tracer_util.py @@ -283,7 +283,7 @@ def time_delays_from( xp=np, cosmology: ag.cosmo.LensingCosmology = None, ) -> aa.type.Grid2DLike: - """ + r""" Returns the gravitational lensing time delay in days for a grid of 2D (y, x) coordinates. This function calculates the time delay at each image-plane position due to both geometric and gravitational @@ -388,15 +388,15 @@ def ordered_plane_redshifts_with_slicing_from( The source-plane redshift is removed from the ordered plane redshifts that are returned, so that galaxies are not \ planed at the source-plane redshift. - For example, if the main plane redshifts are [1.0, 2.0], and the bin sizes are [1,3], the following redshift \ - slices for planes will be used: + For example, if the main plane redshifts are [1.0, 2.0], and the bin sizes are [1,3], the following redshift + slices for planes will be used:: - z=0.5 - z=1.0 - z=1.25 - z=1.5 - z=1.75 - z=2.0 + z=0.5 + z=1.0 + z=1.25 + z=1.5 + z=1.75 + z=2.0 Parameters ---------- @@ -436,18 +436,80 @@ def ordered_plane_redshifts_with_slicing_from( return plane_redshifts[0:-1] -def visuals_2d_of_planes_list_from(tracer, grid) -> aplt.Visuals2D: +def critical_curves_from(tracer, grid): + """ + Compute tangential and radial critical curves for the tracer via LensCalc and + return them as plain lists of numpy arrays. + + Returns + ------- + tuple[list, list] + ``(tangential_critical_curves, radial_critical_curves)`` where each element + is a list of (N, 2) numpy arrays. *radial_critical_curves* may be an empty + list when the radial curve area is below the pixel-scale threshold. + """ + from autogalaxy.operate.lens_calc import LensCalc + import numpy as np + + od = LensCalc.from_mass_obj(tracer) + + tangential_critical_curves = od.tangential_critical_curve_list_from(grid=grid) - visuals_2d_of_planes_list = [] + radial_critical_curve_area_list = od.radial_critical_curve_area_list_from(grid=grid) + if any(area > grid.pixel_scale for area in radial_critical_curve_area_list): + radial_critical_curves = od.radial_critical_curve_list_from(grid=grid) + else: + radial_critical_curves = [] + return tangential_critical_curves, radial_critical_curves + + +def caustics_from(tracer, grid): + """ + Compute tangential and radial caustics for the tracer via LensCalc and + return them as plain lists of numpy arrays. + + Returns + ------- + tuple[list, list] + ``(tangential_caustics, radial_caustics)`` where each element is a list + of (N, 2) numpy arrays. + """ + from autogalaxy.operate.lens_calc import LensCalc + + od = LensCalc.from_mass_obj(tracer) + + tangential_caustics = od.tangential_caustic_list_from(grid=grid) + radial_caustics = od.radial_caustic_list_from(grid=grid) + + return tangential_caustics, radial_caustics + + +def lines_of_planes_from(tracer, grid): + """ + For each plane in the tracer return the appropriate line overlays: + - plane 0 (image plane): critical curves + - plane 1+ (source planes): caustics + + Returns + ------- + list[list[np.ndarray]] + One entry per plane; each entry is a (possibly empty) list of (N, 2) numpy + arrays suitable for passing as ``lines=`` to ``_plot_array``. + """ + tan_cc, rad_cc = critical_curves_from(tracer=tracer, grid=grid) + tan_ca, rad_ca = caustics_from(tracer=tracer, grid=grid) + + critical_curve_lines = list(tan_cc) + list(rad_cc) + caustic_lines = list(tan_ca) + list(rad_ca) + + lines_of_planes = [] for plane_index in range(len(tracer.planes)): + if plane_index == 0: + lines_of_planes.append(critical_curve_lines) + else: + lines_of_planes.append(caustic_lines) + + return lines_of_planes - visuals_2d_of_planes_list.append( - aplt.Visuals2D().add_critical_curves_or_caustics( - mass_obj=tracer, - grid=grid, - plane_index=plane_index, - ) - ) - return visuals_2d_of_planes_list diff --git a/autolens/plot/__init__.py b/autolens/plot/__init__.py index 78e5ff5c1..961f22cc8 100644 --- a/autolens/plot/__init__.py +++ b/autolens/plot/__init__.py @@ -1,93 +1,56 @@ -from autofit.non_linear.plot.nest_plotters import NestPlotter -from autofit.non_linear.plot.mcmc_plotters import MCMCPlotter -from autofit.non_linear.plot.mle_plotters import MLEPlotter - -from autoarray.plot.wrap.base import ( - Units, - Figure, - Axis, - Cmap, - Colorbar, - ColorbarTickParams, - TickParams, - YTicks, - XTicks, - Title, - YLabel, - XLabel, - Legend, - Annotate, - Text, - Output, -) -from autoarray.plot.wrap.one_d import YXPlot, FillBetween -from autoarray.plot.wrap.two_d import ( - ArrayOverlay, - Contour, - GridScatter, - GridPlot, - VectorYXQuiver, - PatchOverlay, - DelaunayDrawer, - OriginScatter, - MaskScatter, - BorderScatter, - PositionsScatter, - IndexScatter, - MeshGridScatter, - ParallelOverscanPlot, - SerialPrescanPlot, - SerialOverscanPlot, -) - -from autoarray.structures.plot.structure_plotters import Array2DPlotter -from autoarray.structures.plot.structure_plotters import Grid2DPlotter -from autoarray.inversion.plot.mapper_plotters import MapperPlotter -from autoarray.structures.plot.structure_plotters import YX1DPlotter -from autoarray.structures.plot.structure_plotters import YX1DPlotter as Array1DPlotter -from autoarray.inversion.plot.inversion_plotters import InversionPlotter -from autoarray.dataset.plot.imaging_plotters import ImagingPlotter -from autoarray.dataset.plot.interferometer_plotters import InterferometerPlotter - -from autoarray.plot.multi_plotters import MultiFigurePlotter -from autoarray.plot.multi_plotters import MultiYX1DPlotter - -from autogalaxy.plot.wrap import ( - HalfLightRadiusAXVLine, - EinsteinRadiusAXVLine, - LightProfileCentresScatter, - MassProfileCentresScatter, - TangentialCriticalCurvesPlot, - TangentialCausticsPlot, - RadialCriticalCurvesPlot, - RadialCausticsPlot, - MultipleImagesScatter, -) - -from autogalaxy.plot.mat_plot.one_d import MatPlot1D -from autogalaxy.plot.mat_plot.two_d import MatPlot2D -from autogalaxy.plot.visuals.one_d import Visuals1D -from autogalaxy.plot.visuals.two_d import Visuals2D - -from autogalaxy.profiles.plot.basis_plotters import BasisPlotter -from autogalaxy.profiles.plot.light_profile_plotters import LightProfilePlotter -from autogalaxy.profiles.plot.mass_profile_plotters import MassProfilePlotter -from autogalaxy.galaxy.plot.galaxy_plotters import GalaxyPlotter -from autogalaxy.quantity.plot.fit_quantity_plotters import FitQuantityPlotter - -from autogalaxy.imaging.plot.fit_imaging_plotters import FitImagingPlotter -from autogalaxy.interferometer.plot.fit_interferometer_plotters import ( - FitInterferometerPlotter, -) -from autogalaxy.galaxy.plot.galaxies_plotters import GalaxiesPlotter -from autogalaxy.galaxy.plot.adapt_plotters import AdaptPlotter - -from autolens.point.plot.point_dataset_plotters import PointDatasetPlotter -from autolens.imaging.plot.fit_imaging_plotters import FitImagingPlotter -from autolens.interferometer.plot.fit_interferometer_plotters import ( - FitInterferometerPlotter, -) -from autolens.point.plot.fit_point_plotters import FitPointDatasetPlotter -from autolens.lens.plot.tracer_plotters import TracerPlotter -from autolens.lens.subhalo import SubhaloPlotter -from autolens.lens.sensitivity import SubhaloSensitivityPlotter +from autofit.non_linear.plot.nest_plotters import NestPlotter +from autofit.non_linear.plot.mcmc_plotters import MCMCPlotter +from autofit.non_linear.plot.mle_plotters import MLEPlotter + +from autogalaxy.plot.wrap import ( + HalfLightRadiusAXVLine, + EinsteinRadiusAXVLine, + LightProfileCentresScatter, + MassProfileCentresScatter, + TangentialCriticalCurvesPlot, + TangentialCausticsPlot, + RadialCriticalCurvesPlot, + RadialCausticsPlot, + MultipleImagesScatter, +) + +# --------------------------------------------------------------------------- +# Standalone plot helpers +# --------------------------------------------------------------------------- +from autoarray.plot.plots.array import plot_array +from autoarray.plot.plots.grid import plot_grid + +# --------------------------------------------------------------------------- +# subplot_* public API +# --------------------------------------------------------------------------- +from autolens.lens.plot.tracer_plots import ( + subplot_tracer, + subplot_lensed_images, + subplot_galaxies_images, +) +from autolens.imaging.plot.fit_imaging_plots import ( + subplot_fit as subplot_fit_imaging, + subplot_fit_log10 as subplot_fit_imaging_log10, + subplot_fit_x1_plane as subplot_fit_imaging_x1_plane, + subplot_fit_log10_x1_plane as subplot_fit_imaging_log10_x1_plane, + subplot_of_planes as subplot_fit_imaging_of_planes, + subplot_tracer_from_fit as subplot_fit_imaging_tracer, + subplot_fit_combined, + subplot_fit_combined_log10, +) +from autolens.interferometer.plot.fit_interferometer_plots import ( + subplot_fit as subplot_fit_interferometer, + subplot_fit_real_space as subplot_fit_interferometer_real_space, +) +from autolens.point.plot.fit_point_plots import subplot_fit as subplot_fit_point +from autolens.point.plot.point_dataset_plots import subplot_dataset as subplot_point_dataset + +from autolens.lens.plot.subhalo_plots import ( + subplot_detection_imaging, + subplot_detection_fits, +) +from autolens.lens.plot.sensitivity_plots import ( + subplot_tracer_images as subplot_sensitivity_tracer_images, + subplot_sensitivity, + subplot_figures_of_merit_grid as subplot_sensitivity_figures_of_merit, +) diff --git a/autolens/plot/abstract_plotters.py b/autolens/plot/abstract_plotters.py deleted file mode 100644 index 622f2ea95..000000000 --- a/autolens/plot/abstract_plotters.py +++ /dev/null @@ -1,34 +0,0 @@ -from autoarray.plot.wrap.base.abstract import set_backend - -set_backend() - -from autogalaxy.plot.abstract_plotters import AbstractPlotter - -from autogalaxy.plot.mat_plot.one_d import MatPlot1D -from autogalaxy.plot.mat_plot.two_d import MatPlot2D -from autogalaxy.plot.visuals.one_d import Visuals1D -from autogalaxy.plot.visuals.two_d import Visuals2D - - -class Plotter(AbstractPlotter): - - def __init__( - self, - mat_plot_1d: MatPlot1D = None, - visuals_1d: Visuals1D = None, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - - super().__init__( - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - self.visuals_1d = visuals_1d or Visuals1D() - self.mat_plot_1d = mat_plot_1d or MatPlot1D() - - self.visuals_2d = visuals_2d or Visuals2D() - self.mat_plot_2d = mat_plot_2d or MatPlot2D() diff --git a/autolens/plot/plot_utils.py b/autolens/plot/plot_utils.py new file mode 100644 index 000000000..4727672ac --- /dev/null +++ b/autolens/plot/plot_utils.py @@ -0,0 +1,52 @@ +import numpy as np + + +def _to_lines(*items): + """Convert multiple line sources into a flat list of (N, 2) numpy arrays.""" + result = [] + for item in items: + if item is None: + continue + if isinstance(item, list): + for sub in item: + try: + arr = np.array(sub.array if hasattr(sub, "array") else sub) + if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: + result.append(arr) + except Exception: + pass + else: + try: + arr = np.array(item.array if hasattr(item, "array") else item) + if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: + result.append(arr) + except Exception: + pass + return result or None + + +def _to_positions(*items): + """Convert multiple position sources into a flat list of (N, 2) numpy arrays.""" + return _to_lines(*items) + + +def _critical_curves_from(tracer, grid): + """Return (tangential_critical_curves, radial_critical_curves) as lists of arrays.""" + from autolens.lens import tracer_util + + try: + tan_cc, rad_cc = tracer_util.critical_curves_from(tracer=tracer, grid=grid) + return list(tan_cc), list(rad_cc) + except Exception: + return [], [] + + +def _caustics_from(tracer, grid): + """Return (tangential_caustics, radial_caustics) as lists of arrays.""" + from autolens.lens import tracer_util + + try: + tan_ca, rad_ca = tracer_util.caustics_from(tracer=tracer, grid=grid) + return list(tan_ca), list(rad_ca) + except Exception: + return [], [] diff --git a/autolens/point/model/plotter_interface.py b/autolens/point/model/plotter_interface.py index f5f24bc2c..a5982c393 100644 --- a/autolens/point/model/plotter_interface.py +++ b/autolens/point/model/plotter_interface.py @@ -1,77 +1,55 @@ -from os import path - -from autolens.analysis.plotter_interface import PlotterInterface - -from autolens.point.fit.dataset import FitPointDataset -from autolens.point.plot.fit_point_plotters import FitPointDatasetPlotter -from autolens.point.dataset import PointDataset -from autolens.point.plot.point_dataset_plotters import PointDatasetPlotter - -from autolens.analysis.plotter_interface import plot_setting - - -class PlotterInterfacePoint(PlotterInterface): - def dataset_point(self, dataset: PointDataset): - """ - Output visualization of an `PointDataset` dataset, typically before a model-fit is performed. - - Images are output to the `image` folder of the `image_path`. When used with a non-linear search the `image_path` - is the output folder of the non-linear search. - - Visualization includes individual images of the different points of the dataset (e.g. the positions and fluxes) - - The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml` under - the `point_dataset` header. - - Parameters - ---------- - dataset - The imaging dataset which is visualized. - """ - - def should_plot(name): - return plot_setting(section=["point_dataset"], name=name) - - mat_plot_2d = self.mat_plot_2d_from() - - dataset_plotter = PointDatasetPlotter(dataset=dataset, mat_plot_2d=mat_plot_2d) - - if should_plot("subplot_dataset"): - dataset_plotter.subplot_dataset() - - def fit_point( - self, - fit: FitPointDataset, - quick_update: bool = False, - ): - """ - Visualizes a `FitPointDataset` object, which fits an imaging dataset. - - Images are output to the `image` folder of the `image_path` in a subfolder called `fit`. When - used with a non-linear search the `image_path` points to the search's results folder and this function - visualizes the maximum log likelihood `FitImaging` inferred by the search so far. - - Visualization includes a subplot of individual images of attributes of the `FitPointDataset` (e.g. the model - data and data) and .fits files containing its attributes grouped together. - - The images output by the `PlotterInterface` are customized using the file `config/visualize/plots.yaml` under - the `fit` and `fit_point_dataset` headers. - - Parameters - ---------- - fit - The maximum log likelihood `FitPointDataset` of the non-linear search which is used to plot the fit. - """ - - def should_plot(name): - return plot_setting(section=["fit", "fit_point_dataset"], name=name) - - mat_plot_2d = self.mat_plot_2d_from() - - fit_plotter = FitPointDatasetPlotter(fit=fit, mat_plot_2d=mat_plot_2d) - - if should_plot("subplot_fit") or quick_update: - fit_plotter.subplot_fit() - - if quick_update: - return +from autolens.analysis.plotter_interface import PlotterInterface + +from autolens.point.fit.dataset import FitPointDataset +from autolens.point.plot.fit_point_plots import subplot_fit as subplot_fit_point +from autolens.point.dataset import PointDataset +from autolens.point.plot.point_dataset_plots import subplot_dataset + +from autolens.analysis.plotter_interface import plot_setting + + +class PlotterInterfacePoint(PlotterInterface): + def dataset_point(self, dataset: PointDataset): + """ + Output visualization of a `PointDataset` dataset. + + Parameters + ---------- + dataset + The point dataset which is visualized. + """ + + def should_plot(name): + return plot_setting(section=["point_dataset"], name=name) + + output_path = str(self.image_path) + fmt = self.fmt + + if should_plot("subplot_dataset"): + subplot_dataset(dataset, output_path=output_path, output_format=fmt) + + def fit_point( + self, + fit: FitPointDataset, + quick_update: bool = False, + ): + """ + Visualizes a `FitPointDataset` object. + + Parameters + ---------- + fit + The maximum log likelihood `FitPointDataset` of the non-linear search. + """ + + def should_plot(name): + return plot_setting(section=["fit", "fit_point_dataset"], name=name) + + output_path = str(self.image_path) + fmt = self.fmt + + if should_plot("subplot_fit") or quick_update: + subplot_fit_point(fit, output_path=output_path, output_format=fmt) + + if quick_update: + return diff --git a/autolens/point/plot/fit_point_plots.py b/autolens/point/plot/fit_point_plots.py new file mode 100644 index 000000000..6dfa7eb79 --- /dev/null +++ b/autolens/point/plot/fit_point_plots.py @@ -0,0 +1,60 @@ +import matplotlib.pyplot as plt +import numpy as np +from typing import Optional + +from autoarray.plot.plots.utils import save_figure + + +def subplot_fit( + fit, + output_path: Optional[str] = None, + output_format: str = "png", +): + """Subplot of a FitPointDataset: positions panel and (optionally) fluxes panel.""" + from autoarray.plot.plots.grid import plot_grid + from autoarray.plot.plots.yx import plot_yx + + has_fluxes = fit.dataset.fluxes is not None + n = 2 if has_fluxes else 1 + + fig, axes = plt.subplots(1, n, figsize=(7 * n, 7)) + axes_flat = [axes] if n == 1 else list(np.array(axes).flatten()) + + # Positions panel + obs_grid = np.array( + fit.dataset.positions.array + if hasattr(fit.dataset.positions, "array") + else fit.dataset.positions + ) + model_grid = np.array( + fit.positions.model_data.array + if hasattr(fit.positions.model_data, "array") + else fit.positions.model_data + ) + + plot_grid( + grid=obs_grid, + ax=axes_flat[0], + title=f"{fit.dataset.name} Fit Positions", + output_path=None, + output_filename=None, + output_format=output_format, + ) + axes_flat[0].scatter(model_grid[:, 1], model_grid[:, 0], c="r", s=20, zorder=5) + + # Fluxes panel + if has_fluxes and n > 1: + y = np.array(fit.dataset.fluxes) + x = np.arange(len(y)) + plot_yx( + y=y, + x=x, + ax=axes_flat[1], + title=f"{fit.dataset.name} Fit Fluxes", + output_path=None, + output_filename="fit_point_fluxes", + output_format=output_format, + ) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_fit", format=output_format) diff --git a/autolens/point/plot/fit_point_plotters.py b/autolens/point/plot/fit_point_plotters.py deleted file mode 100644 index 4d7cb8c74..000000000 --- a/autolens/point/plot/fit_point_plotters.py +++ /dev/null @@ -1,164 +0,0 @@ -import autogalaxy.plot as aplt - -from autolens.plot.abstract_plotters import Plotter -from autolens.point.fit.dataset import FitPointDataset - - -class FitPointDatasetPlotter(Plotter): - def __init__( - self, - fit: FitPointDataset, - mat_plot_1d: aplt.MatPlot1D = None, - visuals_1d: aplt.Visuals1D = None, - mat_plot_2d: aplt.MatPlot2D = None, - visuals_2d: aplt.Visuals2D = None, - ): - """ - Plots the attributes of `FitPointDataset` objects using matplotlib methods and functions which customize the - plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to a point source dataset, which includes the data, model positions and other quantities which can - be plotted like the residual_map and chi-squared map. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make the plot. - visuals_2d - Contains visuals that can be overlaid on the plot. - """ - super().__init__( - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - self.fit = fit - - def figures_2d(self, positions: bool = False, fluxes: bool = False): - """ - Plots the individual attributes of the plotter's `FitPointDataset` object in 2D. - - The API is such that every plottable attribute of the `FitPointDataset` object is an input parameter of type - bool of the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - positions - If `True`, the dataset's positions are plotted on the figure compared to the model positions. - fluxes - If `True`, the dataset's fluxes are plotted on the figure compared to the model fluxes. - """ - if positions: - visuals_2d = self.visuals_2d - - visuals_2d += visuals_2d.__class__( - multiple_images=self.fit.positions.model_data - ) - - if self.mat_plot_2d.axis.kwargs.get("extent") is None: - buffer = 0.1 - - y_max = ( - max( - max(self.fit.dataset.positions[:, 0]), - max(self.fit.positions.model_data[:, 0]), - ) - + buffer - ) - y_min = ( - min( - min(self.fit.dataset.positions[:, 0]), - min(self.fit.positions.model_data[:, 0]), - ) - - buffer - ) - x_max = ( - max( - max(self.fit.dataset.positions[:, 1]), - max(self.fit.positions.model_data[:, 1]), - ) - + buffer - ) - x_min = ( - min( - min(self.fit.dataset.positions[:, 1]), - min(self.fit.positions.model_data[:, 1]), - ) - - buffer - ) - - extent = [y_min, y_max, x_min, x_max] - - self.mat_plot_2d.axis.kwargs["extent"] = extent - - self.mat_plot_2d.plot_grid( - grid=self.fit.dataset.positions, - y_errors=self.fit.dataset.positions_noise_map, - x_errors=self.fit.dataset.positions_noise_map, - visuals_2d=visuals_2d, - auto_labels=aplt.AutoLabels( - title=f"{self.fit.dataset.name} Fit Positions", - filename="fit_point_positions", - ), - buffer=0.1, - ) - - # nasty hack to ensure subplot index between 2d and 1d plots are syncs. Need a refactor that mvoes subplot - # functionality out of mat_plot and into plotter. - - if ( - self.mat_plot_1d.subplot_index is not None - and self.mat_plot_2d.subplot_index is not None - ): - self.mat_plot_1d.subplot_index = max( - self.mat_plot_1d.subplot_index, self.mat_plot_2d.subplot_index - ) - - if fluxes: - if self.fit.dataset.fluxes is not None: - visuals_1d = self.visuals_1d - - # Dataset may have flux but model may not - - try: - visuals_1d += visuals_1d.__class__( - model_fluxes=self.fit.flux.model_fluxes.array - ) - except AttributeError: - pass - - self.mat_plot_1d.plot_yx( - y=self.fit.dataset.fluxes, - y_errors=self.fit.dataset.fluxes_noise_map, - visuals_1d=visuals_1d, - auto_labels=aplt.AutoLabels( - title=f" {self.fit.dataset.name} Fit Fluxes", - filename="fit_point_fluxes", - xlabel="Point Number", - ), - plot_axis_type_override="errorbar", - ) - - def subplot( - self, - positions: bool = False, - fluxes: bool = False, - auto_filename: str = "subplot_fit", - ): - self._subplot_custom_plot( - positions=positions, - fluxes=fluxes, - auto_labels=aplt.AutoLabels(filename=auto_filename), - ) - - def subplot_fit(self): - self.subplot(positions=True, fluxes=True) diff --git a/autolens/point/plot/point_dataset_plots.py b/autolens/point/plot/point_dataset_plots.py new file mode 100644 index 000000000..45de0e8be --- /dev/null +++ b/autolens/point/plot/point_dataset_plots.py @@ -0,0 +1,52 @@ +import matplotlib.pyplot as plt +import numpy as np +from typing import Optional + +from autoarray.plot.plots.utils import save_figure + + +def subplot_dataset( + dataset, + output_path: Optional[str] = None, + output_format: str = "png", +): + """Subplot of a PointDataset: positions panel and (optionally) fluxes panel.""" + from autoarray.plot.plots.grid import plot_grid + from autoarray.plot.plots.yx import plot_yx + + has_fluxes = dataset.fluxes is not None + n = 2 if has_fluxes else 1 + + fig, axes = plt.subplots(1, n, figsize=(7 * n, 7)) + axes_flat = [axes] if n == 1 else list(np.array(axes).flatten()) + + grid = np.array( + dataset.positions.array + if hasattr(dataset.positions, "array") + else dataset.positions + ) + + plot_grid( + grid=grid, + ax=axes_flat[0], + title=f"{dataset.name} Positions", + output_path=None, + output_filename=None, + output_format=output_format, + ) + + if has_fluxes and n > 1: + y = np.array(dataset.fluxes) + x = np.arange(len(y)) + plot_yx( + y=y, + x=x, + ax=axes_flat[1], + title=f"{dataset.name} Fluxes", + output_path=None, + output_filename="point_dataset_fluxes", + output_format=output_format, + ) + + plt.tight_layout() + save_figure(fig, path=output_path, filename="subplot_dataset_point", format=output_format) diff --git a/autolens/point/plot/point_dataset_plotters.py b/autolens/point/plot/point_dataset_plotters.py deleted file mode 100644 index f781fc729..000000000 --- a/autolens/point/plot/point_dataset_plotters.py +++ /dev/null @@ -1,110 +0,0 @@ -import autogalaxy.plot as aplt - -from autolens.point.dataset import PointDataset -from autolens.plot.abstract_plotters import Plotter - - -class PointDatasetPlotter(Plotter): - def __init__( - self, - dataset: PointDataset, - mat_plot_1d: aplt.MatPlot1D = None, - visuals_1d: aplt.Visuals1D = None, - mat_plot_2d: aplt.MatPlot2D = None, - visuals_2d: aplt.Visuals2D = None, - ): - """ - Plots the attributes of `PointDataset` objects using the matplotlib methods and functions functions which - customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Imaging` and plotted via the visuals object. - - Parameters - ---------- - dataset - The imaging dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - super().__init__( - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - self.dataset = dataset - - def figures_2d(self, positions: bool = False, fluxes: bool = False): - """ - Plots the individual attributes of the plotter's `PointDataset` object in 2D. - - The API is such that every plottable attribute of the `Imaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - positions - If `True`, the dataset's positions are plotted on the figure compared to the model positions. - fluxes - If `True`, the dataset's fluxes are plotted on the figure compared to the model fluxes. - """ - if positions: - self.mat_plot_2d.plot_grid( - grid=self.dataset.positions, - y_errors=self.dataset.positions_noise_map, - x_errors=self.dataset.positions_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=aplt.AutoLabels( - title=f"{self.dataset.name} Positions", - filename="point_dataset_positions", - ), - buffer=0.1, - ) - - # nasty hack to ensure subplot index between 2d and 1d plots are syncs. Need a refactor that mvoes subplot - # functionality out of mat_plot and into plotter. - - if ( - self.mat_plot_1d.subplot_index is not None - and self.mat_plot_2d.subplot_index is not None - ): - self.mat_plot_1d.subplot_index = max( - self.mat_plot_1d.subplot_index, self.mat_plot_2d.subplot_index - ) - - if fluxes: - if self.dataset.fluxes is not None: - self.mat_plot_1d.plot_yx( - y=self.dataset.fluxes, - y_errors=self.dataset.fluxes_noise_map, - visuals_1d=self.visuals_1d, - auto_labels=aplt.AutoLabels( - title=f" {self.dataset.name} Fluxes", - filename="point_dataset_fluxes", - xlabel="Point Number", - ), - plot_axis_type_override="errorbar", - ) - - def subplot( - self, - positions: bool = False, - fluxes: bool = False, - auto_filename="subplot_dataset_point", - ): - self._subplot_custom_plot( - positions=positions, - fluxes=fluxes, - auto_labels=aplt.AutoLabels(filename=auto_filename), - ) - - def subplot_dataset(self): - self.subplot(positions=True, fluxes=True) diff --git a/patches/autoarray/0001-PR-A1-Add-autoarray-plot-plots-direct-matplotlib-mod.patch b/patches/autoarray/0001-PR-A1-Add-autoarray-plot-plots-direct-matplotlib-mod.patch new file mode 100644 index 000000000..4e6ae15ef --- /dev/null +++ b/patches/autoarray/0001-PR-A1-Add-autoarray-plot-plots-direct-matplotlib-mod.patch @@ -0,0 +1,864 @@ +From bf29f6dbd7ca37fc15c8aa9a8956b87e10627e1c Mon Sep 17 00:00:00 2001 +From: Claude +Date: Mon, 16 Mar 2026 14:08:58 +0000 +Subject: [PATCH 1/3] PR A1: Add autoarray/plot/plots/ direct-matplotlib module + +New standalone functions that replace MatPlot2D/MatPlot1D/MatWrap: + + plot_array() - imshow-based 2D array plot + plot_grid() - scatter/errorbar grid plot + plot_yx() - 1D y-vs-x line/errorbar plot + plot_inversion_reconstruction() - rectangular + Delaunay mapper plots + save_figure() - replaces Output class + conf_figsize() - reads figsize from general.yaml + apply_extent() - 3-tick linear axis limits helper + +ax=None creates its own figure; ax provided draws onto caller's axes. +Overlay data (lines, positions, mask) passed as plain list/array args. +3 linear ticks via np.linspace; log colorbars via LogNorm to imshow. +Purely additive - all existing tests continue to pass. + +https://claude.ai/code/session_01CzJBy8KvFXiNchoNdk5i9k +--- + autoarray/plot/__init__.py | 10 ++ + autoarray/plot/plots/__init__.py | 15 +++ + autoarray/plot/plots/array.py | 192 ++++++++++++++++++++++++++++++ + autoarray/plot/plots/grid.py | 156 ++++++++++++++++++++++++ + autoarray/plot/plots/inversion.py | 180 ++++++++++++++++++++++++++++ + autoarray/plot/plots/utils.py | 96 +++++++++++++++ + autoarray/plot/plots/yx.py | 131 ++++++++++++++++++++ + 7 files changed, 780 insertions(+) + create mode 100644 autoarray/plot/plots/__init__.py + create mode 100644 autoarray/plot/plots/array.py + create mode 100644 autoarray/plot/plots/grid.py + create mode 100644 autoarray/plot/plots/inversion.py + create mode 100644 autoarray/plot/plots/utils.py + create mode 100644 autoarray/plot/plots/yx.py + +diff --git a/autoarray/plot/__init__.py b/autoarray/plot/__init__.py +index 1ec4ff8a..6b967dab 100644 +--- a/autoarray/plot/__init__.py ++++ b/autoarray/plot/__init__.py +@@ -59,3 +59,13 @@ from autoarray.fit.plot.fit_interferometer_plotters import FitInterferometerPlot + + from autoarray.plot.multi_plotters import MultiFigurePlotter + from autoarray.plot.multi_plotters import MultiYX1DPlotter ++ ++from autoarray.plot.plots import ( ++ plot_array, ++ plot_grid, ++ plot_yx, ++ plot_inversion_reconstruction, ++ apply_extent, ++ conf_figsize, ++ save_figure, ++) +diff --git a/autoarray/plot/plots/__init__.py b/autoarray/plot/plots/__init__.py +new file mode 100644 +index 00000000..e8920734 +--- /dev/null ++++ b/autoarray/plot/plots/__init__.py +@@ -0,0 +1,15 @@ ++from autoarray.plot.plots.array import plot_array ++from autoarray.plot.plots.grid import plot_grid ++from autoarray.plot.plots.yx import plot_yx ++from autoarray.plot.plots.inversion import plot_inversion_reconstruction ++from autoarray.plot.plots.utils import apply_extent, conf_figsize, save_figure ++ ++__all__ = [ ++ "plot_array", ++ "plot_grid", ++ "plot_yx", ++ "plot_inversion_reconstruction", ++ "apply_extent", ++ "conf_figsize", ++ "save_figure", ++] +diff --git a/autoarray/plot/plots/array.py b/autoarray/plot/plots/array.py +new file mode 100644 +index 00000000..20f5f267 +--- /dev/null ++++ b/autoarray/plot/plots/array.py +@@ -0,0 +1,192 @@ ++""" ++Standalone function for plotting a 2D array (image) directly with matplotlib. ++ ++This replaces the ``MatPlot2D.plot_array`` / ``MatWrap`` system with a plain ++function whose defaults are ordinary Python parameter defaults rather than ++values loaded from YAML config files. ++""" ++import os ++from typing import List, Optional, Tuple ++ ++import matplotlib.pyplot as plt ++import numpy as np ++from matplotlib.colors import LogNorm, Normalize ++ ++from autoarray.plot.plots.utils import apply_extent, conf_figsize, save_figure ++ ++ ++def plot_array( ++ array: np.ndarray, ++ ax: Optional[plt.Axes] = None, ++ # --- spatial metadata ------------------------------------------------------- ++ extent: Optional[Tuple[float, float, float, float]] = None, ++ # --- overlays --------------------------------------------------------------- ++ mask: Optional[np.ndarray] = None, ++ grid: Optional[np.ndarray] = None, ++ positions: Optional[List[np.ndarray]] = None, ++ lines: Optional[List[np.ndarray]] = None, ++ vector_yx: Optional[np.ndarray] = None, ++ array_overlay: Optional[np.ndarray] = None, ++ # --- cosmetics -------------------------------------------------------------- ++ title: str = "", ++ xlabel: str = 'x (")', ++ ylabel: str = 'y (")', ++ colormap: str = "jet", ++ vmin: Optional[float] = None, ++ vmax: Optional[float] = None, ++ use_log10: bool = False, ++ aspect: str = "auto", ++ origin: str = "upper", ++ # --- figure control (used only when ax is None) ----------------------------- ++ figsize: Optional[Tuple[int, int]] = None, ++ output_path: Optional[str] = None, ++ output_filename: str = "array", ++ output_format: str = "png", ++) -> None: ++ """ ++ Plot a 2D array (image) using ``plt.imshow``. ++ ++ This is the direct-matplotlib replacement for ``MatPlot2D.plot_array``. ++ ++ Parameters ++ ---------- ++ array ++ 2D numpy array of pixel values. ++ ax ++ Existing matplotlib ``Axes`` to draw onto. If ``None`` a new figure ++ is created and saved / shown according to *output_path*. ++ extent ++ ``[xmin, xmax, ymin, ymax]`` spatial extent in data coordinates. ++ When ``None`` the array pixel indices are used by matplotlib. ++ mask ++ Array of shape ``(N, 2)`` with ``(y, x)`` coordinates of masked ++ pixels to overlay as black dots. ++ grid ++ Array of shape ``(N, 2)`` with ``(y, x)`` coordinates to scatter. ++ positions ++ List of ``(N, 2)`` arrays; each is scattered as a distinct group ++ of lensed image positions. ++ lines ++ List of ``(N, 2)`` arrays with ``(y, x)`` columns to plot as lines ++ (e.g. critical curves, caustics). ++ vector_yx ++ Array of shape ``(N, 4)`` — ``(y, x, vy, vx)`` — plotted as quiver ++ arrows. ++ array_overlay ++ A second 2D array rendered on top of *array* with partial alpha. ++ title ++ Figure title string. ++ xlabel, ylabel ++ Axis label strings. ++ colormap ++ Matplotlib colormap name. ++ vmin, vmax ++ Explicit color scale limits. When ``None`` the data range is used. ++ use_log10 ++ When ``True`` a ``LogNorm`` is applied. ++ aspect ++ Passed directly to ``imshow``. ++ origin ++ Passed directly to ``imshow`` (``"upper"`` or ``"lower"``). ++ figsize ++ Figure size in inches ``(width, height)``. Falls back to the value ++ in ``visualize/general.yaml`` when ``None``. ++ output_path ++ Directory to save the figure. When empty / ``None`` ``plt.show()`` ++ is called instead. ++ output_filename ++ Base file name (without extension). ++ output_format ++ File format, e.g. ``"png"``. ++ """ ++ if array is None or np.all(array == 0): ++ return ++ ++ owns_figure = ax is None ++ if owns_figure: ++ figsize = figsize or conf_figsize("figures") ++ fig, ax = plt.subplots(1, 1, figsize=figsize) ++ else: ++ fig = ax.get_figure() ++ ++ # --- colour normalisation -------------------------------------------------- ++ if use_log10: ++ from autoconf import conf as _conf ++ ++ try: ++ log10_min = _conf.instance["visualize"]["general"]["general"][ ++ "log10_min_value" ++ ] ++ except Exception: ++ log10_min = 1.0e-4 ++ ++ clipped = np.clip(array, log10_min, None) ++ norm = LogNorm(vmin=vmin or log10_min, vmax=vmax or clipped.max()) ++ elif vmin is not None or vmax is not None: ++ norm = Normalize(vmin=vmin, vmax=vmax) ++ else: ++ norm = None ++ ++ im = ax.imshow( ++ array, ++ cmap=colormap, ++ norm=norm, ++ extent=extent, ++ aspect=aspect, ++ origin=origin, ++ ) ++ ++ plt.colorbar(im, ax=ax) ++ ++ # --- overlays -------------------------------------------------------------- ++ if array_overlay is not None: ++ ax.imshow( ++ array_overlay, ++ cmap="Greys", ++ alpha=0.5, ++ extent=extent, ++ aspect=aspect, ++ origin=origin, ++ ) ++ ++ if mask is not None: ++ ax.scatter(mask[:, 1], mask[:, 0], s=1, c="k") ++ ++ if grid is not None: ++ ax.scatter(grid[:, 1], grid[:, 0], s=1, c="k") ++ ++ if positions is not None: ++ colors = ["r", "g", "b", "m", "c", "y"] ++ for i, pos in enumerate(positions): ++ ax.scatter(pos[:, 1], pos[:, 0], s=20, c=colors[i % len(colors)], zorder=5) ++ ++ if lines is not None: ++ for line in lines: ++ if line is not None and len(line) > 0: ++ ax.plot(line[:, 1], line[:, 0], linewidth=2) ++ ++ if vector_yx is not None: ++ ax.quiver( ++ vector_yx[:, 1], ++ vector_yx[:, 0], ++ vector_yx[:, 3], ++ vector_yx[:, 2], ++ ) ++ ++ # --- labels / ticks -------------------------------------------------------- ++ ax.set_title(title, fontsize=16) ++ ax.set_xlabel(xlabel, fontsize=14) ++ ax.set_ylabel(ylabel, fontsize=14) ++ ax.tick_params(labelsize=12) ++ ++ if extent is not None: ++ apply_extent(ax, extent) ++ ++ # --- output ---------------------------------------------------------------- ++ if owns_figure: ++ save_figure( ++ fig, ++ path=output_path or "", ++ filename=output_filename, ++ format=output_format, ++ ) +diff --git a/autoarray/plot/plots/grid.py b/autoarray/plot/plots/grid.py +new file mode 100644 +index 00000000..b310a5f7 +--- /dev/null ++++ b/autoarray/plot/plots/grid.py +@@ -0,0 +1,156 @@ ++""" ++Standalone function for plotting a 2D grid of (y, x) coordinates. ++ ++This replaces the ``MatPlot2D.plot_grid`` / ``MatWrap`` system. ++""" ++from typing import Iterable, List, Optional, Tuple ++ ++import matplotlib.pyplot as plt ++import numpy as np ++ ++from autoarray.plot.plots.utils import apply_extent, conf_figsize, save_figure ++ ++ ++def plot_grid( ++ grid: np.ndarray, ++ ax: Optional[plt.Axes] = None, ++ # --- errors ----------------------------------------------------------------- ++ y_errors: Optional[np.ndarray] = None, ++ x_errors: Optional[np.ndarray] = None, ++ # --- overlays --------------------------------------------------------------- ++ lines: Optional[Iterable[np.ndarray]] = None, ++ color_array: Optional[np.ndarray] = None, ++ # --- cosmetics -------------------------------------------------------------- ++ title: str = "", ++ xlabel: str = 'x (")', ++ ylabel: str = 'y (")', ++ colormap: str = "jet", ++ buffer: float = 0.1, ++ extent: Optional[Tuple[float, float, float, float]] = None, ++ force_symmetric_extent: bool = True, ++ # --- figure control (used only when ax is None) ----------------------------- ++ figsize: Optional[Tuple[int, int]] = None, ++ output_path: Optional[str] = None, ++ output_filename: str = "grid", ++ output_format: str = "png", ++) -> None: ++ """ ++ Plot a 2D grid of ``(y, x)`` coordinates as a scatter plot. ++ ++ This is the direct-matplotlib replacement for ``MatPlot2D.plot_grid``. ++ ++ Parameters ++ ---------- ++ grid ++ Array of shape ``(N, 2)``; column 0 is *y*, column 1 is *x*. ++ ax ++ Existing ``Axes`` to draw onto. ``None`` creates a new figure. ++ y_errors, x_errors ++ Per-point error values; when provided ``plt.errorbar`` is used. ++ lines ++ Iterable of ``(N, 2)`` arrays (y, x columns) drawn as lines. ++ color_array ++ 1D array of scalar values for colouring each point; triggers a ++ colorbar. ++ title ++ Figure title. ++ xlabel, ylabel ++ Axis labels. ++ colormap ++ Matplotlib colormap name. ++ buffer ++ Fractional padding for the auto-computed extent. The grid's ++ ``extent_with_buffer_from`` method is called when *extent* is ++ ``None`` and the grid object exposes that method. ++ extent ++ Manual axis limits ``[xmin, xmax, ymin, ymax]``. Auto-computed ++ when ``None``. ++ force_symmetric_extent ++ When ``True`` (and *extent* is auto-computed) the limits are made ++ symmetric about the origin so the plot is centred. ++ figsize ++ Figure size in inches ``(width, height)``. ++ output_path ++ Directory for saving. Empty string / ``None`` triggers ++ ``plt.show()``. ++ output_filename ++ Base file name without extension. ++ output_format ++ File format, e.g. ``"png"``. ++ """ ++ owns_figure = ax is None ++ if owns_figure: ++ figsize = figsize or conf_figsize("figures") ++ fig, ax = plt.subplots(1, 1, figsize=figsize) ++ else: ++ fig = ax.get_figure() ++ ++ # --- scatter / errorbar ---------------------------------------------------- ++ if color_array is not None: ++ cmap = plt.get_cmap(colormap) ++ colors = cmap((color_array - color_array.min()) / (color_array.ptp() or 1)) ++ ++ if y_errors is None and x_errors is None: ++ sc = ax.scatter(grid[:, 1], grid[:, 0], s=1, c=color_array, cmap=colormap) ++ else: ++ sc = ax.scatter(grid[:, 1], grid[:, 0], s=1, c=color_array, cmap=colormap) ++ ax.errorbar( ++ grid[:, 1], ++ grid[:, 0], ++ yerr=y_errors, ++ xerr=x_errors, ++ fmt="none", ++ ecolor=colors, ++ ) ++ ++ plt.colorbar(sc, ax=ax) ++ else: ++ if y_errors is None and x_errors is None: ++ ax.scatter(grid[:, 1], grid[:, 0], s=1, c="k") ++ else: ++ ax.errorbar( ++ grid[:, 1], ++ grid[:, 0], ++ yerr=y_errors, ++ xerr=x_errors, ++ fmt="o", ++ markersize=2, ++ color="k", ++ ) ++ ++ # --- line overlays --------------------------------------------------------- ++ if lines is not None: ++ for line in lines: ++ if line is not None and len(line) > 0: ++ ax.plot(line[:, 1], line[:, 0], linewidth=2) ++ ++ # --- labels ---------------------------------------------------------------- ++ ax.set_title(title, fontsize=16) ++ ax.set_xlabel(xlabel, fontsize=14) ++ ax.set_ylabel(ylabel, fontsize=14) ++ ax.tick_params(labelsize=12) ++ ++ # --- extent ---------------------------------------------------------------- ++ if extent is None: ++ try: ++ extent = grid.extent_with_buffer_from(buffer=buffer) ++ except AttributeError: ++ y_vals = grid[:, 0] ++ x_vals = grid[:, 1] ++ extent = [x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()] ++ ++ if force_symmetric_extent and extent is not None: ++ x_abs = max(abs(extent[0]), abs(extent[1])) ++ y_abs = max(abs(extent[2]), abs(extent[3])) ++ extent = [-x_abs, x_abs, -y_abs, y_abs] ++ ++ apply_extent(ax, extent) ++ ++ # --- output ---------------------------------------------------------------- ++ if owns_figure: ++ save_figure( ++ fig, ++ path=output_path or "", ++ filename=output_filename, ++ format=output_format, ++ ) +diff --git a/autoarray/plot/plots/inversion.py b/autoarray/plot/plots/inversion.py +new file mode 100644 +index 00000000..a58d9e89 +--- /dev/null ++++ b/autoarray/plot/plots/inversion.py +@@ -0,0 +1,180 @@ ++""" ++Standalone functions for plotting inversion / pixelization reconstructions. ++ ++Replaces the inversion-specific paths in ``MatPlot2D.plot_mapper``. ++""" ++from typing import List, Optional, Tuple ++ ++import matplotlib.pyplot as plt ++import numpy as np ++from matplotlib.colors import LogNorm, Normalize ++ ++from autoarray.plot.plots.utils import apply_extent, conf_figsize, save_figure ++ ++ ++def plot_inversion_reconstruction( ++ pixel_values: np.ndarray, ++ mapper, ++ ax: Optional[plt.Axes] = None, ++ # --- cosmetics -------------------------------------------------------------- ++ title: str = "Reconstruction", ++ xlabel: str = 'x (")', ++ ylabel: str = 'y (")', ++ colormap: str = "jet", ++ vmin: Optional[float] = None, ++ vmax: Optional[float] = None, ++ use_log10: bool = False, ++ zoom_to_brightest: bool = True, ++ # --- overlays --------------------------------------------------------------- ++ lines: Optional[List[np.ndarray]] = None, ++ grid: Optional[np.ndarray] = None, ++ # --- figure control (used only when ax is None) ----------------------------- ++ figsize: Optional[Tuple[int, int]] = None, ++ output_path: Optional[str] = None, ++ output_filename: str = "reconstruction", ++ output_format: str = "png", ++) -> None: ++ """ ++ Plot an inversion reconstruction using the appropriate mapper type. ++ ++ Chooses between rectangular (``imshow``/``pcolormesh``) and Delaunay ++ (``tripcolor``) rendering based on the mapper's interpolator type. ++ ++ Parameters ++ ---------- ++ pixel_values ++ 1D array of reconstructed flux values, one per source pixel. ++ mapper ++ Autoarray mapper object exposing ``interpolator``, ``mesh_geometry``, ++ ``source_plane_mesh_grid``, etc. ++ ax ++ Existing ``Axes``. ``None`` creates a new figure. ++ title, xlabel, ylabel ++ Text labels. ++ colormap ++ Matplotlib colormap name. ++ vmin, vmax ++ Explicit colour scale limits. ++ use_log10 ++ Apply ``LogNorm``. ++ zoom_to_brightest ++ Pass through to ``mapper.extent_from``. ++ lines ++ Line overlays (e.g. critical curves). ++ grid ++ Scatter overlay (e.g. data-plane grid). ++ figsize, output_path, output_filename, output_format ++ Figure output controls. ++ """ ++ from autoarray.inversion.mesh.interpolator.rectangular import ( ++ InterpolatorRectangular, ++ ) ++ from autoarray.inversion.mesh.interpolator.rectangular_uniform import ( ++ InterpolatorRectangularUniform, ++ ) ++ from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay ++ from autoarray.inversion.mesh.interpolator.knn import InterpolatorKNearestNeighbor ++ ++ owns_figure = ax is None ++ if owns_figure: ++ figsize = figsize or conf_figsize("figures") ++ fig, ax = plt.subplots(1, 1, figsize=figsize) ++ else: ++ fig = ax.get_figure() ++ ++ # --- colour normalisation -------------------------------------------------- ++ if use_log10: ++ norm = LogNorm(vmin=vmin or 1e-4, vmax=vmax) ++ elif vmin is not None or vmax is not None: ++ norm = Normalize(vmin=vmin, vmax=vmax) ++ else: ++ norm = None ++ ++ extent = mapper.extent_from(values=pixel_values, zoom_to_brightest=zoom_to_brightest) ++ ++ if isinstance(mapper.interpolator, (InterpolatorRectangular, InterpolatorRectangularUniform)): ++ _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent) ++ elif isinstance(mapper.interpolator, (InterpolatorDelaunay, InterpolatorKNearestNeighbor)): ++ _plot_delaunay(ax, pixel_values, mapper, norm, colormap) ++ ++ # --- overlays -------------------------------------------------------------- ++ if lines is not None: ++ for line in lines: ++ if line is not None and len(line) > 0: ++ ax.plot(line[:, 1], line[:, 0], linewidth=2) ++ ++ if grid is not None: ++ ax.scatter(grid[:, 1], grid[:, 0], s=1, c="w", alpha=0.5) ++ ++ apply_extent(ax, extent) ++ ++ ax.set_title(title, fontsize=16) ++ ax.set_xlabel(xlabel, fontsize=14) ++ ax.set_ylabel(ylabel, fontsize=14) ++ ax.tick_params(labelsize=12) ++ ++ if owns_figure: ++ save_figure( ++ fig, ++ path=output_path or "", ++ filename=output_filename, ++ format=output_format, ++ ) ++ ++ ++def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent): ++ """Render a rectangular mesh reconstruction with pcolormesh or imshow.""" ++ from autoarray.inversion.mesh.interpolator.rectangular_uniform import ( ++ InterpolatorRectangularUniform, ++ ) ++ import numpy as np ++ ++ shape_native = mapper.mesh_geometry.shape ++ ++ if isinstance(mapper.interpolator, InterpolatorRectangularUniform): ++ from autoarray.structures.arrays.uniform_2d import Array2D ++ from autoarray.structures.arrays import array_2d_util ++ ++ solution_array_2d = array_2d_util.array_2d_native_from( ++ array_2d_slim=pixel_values, ++ mask_2d=np.full(fill_value=False, shape=shape_native), ++ ) ++ pix_array = Array2D.no_mask( ++ values=solution_array_2d, ++ pixel_scales=mapper.mesh_geometry.pixel_scales, ++ origin=mapper.mesh_geometry.origin, ++ ) ++ ax.imshow( ++ pix_array.native.array, ++ cmap=colormap, ++ norm=norm, ++ extent=pix_array.geometry.extent, ++ aspect="auto", ++ origin="upper", ++ ) ++ else: ++ y_edges, x_edges = mapper.mesh_geometry.edges_transformed.T ++ Y, X = np.meshgrid(y_edges, x_edges, indexing="ij") ++ im = ax.pcolormesh( ++ X, Y, ++ pixel_values.reshape(shape_native), ++ shading="flat", ++ norm=norm, ++ cmap=colormap, ++ ) ++ plt.colorbar(im, ax=ax) ++ ++ ++def _plot_delaunay(ax, pixel_values, mapper, norm, colormap): ++ """Render a Delaunay mesh reconstruction with tripcolor.""" ++ mesh_grid = mapper.source_plane_mesh_grid ++ x = mesh_grid[:, 1] ++ y = mesh_grid[:, 0] ++ ++ if hasattr(pixel_values, "array"): ++ vals = pixel_values.array ++ else: ++ vals = pixel_values ++ ++ tc = ax.tripcolor(x, y, vals, cmap=colormap, norm=norm, shading="gouraud") ++ plt.colorbar(tc, ax=ax) +diff --git a/autoarray/plot/plots/utils.py b/autoarray/plot/plots/utils.py +new file mode 100644 +index 00000000..4d1cc4f0 +--- /dev/null ++++ b/autoarray/plot/plots/utils.py +@@ -0,0 +1,96 @@ ++""" ++Shared utilities for the direct-matplotlib plot functions. ++""" ++import logging ++import os ++from typing import Optional, Tuple ++ ++import matplotlib.pyplot as plt ++import numpy as np ++ ++logger = logging.getLogger(__name__) ++ ++ ++def conf_figsize(context: str = "figures") -> Tuple[int, int]: ++ """ ++ Read figsize from ``visualize/general.yaml`` for the given context. ++ ++ Parameters ++ ---------- ++ context ++ Either ``"figures"`` (single-panel) or ``"subplots"`` (multi-panel). ++ """ ++ try: ++ from autoconf import conf ++ ++ return tuple(conf.instance["visualize"]["general"][context]["figsize"]) ++ except Exception: ++ return (7, 7) if context == "figures" else (19, 16) ++ ++ ++def save_figure( ++ fig: plt.Figure, ++ path: str, ++ filename: str, ++ format: str = "png", ++ dpi: int = 300, ++) -> None: ++ """ ++ Save *fig* to ``/.`` then close it. ++ ++ If *path* is an empty string or ``None``, ``plt.show()`` is called instead. ++ After either action ``plt.close(fig)`` is always called to free memory. ++ ++ Parameters ++ ---------- ++ fig ++ The matplotlib figure to save. ++ path ++ Directory where the file is written. Created if it does not exist. ++ filename ++ File name without extension. ++ format ++ File format passed to ``fig.savefig`` (e.g. ``"png"``, ``"pdf"``). ++ dpi ++ Resolution in dots per inch. ++ """ ++ if path: ++ os.makedirs(path, exist_ok=True) ++ try: ++ fig.savefig( ++ os.path.join(path, f"{filename}.{format}"), ++ dpi=dpi, ++ bbox_inches="tight", ++ pad_inches=0.1, ++ ) ++ except Exception as exc: ++ logger.warning(f"save_figure: could not save {filename}.{format}: {exc}") ++ else: ++ plt.show() ++ ++ plt.close(fig) ++ ++ ++def apply_extent( ++ ax: plt.Axes, ++ extent: Tuple[float, float, float, float], ++ n_ticks: int = 3, ++) -> None: ++ """ ++ Apply axis limits and evenly spaced linear ticks to *ax*. ++ ++ Parameters ++ ---------- ++ ax ++ The matplotlib axes to configure. ++ extent ++ ``[xmin, xmax, ymin, ymax]`` limits. ++ n_ticks ++ Number of ticks on each axis. ``3`` produces ``[-R, 0, R]`` for ++ a symmetric extent, matching the reference ``plot_grid`` example. ++ """ ++ xmin, xmax, ymin, ymax = extent ++ ax.set_xlim(xmin, xmax) ++ ax.set_ylim(ymin, ymax) ++ ax.set_xticks(np.linspace(xmin, xmax, n_ticks)) ++ ax.set_yticks(np.linspace(ymin, ymax, n_ticks)) +diff --git a/autoarray/plot/plots/yx.py b/autoarray/plot/plots/yx.py +new file mode 100644 +index 00000000..383c75e5 +--- /dev/null ++++ b/autoarray/plot/plots/yx.py +@@ -0,0 +1,131 @@ ++""" ++Standalone function for plotting 1D y-vs-x data. ++ ++Replaces ``MatPlot1D.plot_yx`` / ``MatWrap`` system. ++""" ++from typing import List, Optional, Tuple ++ ++import matplotlib.pyplot as plt ++import numpy as np ++ ++from autoarray.plot.plots.utils import conf_figsize, save_figure ++ ++ ++def plot_yx( ++ y: np.ndarray, ++ x: Optional[np.ndarray] = None, ++ ax: Optional[plt.Axes] = None, ++ # --- errors / extras -------------------------------------------------------- ++ y_errors: Optional[np.ndarray] = None, ++ x_errors: Optional[np.ndarray] = None, ++ y_extra: Optional[np.ndarray] = None, ++ shaded_region: Optional[Tuple[np.ndarray, np.ndarray]] = None, ++ # --- cosmetics -------------------------------------------------------------- ++ title: str = "", ++ xlabel: str = "", ++ ylabel: str = "", ++ label: Optional[str] = None, ++ color: str = "b", ++ linestyle: str = "-", ++ plot_axis_type: str = "linear", ++ # --- figure control (used only when ax is None) ----------------------------- ++ figsize: Optional[Tuple[int, int]] = None, ++ output_path: Optional[str] = None, ++ output_filename: str = "yx", ++ output_format: str = "png", ++) -> None: ++ """ ++ Plot 1D y versus x data. ++ ++ Replaces ``MatPlot1D.plot_yx`` with direct matplotlib calls. ++ ++ Parameters ++ ---------- ++ y ++ 1D numpy array of y values. ++ x ++ 1D numpy array of x values. When ``None`` integer indices are used. ++ ax ++ Existing ``Axes`` to draw onto. ``None`` creates a new figure. ++ y_errors, x_errors ++ Per-point error values; trigger ``plt.errorbar``. ++ y_extra ++ Optional second y series to overlay. ++ shaded_region ++ Tuple ``(y1, y2)`` arrays; filled region drawn with alpha. ++ title ++ Figure title. ++ xlabel, ylabel ++ Axis labels. ++ label ++ Legend label for the main series. ++ color ++ Line / marker colour. ++ linestyle ++ Line style string. ++ plot_axis_type ++ One of ``"linear"``, ``"log"``, ``"loglog"``, ``"symlog"``. ++ figsize ++ Figure size in inches. ++ output_path ++ Directory for saving. Empty / ``None`` calls ``plt.show()``. ++ output_filename ++ Base file name without extension. ++ output_format ++ File format, e.g. ``"png"``. ++ """ ++ # guard: nothing to draw ++ if y is None or np.count_nonzero(y) == 0 or np.isnan(y).all(): ++ return ++ ++ owns_figure = ax is None ++ if owns_figure: ++ figsize = figsize or conf_figsize("figures") ++ fig, ax = plt.subplots(1, 1, figsize=figsize) ++ else: ++ fig = ax.get_figure() ++ ++ if x is None: ++ x = np.arange(len(y)) ++ ++ # --- main line / scatter --------------------------------------------------- ++ if y_errors is not None or x_errors is not None: ++ ax.errorbar( ++ x, y, yerr=y_errors, xerr=x_errors, ++ fmt="-o", color=color, label=label, markersize=3, ++ ) ++ elif plot_axis_type in ("log", "semilogy"): ++ ax.semilogy(x, y, color=color, linestyle=linestyle, label=label) ++ elif plot_axis_type == "loglog": ++ ax.loglog(x, y, color=color, linestyle=linestyle, label=label) ++ else: ++ ax.plot(x, y, color=color, linestyle=linestyle, label=label) ++ ++ if plot_axis_type == "symlog": ++ ax.set_yscale("symlog") ++ ++ # --- extras ---------------------------------------------------------------- ++ if y_extra is not None: ++ ax.plot(x, y_extra, color="r", linestyle="--", alpha=0.7) ++ ++ if shaded_region is not None: ++ y1, y2 = shaded_region ++ ax.fill_between(x, y1, y2, alpha=0.3) ++ ++ # --- labels ---------------------------------------------------------------- ++ ax.set_title(title, fontsize=16) ++ ax.set_xlabel(xlabel, fontsize=14) ++ ax.set_ylabel(ylabel, fontsize=14) ++ ax.tick_params(labelsize=12) ++ ++ if label is not None: ++ ax.legend(fontsize=12) ++ ++ # --- output ---------------------------------------------------------------- ++ if owns_figure: ++ save_figure( ++ fig, ++ path=output_path or "", ++ filename=output_filename, ++ format=output_format, ++ ) +-- +2.43.0 + diff --git a/patches/autoarray/0002-PR-A2-A3-Switch-all-autoarray-plotters-to-use-direct.patch b/patches/autoarray/0002-PR-A2-A3-Switch-all-autoarray-plotters-to-use-direct.patch new file mode 100644 index 000000000..d5736e18e --- /dev/null +++ b/patches/autoarray/0002-PR-A2-A3-Switch-all-autoarray-plotters-to-use-direct.patch @@ -0,0 +1,1033 @@ +From 9c1ab66b47873a906733f393d04e87ea25254c49 Mon Sep 17 00:00:00 2001 +From: Claude +Date: Mon, 16 Mar 2026 14:19:37 +0000 +Subject: [PATCH 2/3] PR A2+A3: Switch all autoarray plotters to use + direct-matplotlib functions + +Array2DPlotter, Grid2DPlotter, YX1DPlotter (structure_plotters.py): +- figure_2d() / figure_1d() now call plot_array() / plot_grid() / plot_yx() +- Subplot mode bridged: setup_subplot() positions the panel, ax passed to plot fn +- Overlay data extracted from Visuals2D to typed numpy arrays via helpers +- _output_for_mat_plot() bridges mat_plot.output to path/filename/format args + +ImagingPlotterMeta (imaging_plotters.py): +- figures_2d() uses new _plot_array() helper calling plot_array() directly +- subplot_dataset() unchanged: open/close subplot still via old mechanism + +MapperPlotter (mapper_plotters.py): +- figure_2d(), figure_2d_image(), plot_source_from() use plot_inversion_reconstruction() + and plot_array() respectively + +InversionPlotter (inversion_plotters.py): +- Added _plot_array() helper; all mat_plot_2d.plot_array() calls replaced + +All integration tests pass. + +https://claude.ai/code/session_01CzJBy8KvFXiNchoNdk5i9k +--- + autoarray/dataset/plot/imaging_plotters.py | 192 ++++-------- + .../inversion/plot/inversion_plotters.py | 113 +++---- + autoarray/inversion/plot/mapper_plotters.py | 157 +++++----- + .../structures/plot/structure_plotters.py | 277 +++++++++++------- + 4 files changed, 374 insertions(+), 365 deletions(-) + +diff --git a/autoarray/dataset/plot/imaging_plotters.py b/autoarray/dataset/plot/imaging_plotters.py +index ac1f23de..afc2cff5 100644 +--- a/autoarray/dataset/plot/imaging_plotters.py ++++ b/autoarray/dataset/plot/imaging_plotters.py +@@ -1,10 +1,19 @@ + import copy ++import numpy as np + from typing import Callable, Optional + + from autoarray.plot.visuals.two_d import Visuals2D + from autoarray.plot.mat_plot.two_d import MatPlot2D + from autoarray.plot.auto_labels import AutoLabels + from autoarray.plot.abstract_plotters import AbstractPlotter ++from autoarray.plot.plots.array import plot_array ++from autoarray.structures.plot.structure_plotters import ( ++ _lines_from_visuals, ++ _positions_from_visuals, ++ _mask_edge_from, ++ _grid_from_visuals, ++ _output_for_mat_plot, ++) + from autoarray.dataset.imaging.dataset import Imaging + + +@@ -15,35 +24,49 @@ class ImagingPlotterMeta(AbstractPlotter): + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, + ): +- """ +- Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib +- functions which customize the plot's appearance. +- +- The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings +- passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, +- but a user can manually input values into `MatPlot2d` to customize the figure's appearance. +- +- Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from +- the `Imaging` and plotted via the visuals object. +- +- Parameters +- ---------- +- dataset +- The imaging dataset the plotter plots. +- mat_plot_2d +- Contains objects which wrap the matplotlib function calls that make 2D plots. +- visuals_2d +- Contains 2D visuals that can be overlaid on 2D plots. +- """ +- + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) +- + self.dataset = dataset + + @property + def imaging(self): + return self.dataset + ++ def _plot_array(self, array, auto_filename: str, title: str, ax=None): ++ """Internal helper: plot an Array2D via plot_array().""" ++ if array is None: ++ return ++ ++ is_sub = self.mat_plot_2d.is_for_subplot ++ if ax is None: ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None ++ ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, is_sub, auto_filename ++ ) ++ ++ try: ++ arr = array.native.array ++ extent = array.geometry.extent ++ except AttributeError: ++ arr = np.asarray(array) ++ extent = None ++ ++ plot_array( ++ array=arr, ++ ax=ax, ++ extent=extent, ++ mask=_mask_edge_from(array if hasattr(array, "mask") else None, self.visuals_2d), ++ grid=_grid_from_visuals(self.visuals_2d), ++ positions=_positions_from_visuals(self.visuals_2d), ++ lines=_lines_from_visuals(self.visuals_2d), ++ title=title, ++ colormap=self.mat_plot_2d.cmap.cmap, ++ use_log10=self.mat_plot_2d.use_log10, ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, ++ ) ++ + def figures_2d( + self, + data: bool = False, +@@ -54,86 +77,47 @@ class ImagingPlotterMeta(AbstractPlotter): + over_sample_size_pixelization: bool = False, + title_str: Optional[str] = None, + ): +- """ +- Plots the individual attributes of the plotter's `Imaging` object in 2D. +- +- The API is such that every plottable attribute of the `Imaging` object is an input parameter of type bool of +- the function, which if switched to `True` means that it is plotted. +- +- Parameters +- ---------- +- data +- Whether to make a 2D plot (via `imshow`) of the image data. +- noise_map +- Whether to make a 2D plot (via `imshow`) of the noise map. +- psf +- Whether to make a 2D plot (via `imshow`) of the psf. +- signal_to_noise_map +- Whether to make a 2D plot (via `imshow`) of the signal-to-noise map. +- over_sample_size_lp +- Whether to make a 2D plot (via `imshow`) of the Over Sampling for input light profiles. If +- adaptive sub size is used, the sub size grid for a centre of (0.0, 0.0) is used. +- over_sample_size_pixelization +- Whether to make a 2D plot (via `imshow`) of the Over Sampling for pixelizations. +- """ +- + if data: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.dataset.data, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels(title=title_str or f" Data", filename="data"), ++ auto_filename="data", ++ title=title_str or "Data", + ) + + if noise_map: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.dataset.noise_map, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels(title_str or f"Noise-Map", filename="noise_map"), ++ auto_filename="noise_map", ++ title=title_str or "Noise-Map", + ) + + if psf: + if self.dataset.psf is not None: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.dataset.psf.kernel, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels( +- title=title_str or f"Point Spread Function", +- filename="psf", +- cb_unit="", +- ), ++ auto_filename="psf", ++ title=title_str or "Point Spread Function", + ) + + if signal_to_noise_map: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.dataset.signal_to_noise_map, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels( +- title=title_str or f"Signal-To-Noise Map", +- filename="signal_to_noise_map", +- cb_unit="", +- ), ++ auto_filename="signal_to_noise_map", ++ title=title_str or "Signal-To-Noise Map", + ) + + if over_sample_size_lp: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.dataset.grids.over_sample_size_lp, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels( +- title=title_str or f"Over Sample Size (Light Profiles)", +- filename="over_sample_size_lp", +- cb_unit="", +- ), ++ auto_filename="over_sample_size_lp", ++ title=title_str or "Over Sample Size (Light Profiles)", + ) + + if over_sample_size_pixelization: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.dataset.grids.over_sample_size_pixelization, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels( +- title=title_str or f"Over Sample Size (Pixelization)", +- filename="over_sample_size_pixelization", +- cb_unit="", +- ), ++ auto_filename="over_sample_size_pixelization", ++ title=title_str or "Over Sample Size (Pixelization)", + ) + + def subplot( +@@ -146,30 +130,6 @@ class ImagingPlotterMeta(AbstractPlotter): + over_sampling_pixelization: bool = False, + auto_filename: str = "subplot_dataset", + ): +- """ +- Plots the individual attributes of the plotter's `Imaging` object in 2D on a subplot. +- +- The API is such that every plottable attribute of the `Imaging` object is an input parameter of type bool of +- the function, which if switched to `True` means that it is included on the subplot. +- +- Parameters +- ---------- +- data +- Whether to include a 2D plot (via `imshow`) of the image data. +- noise_map +- Whether to include a 2D plot (via `imshow`) of the noise map. +- psf +- Whether to include a 2D plot (via `imshow`) of the psf. +- signal_to_noise_map +- Whether to include a 2D plot (via `imshow`) of the signal-to-noise map. +- over_sampling +- Whether to include a 2D plot (via `imshow`) of the Over Sampling. If adaptive sub size is used, the +- sub size grid for a centre of (0.0, 0.0) is used. +- over_sampling_pixelization +- Whether to include a 2D plot (via `imshow`) of the Over Sampling for pixelizations. +- auto_filename +- The default filename of the output subplot if written to hard-disk. +- """ + self._subplot_custom_plot( + data=data, + noise_map=noise_map, +@@ -181,9 +141,6 @@ class ImagingPlotterMeta(AbstractPlotter): + ) + + def subplot_dataset(self): +- """ +- Standard subplot of the attributes of the plotter's `Imaging` object. +- """ + use_log10_original = self.mat_plot_2d.use_log10 + + self.open_subplot_figure(number_subplots=9) +@@ -199,7 +156,6 @@ class ImagingPlotterMeta(AbstractPlotter): + self.mat_plot_2d.contour = contour_original + + self.figures_2d(noise_map=True) +- + self.figures_2d(psf=True) + + self.mat_plot_2d.use_log10 = True +@@ -207,7 +163,6 @@ class ImagingPlotterMeta(AbstractPlotter): + self.mat_plot_2d.use_log10 = False + + self.figures_2d(signal_to_noise_map=True) +- + self.figures_2d(over_sample_size_lp=True) + self.figures_2d(over_sample_size_pixelization=True) + +@@ -224,27 +179,6 @@ class ImagingPlotter(AbstractPlotter): + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, + ): +- """ +- Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib +- functions which customize the plot's appearance. +- +- The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings +- passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, +- but a user can manually input values into `MatPlot2d` to customize the figure's appearance. +- +- Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from +- the `Imaging` and plotted via the visuals object. +- +- Parameters +- ---------- +- imaging +- The imaging dataset the plotter plots. +- mat_plot_2d +- Contains objects which wrap the matplotlib function calls that make 2D plots. +- visuals_2d +- Contains 2D visuals that can be overlaid on 2D plots. +- """ +- + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) + + self.dataset = dataset +diff --git a/autoarray/inversion/plot/inversion_plotters.py b/autoarray/inversion/plot/inversion_plotters.py +index ef2ae6ea..586eabc7 100644 +--- a/autoarray/inversion/plot/inversion_plotters.py ++++ b/autoarray/inversion/plot/inversion_plotters.py +@@ -7,9 +7,17 @@ from autoarray.plot.abstract_plotters import AbstractPlotter + from autoarray.plot.visuals.two_d import Visuals2D + from autoarray.plot.mat_plot.two_d import MatPlot2D + from autoarray.plot.auto_labels import AutoLabels ++from autoarray.plot.plots.array import plot_array + from autoarray.structures.arrays.uniform_2d import Array2D + from autoarray.inversion.inversion.abstract import AbstractInversion + from autoarray.inversion.plot.mapper_plotters import MapperPlotter ++from autoarray.structures.plot.structure_plotters import ( ++ _lines_from_visuals, ++ _mask_edge_from, ++ _grid_from_visuals, ++ _positions_from_visuals, ++ _output_for_mat_plot, ++) + + + class InversionPlotter(AbstractPlotter): +@@ -66,35 +74,53 @@ class InversionPlotter(AbstractPlotter): + visuals_2d=self.visuals_2d, + ) + ++ def _plot_array(self, array, auto_filename: str, title: str): ++ """Helper: plot an Array2D using the new direct-matplotlib function.""" ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, is_sub, auto_filename ++ ) ++ try: ++ arr = array.native.array ++ extent = array.geometry.extent ++ mask_overlay = _mask_edge_from(array, self.visuals_2d) ++ except AttributeError: ++ arr = np.asarray(array) ++ extent = None ++ mask_overlay = None ++ plot_array( ++ array=arr, ++ ax=ax, ++ extent=extent, ++ mask=mask_overlay, ++ grid=_grid_from_visuals(self.visuals_2d), ++ positions=_positions_from_visuals(self.visuals_2d), ++ lines=_lines_from_visuals(self.visuals_2d), ++ title=title, ++ colormap=self.mat_plot_2d.cmap.cmap, ++ use_log10=self.mat_plot_2d.use_log10, ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, ++ ) ++ + def figures_2d(self, reconstructed_operated_data: bool = False): + """ + Plots the individual attributes of the plotter's `Inversion` object in 2D. +- +- The API is such that every plottable attribute of the `Inversion` object is an input parameter of type bool of +- the function, which if switched to `True` means that it is plotted. +- +- Parameters +- ---------- +- reconstructed_operated_data +- Whether to make a 2D plot (via `imshow`) of the reconstructed image data. + """ + if reconstructed_operated_data: + try: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.inversion.mapped_reconstructed_operated_data, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels( +- title="Reconstructed Image", +- filename="reconstructed_operated_data", +- ), ++ auto_filename="reconstructed_operated_data", ++ title="Reconstructed Image", + ) + except AttributeError: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.inversion.mapped_reconstructed_data, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels( +- title="Reconstructed Image", filename="reconstructed_data" +- ), ++ auto_filename="reconstructed_data", ++ title="Reconstructed Image", + ) + + def figures_2d_of_pixelization( +@@ -153,19 +179,13 @@ class InversionPlotter(AbstractPlotter): + mapper_plotter = self.mapper_plotter_from(mapper_index=pixelization_index) + + if data_subtracted: +- # Attribute error is cause this raises an error for interferometer inversion, because the data is +- # visibilities not an image. Update this to be handled better in future. +- ++ # Attribute error is raised for interferometer inversion where data is visibilities not an image. + try: + array = self.inversion.data_subtracted_dict[mapper_plotter.mapper] +- +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=array, +- visuals_2d=self.visuals_2d, +- grid_indexes=mapper_plotter.mapper.over_sampler.uniform_over_sampled, +- auto_labels=AutoLabels( +- title="Data Subtracted", filename="data_subtracted" +- ), ++ auto_filename="data_subtracted", ++ title="Data Subtracted", + ) + except AttributeError: + pass +@@ -182,13 +202,10 @@ class InversionPlotter(AbstractPlotter): + mapper_plotter.mapper + ] + +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=array, +- visuals_2d=self.visuals_2d, +- grid_indexes=mapper_plotter.mapper.over_sampler.uniform_over_sampled, +- auto_labels=AutoLabels( +- title="Reconstructed Image", filename="reconstructed_operated_data" +- ), ++ auto_filename="reconstructed_operated_data", ++ title="Reconstructed Image", + ) + + if reconstruction: +@@ -271,29 +288,19 @@ class InversionPlotter(AbstractPlotter): + values=mapper_plotter.mapper.over_sampler.sub_size, + mask=self.inversion.dataset.mask, + ) +- +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=sub_size, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels( +- title="Sub Pixels Per Image Pixels", +- filename="sub_pixels_per_image_pixels", +- ), ++ auto_filename="sub_pixels_per_image_pixels", ++ title="Sub Pixels Per Image Pixels", + ) + + if mesh_pixels_per_image_pixels: + try: +- mesh_pixels_per_image_pixels = ( +- mapper_plotter.mapper.mesh_pixels_per_image_pixels +- ) +- +- self.mat_plot_2d.plot_array( +- array=mesh_pixels_per_image_pixels, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels( +- title="Mesh Pixels Per Image Pixels", +- filename="mesh_pixels_per_image_pixels", +- ), ++ mesh_arr = mapper_plotter.mapper.mesh_pixels_per_image_pixels ++ self._plot_array( ++ array=mesh_arr, ++ auto_filename="mesh_pixels_per_image_pixels", ++ title="Mesh Pixels Per Image Pixels", + ) + except Exception: + pass +diff --git a/autoarray/inversion/plot/mapper_plotters.py b/autoarray/inversion/plot/mapper_plotters.py +index b9a44679..47617e1e 100644 +--- a/autoarray/inversion/plot/mapper_plotters.py ++++ b/autoarray/inversion/plot/mapper_plotters.py +@@ -1,12 +1,20 @@ + import numpy as np ++import logging + + from autoarray.plot.abstract_plotters import AbstractPlotter + from autoarray.plot.visuals.two_d import Visuals2D + from autoarray.plot.mat_plot.two_d import MatPlot2D + from autoarray.plot.auto_labels import AutoLabels ++from autoarray.plot.plots.inversion import plot_inversion_reconstruction ++from autoarray.plot.plots.array import plot_array + from autoarray.structures.arrays.uniform_2d import Array2D +- +-import logging ++from autoarray.structures.plot.structure_plotters import ( ++ _lines_from_visuals, ++ _positions_from_visuals, ++ _mask_edge_from, ++ _grid_from_visuals, ++ _output_for_mat_plot, ++) + + logger = logging.getLogger(__name__) + +@@ -18,80 +26,70 @@ class MapperPlotter(AbstractPlotter): + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, + ): +- """ +- Plots the attributes of `Mapper` objects using the matplotlib method `imshow()` and many other matplotlib +- functions which customize the plot's appearance. +- +- The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings +- passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, +- but a user can manually input values into `MatPlot2d` to customize the figure's appearance. +- +- Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from +- the `Mapper` and plotted via the visuals object. +- +- Parameters +- ---------- +- mapper +- The mapper the plotter plots. +- mat_plot_2d +- Contains objects which wrap the matplotlib function calls that make 2D plots. +- visuals_2d +- Contains 2D visuals that can be overlaid on 2D plots. +- """ + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) +- + self.mapper = mapper + +- def figure_2d(self, solution_vector: bool = None): +- """ +- Plots the plotter's `Mapper` object in 2D. ++ def figure_2d(self, solution_vector=None): ++ """Plot the mapper's source-plane reconstruction.""" ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None + +- Parameters +- ---------- +- solution_vector +- A vector of values which can culor the pixels of the mapper's source pixels. +- """ +- self.mat_plot_2d.plot_mapper( +- mapper=self.mapper, +- visuals_2d=self.visuals_2d, +- pixel_values=solution_vector, +- auto_labels=AutoLabels( +- title="Pixelization Mesh (Source-Plane)", filename="mapper" +- ), ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, is_sub, "mapper" + ) + ++ try: ++ plot_inversion_reconstruction( ++ pixel_values=solution_vector, ++ mapper=self.mapper, ++ ax=ax, ++ title="Pixelization Mesh (Source-Plane)", ++ colormap=self.mat_plot_2d.cmap.cmap, ++ use_log10=self.mat_plot_2d.use_log10, ++ lines=_lines_from_visuals(self.visuals_2d), ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, ++ ) ++ except Exception as exc: ++ logger.info( ++ f"Could not plot the source-plane via the Mapper: {exc}" ++ ) ++ + def figure_2d_image(self, image): ++ """Plot an image-plane representation of the mapper.""" ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None + +- self.mat_plot_2d.plot_array( +- array=image, +- visuals_2d=self.visuals_2d, +- grid_indexes=self.mapper.image_plane_data_grid.over_sampled, +- auto_labels=AutoLabels( +- title="Image (Image-Plane)", filename="mapper_image" +- ), ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, is_sub, "mapper_image" + ) + +- def subplot_image_and_mapper( +- self, +- image: Array2D, +- ): +- """ +- Make a subplot of an input image and the `Mapper`'s source-plane reconstruction. +- +- This function can include colored points that mark the mappings between the image pixels and their +- corresponding locations in the `Mapper` source-plane and reconstruction. This therefore visually illustrates +- the mapping process. ++ try: ++ arr = image.native.array ++ extent = image.geometry.extent ++ except AttributeError: ++ arr = np.asarray(image) ++ extent = None ++ ++ plot_array( ++ array=arr, ++ ax=ax, ++ extent=extent, ++ mask=_mask_edge_from(image if hasattr(image, "mask") else None, self.visuals_2d), ++ lines=_lines_from_visuals(self.visuals_2d), ++ title="Image (Image-Plane)", ++ colormap=self.mat_plot_2d.cmap.cmap, ++ use_log10=self.mat_plot_2d.use_log10, ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, ++ ) + +- Parameters +- ---------- +- image +- The image which is plotted on the subplot. +- """ ++ def subplot_image_and_mapper(self, image: Array2D): + self.open_subplot_figure(number_subplots=2) +- + self.figure_2d_image(image=image) + self.figure_2d() +- + self.mat_plot_2d.output.subplot_to_figure( + auto_filename="subplot_image_and_mapper" + ) +@@ -103,26 +101,29 @@ class MapperPlotter(AbstractPlotter): + zoom_to_brightest: bool = True, + auto_labels: AutoLabels = AutoLabels(), + ): +- """ +- Plot the source of the `Mapper` where the coloring is specified by an input set of values. ++ """Plot mapper source coloured by pixel_values.""" ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None ++ ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, ++ is_sub, ++ auto_labels.filename or "reconstruction", ++ ) + +- Parameters +- ---------- +- pixel_values +- The values of the mapper's source pixels used for coloring the figure. +- zoom_to_brightest +- For images not in the image-plane (e.g. the `plane_image`), whether to automatically zoom the plot to +- the brightest regions of the galaxies being plotted as opposed to the full extent of the grid. +- auto_labels +- The labels given to the figure. +- """ + try: +- self.mat_plot_2d.plot_mapper( +- mapper=self.mapper, +- visuals_2d=self.visuals_2d, +- auto_labels=auto_labels, ++ plot_inversion_reconstruction( + pixel_values=pixel_values, ++ mapper=self.mapper, ++ ax=ax, ++ title=auto_labels.title or "Source Reconstruction", ++ colormap=self.mat_plot_2d.cmap.cmap, ++ use_log10=self.mat_plot_2d.use_log10, + zoom_to_brightest=zoom_to_brightest, ++ lines=_lines_from_visuals(self.visuals_2d), ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, + ) + except ValueError: + logger.info( +diff --git a/autoarray/structures/plot/structure_plotters.py b/autoarray/structures/plot/structure_plotters.py +index 7e7cf655..e05c19f9 100644 +--- a/autoarray/structures/plot/structure_plotters.py ++++ b/autoarray/structures/plot/structure_plotters.py +@@ -7,12 +7,115 @@ from autoarray.plot.visuals.two_d import Visuals2D + from autoarray.plot.mat_plot.one_d import MatPlot1D + from autoarray.plot.mat_plot.two_d import MatPlot2D + from autoarray.plot.auto_labels import AutoLabels ++from autoarray.plot.plots.array import plot_array ++from autoarray.plot.plots.grid import plot_grid ++from autoarray.plot.plots.yx import plot_yx + from autoarray.structures.arrays.uniform_1d import Array1D + from autoarray.structures.arrays.uniform_2d import Array2D + from autoarray.structures.grids.uniform_1d import Grid1D + from autoarray.structures.grids.uniform_2d import Grid2D + + ++# --------------------------------------------------------------------------- ++# Helpers to extract plain numpy overlay data from Visuals2D/Visuals1D ++# --------------------------------------------------------------------------- ++ ++def _lines_from_visuals(visuals_2d: Visuals2D) -> Optional[List[np.ndarray]]: ++ """Return a list of (N, 2) numpy arrays from visuals_2d.lines.""" ++ if visuals_2d is None or visuals_2d.lines is None: ++ return None ++ lines = visuals_2d.lines ++ result = [] ++ try: ++ # Grid2DIrregular or list of array-like objects ++ for line in lines: ++ try: ++ arr = np.array(line.array if hasattr(line, "array") else line) ++ if arr.ndim == 2 and arr.shape[1] == 2: ++ result.append(arr) ++ except Exception: ++ pass ++ except TypeError: ++ pass ++ return result or None ++ ++ ++def _positions_from_visuals(visuals_2d: Visuals2D) -> Optional[List[np.ndarray]]: ++ """Return a list of (N, 2) numpy arrays from visuals_2d.positions.""" ++ if visuals_2d is None or visuals_2d.positions is None: ++ return None ++ positions = visuals_2d.positions ++ try: ++ arr = np.array(positions.array if hasattr(positions, "array") else positions) ++ if arr.ndim == 2 and arr.shape[1] == 2: ++ return [arr] ++ except Exception: ++ pass ++ if isinstance(positions, list): ++ result = [] ++ for p in positions: ++ try: ++ arr = np.array(p.array if hasattr(p, "array") else p) ++ result.append(arr) ++ except Exception: ++ pass ++ return result or None ++ return None ++ ++ ++def _mask_edge_from(array: Array2D, visuals_2d: Optional[Visuals2D]) -> Optional[np.ndarray]: ++ """Return edge-pixel coordinates to scatter as mask overlay.""" ++ if visuals_2d is not None and visuals_2d.mask is not None: ++ try: ++ return np.array(visuals_2d.mask.derive_grid.edge.array) ++ except Exception: ++ pass ++ if array is not None and not array.mask.is_all_false: ++ try: ++ return np.array(array.mask.derive_grid.edge.array) ++ except Exception: ++ pass ++ return None ++ ++ ++def _grid_from_visuals(visuals_2d: Visuals2D) -> Optional[np.ndarray]: ++ """Return grid scatter coordinates from visuals_2d.grid.""" ++ if visuals_2d is None or visuals_2d.grid is None: ++ return None ++ grid = visuals_2d.grid ++ try: ++ return np.array(grid.array if hasattr(grid, "array") else grid) ++ except Exception: ++ return None ++ ++ ++def _output_for_mat_plot(mat_plot, is_for_subplot: bool, auto_filename: str): ++ """ ++ Derive (output_path, output_filename, output_format) from a MatPlot object. ++ ++ When in subplot mode, returns output_path=None so that plot_array does not ++ save — the subplot is saved later by close_subplot_figure(). ++ """ ++ if is_for_subplot: ++ return None, auto_filename, "png" ++ ++ output = mat_plot.output ++ fmt_list = output.format_list ++ fmt = fmt_list[0] if fmt_list else "show" ++ ++ filename = output.filename_from(auto_filename) ++ ++ if fmt == "show": ++ return None, filename, "png" ++ ++ path = output.output_path_from(fmt) ++ return path, filename, fmt ++ ++ ++# --------------------------------------------------------------------------- ++# Plotters ++# --------------------------------------------------------------------------- ++ + class Array2DPlotter(AbstractPlotter): + def __init__( + self, +@@ -20,38 +123,35 @@ class Array2DPlotter(AbstractPlotter): + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, + ): +- """ +- Plots `Array2D` objects using the matplotlib method `imshow()` and many other matplotlib functions which +- customize the plot's appearance. +- +- The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings +- passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, +- but a user can manually input values into `MatPlot2d` to customize the figure's appearance. +- +- Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from +- the `Array2D` and plotted via the visuals object. +- +- Parameters +- ---------- +- array +- The 2D array the plotter plot. +- mat_plot_2d +- Contains objects which wrap the matplotlib function calls that make 2D plots. +- visuals_2d +- Contains 2D visuals that can be overlaid on 2D plots. +- """ + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) +- + self.array = array + + def figure_2d(self): +- """ +- Plots the plotter's `Array2D` object in 2D. +- """ +- self.mat_plot_2d.plot_array( +- array=self.array, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels(title="Array2D", filename="array"), ++ """Plot the array as a 2D image.""" ++ if self.array is None or np.all(self.array == 0): ++ return ++ ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None ++ ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, is_sub, "array" ++ ) ++ ++ plot_array( ++ array=self.array.native.array, ++ ax=ax, ++ extent=self.array.geometry.extent, ++ mask=_mask_edge_from(self.array, self.visuals_2d), ++ grid=_grid_from_visuals(self.visuals_2d), ++ positions=_positions_from_visuals(self.visuals_2d), ++ lines=_lines_from_visuals(self.visuals_2d), ++ title="Array2D", ++ colormap=self.mat_plot_2d.cmap.cmap, ++ use_log10=self.mat_plot_2d.use_log10, ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, + ) + + +@@ -62,28 +162,7 @@ class Grid2DPlotter(AbstractPlotter): + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, + ): +- """ +- Plots `Grid2D` objects using the matplotlib method `scatter()` and many other matplotlib functions which +- customize the plot's appearance. +- +- The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings +- passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, +- but a user can manually input values into `MatPlot2d` to customize the figure's appearance. +- +- Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from +- the `Grid2D` and plotted via the visuals object. +- +- Parameters +- ---------- +- grid +- The 2D grid the plotter plot. +- mat_plot_2d +- Contains objects which wrap the matplotlib function calls that make 2D plots. +- visuals_2d +- Contains 2D visuals that can be overlaid on 2D plots. +- """ + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) +- + self.grid = grid + + def figure_2d( +@@ -92,27 +171,24 @@ class Grid2DPlotter(AbstractPlotter): + plot_grid_lines: bool = False, + plot_over_sampled_grid: bool = False, + ): +- """ +- Plots the plotter's `Grid2D` object in 2D. +- +- Parameters +- ---------- +- color_array +- An array of RGB color values which can be used to give the plotted 2D grid a colorscale (w/ colorbar). +- plot_grid_lines +- If True, a rectangular grid of lines is plotted on the figure showing the pixels which the grid coordinates +- are centred on. +- plot_over_sampled_grid +- If True, the grid is plotted with over-sampled sub-gridded coordinates based on the `sub_size` attribute +- of the grid's over-sampling object. +- """ +- self.mat_plot_2d.plot_grid( +- grid=self.grid, +- visuals_2d=self.visuals_2d, +- auto_labels=AutoLabels(title="Grid2D", filename="grid"), ++ """Plot the grid as a 2D scatter.""" ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None ++ ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, is_sub, "grid" ++ ) ++ ++ grid_plot = self.grid.over_sampled if plot_over_sampled_grid else self.grid ++ ++ plot_grid( ++ grid=np.array(grid_plot.array), ++ ax=ax, ++ lines=_lines_from_visuals(self.visuals_2d), + color_array=color_array, +- plot_grid_lines=plot_grid_lines, +- plot_over_sampled_grid=plot_over_sampled_grid, ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, + ) + + +@@ -129,29 +205,6 @@ class YX1DPlotter(AbstractPlotter): + plot_yx_dict=None, + auto_labels=AutoLabels(), + ): +- """ +- Plots two 1D objects using the matplotlib method `plot()` (or a similar method) and many other matplotlib +- functions which customize the plot's appearance. +- +- The `mat_plot_1d` attribute wraps matplotlib function calls to make the figure. By default, the settings +- passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, +- but a user can manually input values into `MatPlot1d` to customize the figure's appearance. +- +- Overlaid on the figure are visuals, contained in the `Visuals1D` object. Attributes may be extracted from +- the `Array1D` and plotted via the visuals object. +- +- Parameters +- ---------- +- y +- The 1D y values the plotter plot. +- x +- The 1D x values the plotter plot. +- mat_plot_1d +- Contains objects which wrap the matplotlib function calls that make 1D plots. +- visuals_1d +- Contains 1D visuals that can be overlaid on 1D plots. +- """ +- + if isinstance(y, list): + y = Array1D.no_mask(values=y, pixel_scales=1.0) + +@@ -169,17 +222,31 @@ class YX1DPlotter(AbstractPlotter): + self.auto_labels = auto_labels + + def figure_1d(self): +- """ +- Plots the plotter's y and x values in 1D. +- """ +- +- self.mat_plot_1d.plot_yx( +- y=self.y, +- x=self.x, +- visuals_1d=self.visuals_1d, +- auto_labels=self.auto_labels, +- should_plot_grid=self.should_plot_grid, +- should_plot_zero=self.should_plot_zero, +- plot_axis_type_override=self.plot_axis_type, +- **self.plot_yx_dict, ++ """Plot the y and x values as a 1D line.""" ++ y_arr = self.y.array if hasattr(self.y, "array") else np.array(self.y) ++ x_arr = self.x.array if hasattr(self.x, "array") else np.array(self.x) ++ ++ is_sub = self.mat_plot_1d.is_for_subplot ++ ax = self.mat_plot_1d.setup_subplot() if is_sub else None ++ ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_1d, is_sub, self.auto_labels.filename or "yx" ++ ) ++ ++ shaded = None ++ if self.visuals_1d is not None and self.visuals_1d.shaded_region is not None: ++ shaded = self.visuals_1d.shaded_region ++ ++ plot_yx( ++ y=y_arr, ++ x=x_arr, ++ ax=ax, ++ shaded_region=shaded, ++ title=self.auto_labels.title or "", ++ xlabel=self.auto_labels.xlabel or "", ++ ylabel=self.auto_labels.ylabel or "", ++ plot_axis_type=self.plot_axis_type or "linear", ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, + ) +-- +2.43.0 + diff --git a/patches/autoarray/0003-PR-A3-replace-mat_plot_2d.plot_array-in-FitImagingPl.patch b/patches/autoarray/0003-PR-A3-replace-mat_plot_2d.plot_array-in-FitImagingPl.patch new file mode 100644 index 000000000..f41d9d003 --- /dev/null +++ b/patches/autoarray/0003-PR-A3-replace-mat_plot_2d.plot_array-in-FitImagingPl.patch @@ -0,0 +1,172 @@ +From ebbb315cdd0eee0a3cc4e7c2ea0600859cc95330 Mon Sep 17 00:00:00 2001 +From: Claude +Date: Mon, 16 Mar 2026 17:31:52 +0000 +Subject: [PATCH 3/3] PR A3: replace mat_plot_2d.plot_array in + FitImagingPlotterMeta with plot_array() + +Bridge FitImagingPlotterMeta.figures_2d() to use the new direct-matplotlib +plot_array() function via a _plot_array() helper method, eliminating the +MatWrap system for fit imaging plots. + +https://claude.ai/code/session_01CzJBy8KvFXiNchoNdk5i9k +--- + autoarray/fit/plot/fit_imaging_plotters.py | 72 +++++++++++++++++----- + 1 file changed, 55 insertions(+), 17 deletions(-) + +diff --git a/autoarray/fit/plot/fit_imaging_plotters.py b/autoarray/fit/plot/fit_imaging_plotters.py +index 86aa0d34..20552945 100644 +--- a/autoarray/fit/plot/fit_imaging_plotters.py ++++ b/autoarray/fit/plot/fit_imaging_plotters.py +@@ -1,10 +1,19 @@ ++import numpy as np + from typing import Callable + + from autoarray.plot.abstract_plotters import AbstractPlotter + from autoarray.plot.visuals.two_d import Visuals2D + from autoarray.plot.mat_plot.two_d import MatPlot2D + from autoarray.plot.auto_labels import AutoLabels ++from autoarray.plot.plots.array import plot_array + from autoarray.fit.fit_imaging import FitImaging ++from autoarray.structures.plot.structure_plotters import ( ++ _mask_edge_from, ++ _grid_from_visuals, ++ _lines_from_visuals, ++ _positions_from_visuals, ++ _output_for_mat_plot, ++) + + + class FitImagingPlotterMeta(AbstractPlotter): +@@ -43,6 +52,44 @@ class FitImagingPlotterMeta(AbstractPlotter): + self.fit = fit + self.residuals_symmetric_cmap = residuals_symmetric_cmap + ++ def _plot_array(self, array, auto_labels, visuals_2d=None): ++ """Helper: plot an Array2D using the new direct-matplotlib function.""" ++ if array is None: ++ return ++ ++ v2d = visuals_2d if visuals_2d is not None else self.visuals_2d ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None ++ ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, ++ is_sub, ++ auto_labels.filename if auto_labels else "array", ++ ) ++ ++ try: ++ arr = array.native.array ++ extent = array.geometry.extent ++ except AttributeError: ++ arr = np.asarray(array) ++ extent = None ++ ++ plot_array( ++ array=arr, ++ ax=ax, ++ extent=extent, ++ mask=_mask_edge_from(array if hasattr(array, "mask") else None, v2d), ++ grid=_grid_from_visuals(v2d), ++ positions=_positions_from_visuals(v2d), ++ lines=_lines_from_visuals(v2d), ++ title=auto_labels.title if auto_labels else "", ++ colormap=self.mat_plot_2d.cmap.cmap, ++ use_log10=self.mat_plot_2d.use_log10, ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, ++ ) ++ + def figures_2d( + self, + data: bool = False, +@@ -82,57 +129,50 @@ class FitImagingPlotterMeta(AbstractPlotter): + """ + + if data: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.data, +- visuals_2d=self.visuals_2d, + auto_labels=AutoLabels(title="Data", filename=f"data{suffix}"), + ) + + if noise_map: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.noise_map, +- visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Noise-Map", filename=f"noise_map{suffix}" + ), + ) + + if signal_to_noise_map: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.signal_to_noise_map, +- visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Signal-To-Noise Map", filename=f"signal_to_noise_map{suffix}" + ), + ) + + if model_image: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.model_data, +- visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Model Image", filename=f"model_image{suffix}" + ), + ) + + cmap_original = self.mat_plot_2d.cmap +- + if self.residuals_symmetric_cmap: + self.mat_plot_2d.cmap = self.mat_plot_2d.cmap.symmetric_cmap_from() + + if residual_map: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.residual_map, +- visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Residual Map", filename=f"residual_map{suffix}" + ), + ) + + if normalized_residual_map: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.normalized_residual_map, +- visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Normalized Residual Map", + filename=f"normalized_residual_map{suffix}", +@@ -142,18 +182,16 @@ class FitImagingPlotterMeta(AbstractPlotter): + self.mat_plot_2d.cmap = cmap_original + + if chi_squared_map: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.chi_squared_map, +- visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Chi-Squared Map", filename=f"chi_squared_map{suffix}" + ), + ) + + if residual_flux_fraction_map: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.residual_map, +- visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Residual Flux Fraction Map", + filename=f"residual_flux_fraction_map{suffix}", +-- +2.43.0 + diff --git a/patches/autogalaxy/0001-PR-G1-G2-replace-mat_plot_2d.plot_array-with-_plot_a.patch b/patches/autogalaxy/0001-PR-G1-G2-replace-mat_plot_2d.plot_array-with-_plot_a.patch new file mode 100644 index 000000000..22f7069d7 --- /dev/null +++ b/patches/autogalaxy/0001-PR-G1-G2-replace-mat_plot_2d.plot_array-with-_plot_a.patch @@ -0,0 +1,466 @@ +From 01763ab468688ef6222694f14dbc073755ac268f Mon Sep 17 00:00:00 2001 +From: Claude +Date: Mon, 16 Mar 2026 17:32:12 +0000 +Subject: [PATCH] PR G1-G2: replace mat_plot_2d.plot_array with _plot_array() + bridge in all plotters + +- Add autogalaxy/plot/plots/overlays.py with helpers to extract critical + curves, caustics, and profile centres from Visuals2D as plain arrays +- Add _plot_array() and _plot_grid() bridge methods to the Plotter base class + that route to the new direct-matplotlib plot_array()/plot_grid() functions +- Update all autogalaxy plotters to use self._plot_array() instead of + self.mat_plot_2d.plot_array(), covering: LightProfilePlotter, BasisPlotter, + MassPlotter, GalaxyPlotter, GalaxiesPlotter, AdaptPlotter, + FitImagingPlotter, FitEllipsePlotter, FitEllipsePDFPlotter + +https://claude.ai/code/session_01CzJBy8KvFXiNchoNdk5i9k +--- + .../ellipse/plot/fit_ellipse_plotters.py | 4 +- + autogalaxy/galaxy/plot/adapt_plotters.py | 4 +- + autogalaxy/galaxy/plot/galaxies_plotters.py | 6 +- + autogalaxy/galaxy/plot/galaxy_plotters.py | 2 +- + .../imaging/plot/fit_imaging_plotters.py | 11 +-- + autogalaxy/plot/abstract_plotters.py | 80 +++++++++++++++ + autogalaxy/plot/mass_plotter.py | 15 +-- + autogalaxy/plot/plots/__init__.py | 6 ++ + autogalaxy/plot/plots/overlays.py | 98 +++++++++++++++++++ + autogalaxy/profiles/plot/basis_plotters.py | 2 +- + .../profiles/plot/light_profile_plotters.py | 2 +- + 11 files changed, 201 insertions(+), 29 deletions(-) + create mode 100644 autogalaxy/plot/plots/__init__.py + create mode 100644 autogalaxy/plot/plots/overlays.py + +diff --git a/autogalaxy/ellipse/plot/fit_ellipse_plotters.py b/autogalaxy/ellipse/plot/fit_ellipse_plotters.py +index 0bb67a21..0506f810 100644 +--- a/autogalaxy/ellipse/plot/fit_ellipse_plotters.py ++++ b/autogalaxy/ellipse/plot/fit_ellipse_plotters.py +@@ -89,7 +89,7 @@ class FitEllipsePlotter(Plotter): + positions=ellipse_list, lines=ellipse_list + ) + +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit_list[0].data, + visuals_2d=visuals_2d, + auto_labels=aplt.AutoLabels( +@@ -217,7 +217,7 @@ class FitEllipsePDFPlotter(Plotter): + lines=median_ellipse, fill_region=[y_fill, x_fill] + ) + +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit_pdf_list[0][0].data, + visuals_2d=visuals_2d, + auto_labels=aplt.AutoLabels( +diff --git a/autogalaxy/galaxy/plot/adapt_plotters.py b/autogalaxy/galaxy/plot/adapt_plotters.py +index a2b07e2e..d9f7818d 100644 +--- a/autogalaxy/galaxy/plot/adapt_plotters.py ++++ b/autogalaxy/galaxy/plot/adapt_plotters.py +@@ -27,7 +27,7 @@ class AdaptPlotter(Plotter): + The adapt model image that is plotted. + """ + +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=model_image, + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels( +@@ -44,7 +44,7 @@ class AdaptPlotter(Plotter): + galaxy_image + The galaxy image that is plotted. + """ +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=galaxy_image, + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels( +diff --git a/autogalaxy/galaxy/plot/galaxies_plotters.py b/autogalaxy/galaxy/plot/galaxies_plotters.py +index e49a8e1a..c3b1ed1b 100644 +--- a/autogalaxy/galaxy/plot/galaxies_plotters.py ++++ b/autogalaxy/galaxy/plot/galaxies_plotters.py +@@ -146,7 +146,7 @@ class GalaxiesPlotter(Plotter): + Add a suffix to the end of the filename the plot is saved to hard-disk using. + """ + if image: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.galaxies.image_2d_from(grid=self.grid), + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels( +@@ -160,7 +160,7 @@ class GalaxiesPlotter(Plotter): + else: + title = f"Plane Image{title_suffix}" + +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.galaxies.plane_image_2d_from( + grid=self.grid, zoom_to_brightest=zoom_to_brightest + ), +@@ -177,7 +177,7 @@ class GalaxiesPlotter(Plotter): + else: + title = f"Plane Grid{title_suffix}" + +- self.mat_plot_2d.plot_grid( ++ self._plot_grid( + grid=self.grid, + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels( +diff --git a/autogalaxy/galaxy/plot/galaxy_plotters.py b/autogalaxy/galaxy/plot/galaxy_plotters.py +index d1a37283..d2464f81 100644 +--- a/autogalaxy/galaxy/plot/galaxy_plotters.py ++++ b/autogalaxy/galaxy/plot/galaxy_plotters.py +@@ -201,7 +201,7 @@ class GalaxyPlotter(Plotter): + Whether to make a 2D plot (via `imshow`) of the magnification. + """ + if image: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.galaxy.image_2d_from(grid=self.grid), + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels( +diff --git a/autogalaxy/imaging/plot/fit_imaging_plotters.py b/autogalaxy/imaging/plot/fit_imaging_plotters.py +index 3f21b5a2..8a68222a 100644 +--- a/autogalaxy/imaging/plot/fit_imaging_plotters.py ++++ b/autogalaxy/imaging/plot/fit_imaging_plotters.py +@@ -131,14 +131,7 @@ class FitImagingPlotter(Plotter): + + for galaxy_index in galaxy_indices: + if subtracted_image: +- self.mat_plot_2d.cmap.kwargs["vmin"] = np.max( +- self.fit.model_images_of_galaxies_list[galaxy_index] +- ) +- self.mat_plot_2d.cmap.kwargs["vmin"] = np.min( +- self.fit.model_images_of_galaxies_list[galaxy_index] +- ) +- +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.subtracted_images_of_galaxies_list[galaxy_index], + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels( +@@ -148,7 +141,7 @@ class FitImagingPlotter(Plotter): + ) + + if model_image: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.fit.model_images_of_galaxies_list[galaxy_index], + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels( +diff --git a/autogalaxy/plot/abstract_plotters.py b/autogalaxy/plot/abstract_plotters.py +index 3acd3667..79bf31f2 100644 +--- a/autogalaxy/plot/abstract_plotters.py ++++ b/autogalaxy/plot/abstract_plotters.py +@@ -2,7 +2,19 @@ from autoarray.plot.wrap.base.abstract import set_backend + + set_backend() + ++import numpy as np ++ + from autoarray.plot.abstract_plotters import AbstractPlotter ++from autoarray.plot.plots.array import plot_array ++from autoarray.structures.plot.structure_plotters import ( ++ _mask_edge_from, ++ _grid_from_visuals, ++ _output_for_mat_plot, ++) ++from autogalaxy.plot.plots.overlays import ( ++ _galaxy_lines_from_visuals, ++ _galaxy_positions_from_visuals, ++) + + from autogalaxy.plot.mat_plot.one_d import MatPlot1D + from autogalaxy.plot.mat_plot.two_d import MatPlot2D +@@ -32,3 +44,71 @@ class Plotter(AbstractPlotter): + + self.visuals_2d = visuals_2d or Visuals2D() + self.mat_plot_2d = mat_plot_2d or MatPlot2D() ++ ++ def _plot_array(self, array, visuals_2d, auto_labels): ++ """Bridge: replace mat_plot_2d.plot_array() with the new plot_array() function.""" ++ if array is None: ++ return ++ ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None ++ ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, ++ is_sub, ++ auto_labels.filename if auto_labels else "array", ++ ) ++ ++ try: ++ arr = array.native.array ++ extent = array.geometry.extent ++ except AttributeError: ++ arr = np.asarray(array) ++ extent = None ++ ++ v2d = visuals_2d if visuals_2d is not None else self.visuals_2d ++ ++ plot_array( ++ array=arr, ++ ax=ax, ++ extent=extent, ++ mask=_mask_edge_from(array if hasattr(array, "mask") else None, v2d), ++ grid=_grid_from_visuals(v2d), ++ positions=_galaxy_positions_from_visuals(v2d), ++ lines=_galaxy_lines_from_visuals(v2d), ++ title=auto_labels.title if auto_labels else "", ++ colormap=self.mat_plot_2d.cmap.cmap, ++ use_log10=self.mat_plot_2d.use_log10, ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, ++ ) ++ ++ def _plot_grid(self, grid, visuals_2d, auto_labels): ++ """Bridge: replace mat_plot_2d.plot_grid() with plot_grid() function.""" ++ from autoarray.plot.plots.grid import plot_grid ++ ++ is_sub = self.mat_plot_2d.is_for_subplot ++ ax = self.mat_plot_2d.setup_subplot() if is_sub else None ++ ++ output_path, filename, fmt = _output_for_mat_plot( ++ self.mat_plot_2d, ++ is_sub, ++ auto_labels.filename if auto_labels else "grid", ++ ) ++ ++ v2d = visuals_2d if visuals_2d is not None else self.visuals_2d ++ ++ try: ++ grid_arr = np.array(grid.array if hasattr(grid, "array") else grid) ++ except Exception: ++ grid_arr = np.asarray(grid) ++ ++ plot_grid( ++ grid=grid_arr, ++ ax=ax, ++ lines=_galaxy_lines_from_visuals(v2d), ++ output_path=output_path, ++ output_filename=filename, ++ output_format=fmt, ++ ) +diff --git a/autogalaxy/plot/mass_plotter.py b/autogalaxy/plot/mass_plotter.py +index 16167ad5..fcd7d483 100644 +--- a/autogalaxy/plot/mass_plotter.py ++++ b/autogalaxy/plot/mass_plotter.py +@@ -64,24 +64,22 @@ class MassPlotter(Plotter): + """ + + if convergence: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.mass_obj.convergence_2d_from(grid=self.grid), + visuals_2d=self.visuals_2d_with_critical_curves, + auto_labels=aplt.AutoLabels( + title=f"Convergence{title_suffix}", + filename=f"convergence_2d{filename_suffix}", +- cb_unit="", + ), + ) + + if potential: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.mass_obj.potential_2d_from(grid=self.grid), + visuals_2d=self.visuals_2d_with_critical_curves, + auto_labels=aplt.AutoLabels( + title=f"Potential{title_suffix}", + filename=f"potential_2d{filename_suffix}", +- cb_unit="", + ), + ) + +@@ -91,13 +89,12 @@ class MassPlotter(Plotter): + values=deflections.slim[:, 0], mask=self.grid.mask + ) + +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=deflections_y, + visuals_2d=self.visuals_2d_with_critical_curves, + auto_labels=aplt.AutoLabels( + title=f"Deflections Y{title_suffix}", + filename=f"deflections_y_2d{filename_suffix}", +- cb_unit="", + ), + ) + +@@ -107,20 +104,19 @@ class MassPlotter(Plotter): + values=deflections.slim[:, 1], mask=self.grid.mask + ) + +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=deflections_x, + visuals_2d=self.visuals_2d_with_critical_curves, + auto_labels=aplt.AutoLabels( + title=f"Deflections X{title_suffix}", + filename=f"deflections_x_2d{filename_suffix}", +- cb_unit="", + ), + ) + + if magnification: + from autogalaxy.operate.lens_calc import LensCalc + +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=LensCalc.from_mass_obj( + self.mass_obj + ).magnification_2d_from(grid=self.grid), +@@ -128,6 +124,5 @@ class MassPlotter(Plotter): + auto_labels=aplt.AutoLabels( + title=f"Magnification{title_suffix}", + filename=f"magnification_2d{filename_suffix}", +- cb_unit="", + ), + ) +diff --git a/autogalaxy/plot/plots/__init__.py b/autogalaxy/plot/plots/__init__.py +new file mode 100644 +index 00000000..1f20cbdc +--- /dev/null ++++ b/autogalaxy/plot/plots/__init__.py +@@ -0,0 +1,6 @@ ++from autogalaxy.plot.plots.overlays import ( ++ _critical_curves_from_visuals, ++ _caustics_from_visuals, ++ _galaxy_lines_from_visuals, ++ _galaxy_positions_from_visuals, ++) +diff --git a/autogalaxy/plot/plots/overlays.py b/autogalaxy/plot/plots/overlays.py +new file mode 100644 +index 00000000..9f79ad94 +--- /dev/null ++++ b/autogalaxy/plot/plots/overlays.py +@@ -0,0 +1,98 @@ ++""" ++Helper functions to extract autogalaxy-specific overlay data from Visuals2D ++objects and convert them to plain numpy arrays suitable for plot_array(). ++""" ++from typing import List, Optional ++ ++import numpy as np ++ ++ ++def _critical_curves_from_visuals(visuals_2d) -> Optional[List[np.ndarray]]: ++ """Return list of (N,2) arrays for tangential and radial critical curves.""" ++ if visuals_2d is None: ++ return None ++ curves = [] ++ for attr in ("tangential_critical_curves", "radial_critical_curves"): ++ val = getattr(visuals_2d, attr, None) ++ if val is None: ++ continue ++ try: ++ for item in val: ++ try: ++ arr = np.array(item.array if hasattr(item, "array") else item) ++ if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: ++ curves.append(arr) ++ except Exception: ++ pass ++ except TypeError: ++ try: ++ arr = np.array(val.array if hasattr(val, "array") else val) ++ if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: ++ curves.append(arr) ++ except Exception: ++ pass ++ return curves or None ++ ++ ++def _caustics_from_visuals(visuals_2d) -> Optional[List[np.ndarray]]: ++ """Return list of (N,2) arrays for tangential and radial caustics.""" ++ if visuals_2d is None: ++ return None ++ curves = [] ++ for attr in ("tangential_caustics", "radial_caustics"): ++ val = getattr(visuals_2d, attr, None) ++ if val is None: ++ continue ++ try: ++ for item in val: ++ try: ++ arr = np.array(item.array if hasattr(item, "array") else item) ++ if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: ++ curves.append(arr) ++ except Exception: ++ pass ++ except TypeError: ++ try: ++ arr = np.array(val.array if hasattr(val, "array") else val) ++ if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: ++ curves.append(arr) ++ except Exception: ++ pass ++ return curves or None ++ ++ ++def _galaxy_lines_from_visuals(visuals_2d) -> Optional[List[np.ndarray]]: ++ """ ++ Return all line overlays from an autogalaxy Visuals2D, combining regular ++ lines, critical curves, and caustics into a single list. ++ """ ++ from autoarray.structures.plot.structure_plotters import _lines_from_visuals ++ ++ lines = _lines_from_visuals(visuals_2d) or [] ++ critical = _critical_curves_from_visuals(visuals_2d) or [] ++ caustics = _caustics_from_visuals(visuals_2d) or [] ++ combined = lines + critical + caustics ++ return combined or None ++ ++ ++def _galaxy_positions_from_visuals(visuals_2d) -> Optional[List[np.ndarray]]: ++ """ ++ Return all scatter-point overlays from an autogalaxy Visuals2D, combining ++ regular positions, light/mass profile centres, and multiple images. ++ """ ++ from autoarray.structures.plot.structure_plotters import _positions_from_visuals ++ ++ result = _positions_from_visuals(visuals_2d) or [] ++ ++ for attr in ("light_profile_centres", "mass_profile_centres", "multiple_images"): ++ val = getattr(visuals_2d, attr, None) ++ if val is None: ++ continue ++ try: ++ arr = np.array(val.array if hasattr(val, "array") else val) ++ if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: ++ result.append(arr) ++ except Exception: ++ pass ++ ++ return result or None +diff --git a/autogalaxy/profiles/plot/basis_plotters.py b/autogalaxy/profiles/plot/basis_plotters.py +index 93d4bbc6..3ff5fac1 100644 +--- a/autogalaxy/profiles/plot/basis_plotters.py ++++ b/autogalaxy/profiles/plot/basis_plotters.py +@@ -115,7 +115,7 @@ class BasisPlotter(Plotter): + self.open_subplot_figure(number_subplots=len(self.basis.light_profile_list)) + + for light_profile in self.basis.light_profile_list: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=light_profile.image_2d_from(grid=self.grid), + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels(title=light_profile.coefficient_tag), +diff --git a/autogalaxy/profiles/plot/light_profile_plotters.py b/autogalaxy/profiles/plot/light_profile_plotters.py +index 7292f5eb..07b21f29 100644 +--- a/autogalaxy/profiles/plot/light_profile_plotters.py ++++ b/autogalaxy/profiles/plot/light_profile_plotters.py +@@ -89,7 +89,7 @@ class LightProfilePlotter(Plotter): + Whether to make a 2D plot (via `imshow`) of the image. + """ + if image: +- self.mat_plot_2d.plot_array( ++ self._plot_array( + array=self.light_profile.image_2d_from(grid=self.grid), + visuals_2d=self.visuals_2d, + auto_labels=aplt.AutoLabels(title="Image", filename="image_2d"), +-- +2.43.0 + diff --git a/pyproject.toml b/pyproject.toml index 76d4900df..ba256b30c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ ] keywords = ["cli"] dependencies = [ - "autogalaxy", + "autogalaxy @ git+https://github.com/Jammy2211/PyAutoGalaxy.git@claude/refactor-plotting-module-3ZdD8", "nautilus-sampler==1.0.5" ] diff --git a/test_autolens/analysis/test_plotter_interface.py b/test_autolens/analysis/test_plotter_interface.py index 2a5d4573c..34809ef4e 100644 --- a/test_autolens/analysis/test_plotter_interface.py +++ b/test_autolens/analysis/test_plotter_interface.py @@ -1,45 +1,45 @@ -import os -import shutil -from os import path - -import pytest -import autolens as al -from autolens.analysis import plotter_interface as vis - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plotter_interface_plotter_setup(): - return path.join("{}".format(directory), "files") - - -def test__tracer(masked_imaging_7x7, tracer_x2_plane_7x7, plot_path, plot_patch): - if os.path.exists(plot_path): - shutil.rmtree(plot_path) - - plotter_interface = vis.PlotterInterface(image_path=plot_path) - - plotter_interface.tracer( - tracer=tracer_x2_plane_7x7, - grid=masked_imaging_7x7.grids.lp, - ) - - assert path.join(plot_path, "subplot_galaxies_images.png") in plot_patch.paths - - image = al.ndarray_via_fits_from( - file_path=path.join(plot_path, "tracer.fits"), hdu=0 - ) - - assert image.shape == (5, 5) - - -def test__image_with_positions(image_7x7, positions_x2, plot_path, plot_patch): - if os.path.exists(plot_path): - shutil.rmtree(plot_path) - - plotter_interface = vis.PlotterInterface(image_path=plot_path) - - plotter_interface.image_with_positions(image=image_7x7, positions=positions_x2) - - assert path.join(plot_path, "image_with_positions.png") in plot_patch.paths +import os +import shutil +from os import path + +import pytest +import autolens as al +from autolens.analysis import plotter_interface as vis + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plotter_interface_plotter_setup(): + return path.join("{}".format(directory), "files") + + +def test__tracer(masked_imaging_7x7, tracer_x2_plane_7x7, plot_path, plot_patch): + if os.path.exists(plot_path): + shutil.rmtree(plot_path) + + plotter_interface = vis.PlotterInterface(image_path=plot_path) + + plotter_interface.tracer( + tracer=tracer_x2_plane_7x7, + grid=masked_imaging_7x7.grids.lp, + ) + + assert path.join(plot_path, "subplot_galaxies_images.png") in plot_patch.paths + + image = al.ndarray_via_fits_from( + file_path=path.join(plot_path, "tracer.fits"), hdu=0 + ) + + assert image.shape == (5, 5) + + +def test__image_with_positions(image_7x7, positions_x2, plot_path, plot_patch): + if os.path.exists(plot_path): + shutil.rmtree(plot_path) + + plotter_interface = vis.PlotterInterface(image_path=plot_path) + + plotter_interface.image_with_positions(image=image_7x7, positions=positions_x2) + + assert path.join(plot_path, "image_with_positions.png") in plot_patch.paths diff --git a/test_autolens/config/visualize.yaml b/test_autolens/config/visualize.yaml index 9b4f10396..8c69d3c6d 100644 --- a/test_autolens/config/visualize.yaml +++ b/test_autolens/config/visualize.yaml @@ -3,52 +3,6 @@ general: backend: default imshow_origin: upper zoom_around_mask: true -mat_wrap_2d: - CausticsLine: - figure: - c: w,g - linestyle: -- - linewidth: 5 - subplot: - c: g - linestyle: -- - linewidth: 7 - CriticalCurvesLine: - figure: - c: w,k - linestyle: '-' - linewidth: 4 - subplot: - c: b - linestyle: '-' - linewidth: 6 - LightProfileCentreScatter: - figure: - c: k,r - marker: + - s: 1 - subplot: - c: b - marker: . - s: 15 - MassProfileCentreScatter: - figure: - c: r,k - marker: x - s: 2 - subplot: - c: k - marker: o - s: 16 - MultipleImagesScatter: - figure: - c: k,w - marker: o - s: 3 - subplot: - c: g - marker: . - s: 17 plots: fits_are_zoomed: true dataset: @@ -59,7 +13,9 @@ plots: fits_fit: true # Output a .fits file containing the fit model data, residual map, normalized residual map and chi-squared? fits_galaxy_images : true # Output a .fits file containing the images (e.g. without PSF convolution) of every galaxy? fits_model_galaxy_images : true # Output a .fits file containing the model images (e.g. with PSF convolution) of every galaxy? - fit_imaging: {} + fit_imaging: + subplot_fit_log10: true + subplot_of_planes: true fit_interferometer: subplot_fit_real_space: true subplot_fit_dirty_images: true diff --git a/test_autolens/conftest.py b/test_autolens/conftest.py index f75cefec4..17da6b12c 100644 --- a/test_autolens/conftest.py +++ b/test_autolens/conftest.py @@ -3,6 +3,7 @@ import pytest from matplotlib import pyplot +import matplotlib.figure from autofit import conf from autolens import fixtures @@ -26,6 +27,7 @@ def __call__(self, path, *args, **kwargs): def make_plot_patch(monkeypatch): plot_patch = PlotPatch() monkeypatch.setattr(pyplot, "savefig", plot_patch) + monkeypatch.setattr(matplotlib.figure.Figure, "savefig", plot_patch) return plot_patch diff --git a/test_autolens/imaging/model/test_plotter_interface_imaging.py b/test_autolens/imaging/model/test_plotter_interface_imaging.py index 692020ebf..389719d2d 100644 --- a/test_autolens/imaging/model/test_plotter_interface_imaging.py +++ b/test_autolens/imaging/model/test_plotter_interface_imaging.py @@ -1,55 +1,55 @@ -import os -import shutil -from os import path - -import pytest -import autolens as al -from autolens.imaging.model.plotter_interface import PlotterInterfaceImaging - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plotter_interface_plotter_setup(): - return path.join("{}".format(directory), "files") - - -def test__fit_imaging( - fit_imaging_x2_plane_inversion_7x7, plot_path, plot_patch -): - if os.path.exists(plot_path): - shutil.rmtree(plot_path) - - plotter_interface = PlotterInterfaceImaging(image_path=plot_path) - - plotter_interface.fit_imaging( - fit=fit_imaging_x2_plane_inversion_7x7, - ) - - assert path.join(plot_path, "subplot_tracer.png") in plot_patch.paths - assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths - assert path.join(plot_path, "subplot_fit_log10.png") in plot_patch.paths - - image = al.ndarray_via_fits_from( - file_path=path.join(plot_path, "fit.fits"), hdu=0 - ) - - assert image.shape == (5, 5) - - image = al.ndarray_via_fits_from( - file_path=path.join(plot_path, "model_galaxy_images.fits"), hdu=0 - ) - - assert image.shape == (5, 5) - -def test__fit_imaging_combined( - fit_imaging_x2_plane_inversion_7x7, plot_path, plot_patch -): - if path.exists(plot_path): - shutil.rmtree(plot_path) - - visualizer = PlotterInterfaceImaging(image_path=plot_path) - - visualizer.fit_imaging_combined(fit_list=2 * [fit_imaging_x2_plane_inversion_7x7]) - - assert path.join(plot_path, "subplot_fit_combined.png") in plot_patch.paths \ No newline at end of file +import os +import shutil +from os import path + +import pytest +import autolens as al +from autolens.imaging.model.plotter_interface import PlotterInterfaceImaging + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plotter_interface_plotter_setup(): + return path.join("{}".format(directory), "files") + + +def test__fit_imaging( + fit_imaging_x2_plane_inversion_7x7, plot_path, plot_patch +): + if os.path.exists(plot_path): + shutil.rmtree(plot_path) + + plotter_interface = PlotterInterfaceImaging(image_path=plot_path) + + plotter_interface.fit_imaging( + fit=fit_imaging_x2_plane_inversion_7x7, + ) + + assert path.join(plot_path, "subplot_tracer.png") in plot_patch.paths + assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths + assert path.join(plot_path, "subplot_fit_log10.png") in plot_patch.paths + + image = al.ndarray_via_fits_from( + file_path=path.join(plot_path, "fit.fits"), hdu=0 + ) + + assert image.shape == (5, 5) + + image = al.ndarray_via_fits_from( + file_path=path.join(plot_path, "model_galaxy_images.fits"), hdu=0 + ) + + assert image.shape == (5, 5) + +def test__fit_imaging_combined( + fit_imaging_x2_plane_inversion_7x7, plot_path, plot_patch +): + if path.exists(plot_path): + shutil.rmtree(plot_path) + + visualizer = PlotterInterfaceImaging(image_path=plot_path) + + visualizer.fit_imaging_combined(fit_list=2 * [fit_imaging_x2_plane_inversion_7x7]) + + assert path.join(plot_path, "subplot_fit_combined.png") in plot_patch.paths diff --git a/test_autolens/imaging/plot/test_fit_imaging_plots.py b/test_autolens/imaging/plot/test_fit_imaging_plots.py new file mode 100644 index 000000000..ae8df649f --- /dev/null +++ b/test_autolens/imaging/plot/test_fit_imaging_plots.py @@ -0,0 +1,65 @@ +from os import path + +import pytest + +from autolens.imaging.plot.fit_imaging_plots import ( + subplot_fit, + subplot_fit_log10, + subplot_of_planes, +) + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_fit_imaging_plotter_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "fit" + ) + + +def test_subplot_fit_is_output( + fit_imaging_x2_plane_7x7, plot_path, plot_patch +): + subplot_fit( + fit=fit_imaging_x2_plane_7x7, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths + + +def test_subplot_fit_log10_is_output( + fit_imaging_x2_plane_7x7, plot_path, plot_patch +): + subplot_fit_log10( + fit=fit_imaging_x2_plane_7x7, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_fit_log10.png") in plot_patch.paths + + +def test__subplot_of_planes( + fit_imaging_x2_plane_7x7, plot_path, plot_patch +): + subplot_of_planes( + fit=fit_imaging_x2_plane_7x7, + output_path=plot_path, + output_format="png", + ) + + assert path.join(plot_path, "subplot_of_plane_0.png") in plot_patch.paths + assert path.join(plot_path, "subplot_of_plane_1.png") in plot_patch.paths + + plot_patch.paths = [] + + subplot_of_planes( + fit=fit_imaging_x2_plane_7x7, + output_path=plot_path, + output_format="png", + plane_index=0, + ) + + assert path.join(plot_path, "subplot_of_plane_0.png") in plot_patch.paths + assert path.join(plot_path, "subplot_of_plane_1.png") not in plot_patch.paths diff --git a/test_autolens/imaging/plot/test_fit_imaging_plotters.py b/test_autolens/imaging/plot/test_fit_imaging_plotters.py deleted file mode 100644 index fce14771c..000000000 --- a/test_autolens/imaging/plot/test_fit_imaging_plotters.py +++ /dev/null @@ -1,134 +0,0 @@ -from os import path - -import pytest - -import autolens.plot as aplt - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_fit_imaging_plotter_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "fit" - ) - - -def test__fit_quantities_are_output( - fit_imaging_x2_plane_7x7, plot_path, plot_patch -): - - fit_plotter = aplt.FitImagingPlotter( - fit=fit_imaging_x2_plane_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_plotter.figures_2d( - data=True, - noise_map=True, - signal_to_noise_map=True, - model_image=True, - residual_map=True, - normalized_residual_map=True, - chi_squared_map=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "model_image.png") in plot_patch.paths - assert path.join(plot_path, "residual_map.png") in plot_patch.paths - assert path.join(plot_path, "normalized_residual_map.png") in plot_patch.paths - assert path.join(plot_path, "chi_squared_map.png") in plot_patch.paths - - plot_patch.paths = [] - - fit_plotter.figures_2d( - data=True, - noise_map=False, - signal_to_noise_map=False, - model_image=True, - chi_squared_map=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") not in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") not in plot_patch.paths - assert path.join(plot_path, "model_image.png") in plot_patch.paths - assert path.join(plot_path, "residual_map.png") not in plot_patch.paths - assert path.join(plot_path, "normalized_residual_map.png") not in plot_patch.paths - assert path.join(plot_path, "chi_squared_map.png") in plot_patch.paths - - -def test__figures_of_plane( - fit_imaging_x2_plane_7x7, plot_path, plot_patch -): - - fit_plotter = aplt.FitImagingPlotter( - fit=fit_imaging_x2_plane_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_plotter.figures_2d_of_planes( - subtracted_image=True, model_image=True, plane_image=True - ) - - assert path.join(plot_path, "source_subtracted_image.png") in plot_patch.paths - assert path.join(plot_path, "lens_subtracted_image.png") in plot_patch.paths - assert path.join(plot_path, "lens_model_image.png") in plot_patch.paths - assert path.join(plot_path, "source_model_image.png") in plot_patch.paths - assert path.join(plot_path, "plane_image_of_plane_0.png") in plot_patch.paths - assert path.join(plot_path, "plane_image_of_plane_1.png") in plot_patch.paths - - plot_patch.paths = [] - - fit_plotter.figures_2d_of_planes( - subtracted_image=True, model_image=True, plane_index=0, plane_image=True - ) - - assert path.join(plot_path, "source_subtracted_image.png") in plot_patch.paths - assert ( - path.join(plot_path, "lens_subtracted_image.png") not in plot_patch.paths - ) - assert path.join(plot_path, "lens_model_image.png") in plot_patch.paths - assert path.join(plot_path, "source_model_image.png") not in plot_patch.paths - assert path.join(plot_path, "plane_image_of_plane_0.png") in plot_patch.paths - assert path.join(plot_path, "plane_image_of_plane_1.png") not in plot_patch.paths - - -def test_subplot_fit_is_output( - fit_imaging_x2_plane_7x7, plot_path, plot_patch -): - - fit_plotter = aplt.FitImagingPlotter( - fit=fit_imaging_x2_plane_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), - ) - - fit_plotter.subplot_fit() - assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths - - fit_plotter.subplot_fit_log10() - assert path.join(plot_path, "subplot_fit_log10.png") in plot_patch.paths - - -def test__subplot_of_planes( - fit_imaging_x2_plane_7x7, plot_path, plot_patch -): - - fit_plotter = aplt.FitImagingPlotter( - fit=fit_imaging_x2_plane_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), - ) - - fit_plotter.subplot_of_planes() - - assert path.join(plot_path, "subplot_of_plane_0.png") in plot_patch.paths - assert path.join(plot_path, "subplot_of_plane_1.png") in plot_patch.paths - - plot_patch.paths = [] - - fit_plotter.subplot_of_planes(plane_index=0) - - assert path.join(plot_path, "subplot_of_plane_0.png") in plot_patch.paths - assert path.join(plot_path, "subplot_of_plane_1.png") not in plot_patch.paths diff --git a/test_autolens/interferometer/model/test_plotter_interface_interferometer.py b/test_autolens/interferometer/model/test_plotter_interface_interferometer.py index ec7cd8d83..b869a7b15 100644 --- a/test_autolens/interferometer/model/test_plotter_interface_interferometer.py +++ b/test_autolens/interferometer/model/test_plotter_interface_interferometer.py @@ -1,44 +1,44 @@ -from os import path - -import pytest - -import autolens as al - -from autolens.interferometer.model.plotter_interface import ( - PlotterInterfaceInterferometer, -) - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plotter_interface_plotter_setup(): - return path.join("{}".format(directory), "files") - - -def test__fit_interferometer( - fit_interferometer_x2_plane_7x7, - plot_path, - plot_patch, -): - plotter_interface = PlotterInterfaceInterferometer(image_path=plot_path) - - plotter_interface.fit_interferometer( - fit=fit_interferometer_x2_plane_7x7, - ) - - assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths - assert path.join(plot_path, "subplot_fit_real_space.png") in plot_patch.paths - assert path.join(plot_path, "subplot_fit_dirty_images.png") in plot_patch.paths - - image = al.ndarray_via_fits_from( - file_path=path.join(plot_path, "galaxy_images.fits"), hdu=0 - ) - - assert image.shape == (5, 5) - - image = al.ndarray_via_fits_from( - file_path=path.join(plot_path, "dirty_images.fits"), hdu=0 - ) - - assert image.shape == (5, 5) +from os import path + +import pytest + +import autolens as al + +from autolens.interferometer.model.plotter_interface import ( + PlotterInterfaceInterferometer, +) + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plotter_interface_plotter_setup(): + return path.join("{}".format(directory), "files") + + +def test__fit_interferometer( + fit_interferometer_x2_plane_7x7, + plot_path, + plot_patch, +): + plotter_interface = PlotterInterfaceInterferometer(image_path=plot_path) + + plotter_interface.fit_interferometer( + fit=fit_interferometer_x2_plane_7x7, + ) + + assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths + assert path.join(plot_path, "subplot_fit_real_space.png") in plot_patch.paths + assert path.join(plot_path, "subplot_fit_dirty_images.png") in plot_patch.paths + + image = al.ndarray_via_fits_from( + file_path=path.join(plot_path, "galaxy_images.fits"), hdu=0 + ) + + assert image.shape == (5, 5) + + image = al.ndarray_via_fits_from( + file_path=path.join(plot_path, "dirty_images.fits"), hdu=0 + ) + + assert image.shape == (5, 5) diff --git a/test_autolens/interferometer/plot/test_fit_interferometer_plots.py b/test_autolens/interferometer/plot/test_fit_interferometer_plots.py new file mode 100644 index 000000000..1029f8ffe --- /dev/null +++ b/test_autolens/interferometer/plot/test_fit_interferometer_plots.py @@ -0,0 +1,39 @@ +from os import path +import pytest + +from autolens.interferometer.plot.fit_interferometer_plots import ( + subplot_fit, + subplot_fit_real_space, +) + + +@pytest.fixture(name="plot_path") +def make_fit_interferometer_plotter_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "fit" + ) + + +def test__subplot_fit( + fit_interferometer_x2_plane_7x7, plot_path, plot_patch +): + subplot_fit( + fit=fit_interferometer_x2_plane_7x7, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths + + +def test__subplot_fit_real_space( + fit_interferometer_x2_plane_7x7, + fit_interferometer_x2_plane_inversion_7x7, + plot_path, + plot_patch, +): + subplot_fit_real_space( + fit=fit_interferometer_x2_plane_7x7, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_fit_real_space.png") in plot_patch.paths diff --git a/test_autolens/interferometer/plot/test_fit_interferometer_plotters.py b/test_autolens/interferometer/plot/test_fit_interferometer_plotters.py deleted file mode 100644 index b405caa4b..000000000 --- a/test_autolens/interferometer/plot/test_fit_interferometer_plotters.py +++ /dev/null @@ -1,120 +0,0 @@ -from os import path -import autolens.plot as aplt -import pytest - - -@pytest.fixture(name="plot_path") -def make_fit_imaging_plotter_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "fit" - ) - - -def test__fit_quantities_are_output( - fit_interferometer_x2_plane_7x7, plot_path, plot_patch -): - fit_plotter = aplt.FitInterferometerPlotter( - fit=fit_interferometer_x2_plane_7x7, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_plotter.figures_2d( - data=True, - noise_map=True, - signal_to_noise_map=True, - model_data=True, - residual_map_real=True, - residual_map_imag=True, - normalized_residual_map_real=True, - normalized_residual_map_imag=True, - chi_squared_map_real=True, - chi_squared_map_imag=True, - image=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "model_data.png") in plot_patch.paths - assert ( - path.join(plot_path, "real_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "real_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "real_normalized_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_normalized_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert path.join(plot_path, "image_2d.png") in plot_patch.paths - - plot_patch.paths = [] - - fit_plotter.figures_2d( - data=True, - noise_map=False, - signal_to_noise_map=False, - model_data=True, - chi_squared_map_real=True, - chi_squared_map_imag=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") not in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") not in plot_patch.paths - assert path.join(plot_path, "model_data.png") in plot_patch.paths - assert ( - path.join(plot_path, "real_residual_map_vs_uv_distances.png") - not in plot_patch.paths - ) - assert ( - path.join(plot_path, "real_normalized_residual_map_vs_uv_distances.png") - not in plot_patch.paths - ) - assert ( - path.join(plot_path, "real_chi_squared_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_residual_map_vs_uv_distances.png") - not in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_normalized_residual_map_vs_uv_distances.png") - not in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_chi_squared_map_vs_uv_distances.png") - in plot_patch.paths - ) - - -def test__fit_sub_plot_real_space( - fit_interferometer_x2_plane_7x7, - fit_interferometer_x2_plane_inversion_7x7, - plot_path, - plot_patch, -): - fit_plotter = aplt.FitInterferometerPlotter( - fit=fit_interferometer_x2_plane_7x7, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), - ) - - fit_plotter.subplot_fit_real_space() - assert path.join(plot_path, "subplot_fit_real_space.png") in plot_patch.paths diff --git a/test_autolens/lens/plot/test_tracer_plots.py b/test_autolens/lens/plot/test_tracer_plots.py new file mode 100644 index 000000000..9cd6ba2b3 --- /dev/null +++ b/test_autolens/lens/plot/test_tracer_plots.py @@ -0,0 +1,56 @@ +from os import path + +import pytest + +import autolens.plot as aplt +from autolens.lens.plot.tracer_plots import ( + subplot_tracer, + subplot_lensed_images, + subplot_galaxies_images, +) + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_tracer_plotter_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), + "files", + "plots", + "tracer", + ) + + +def test__subplot_tracer(tracer_x2_plane_7x7, grid_2d_7x7, plot_path, plot_patch): + subplot_tracer( + tracer=tracer_x2_plane_7x7, + grid=grid_2d_7x7, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_tracer.png") in plot_patch.paths + + +def test__subplot_galaxies_images( + tracer_x2_plane_7x7, grid_2d_7x7, plot_path, plot_patch +): + subplot_galaxies_images( + tracer=tracer_x2_plane_7x7, + grid=grid_2d_7x7, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_galaxies_images.png") in plot_patch.paths + + +def test__subplot_lensed_images( + tracer_x2_plane_7x7, grid_2d_7x7, plot_path, plot_patch +): + subplot_lensed_images( + tracer=tracer_x2_plane_7x7, + grid=grid_2d_7x7, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_lensed_images.png") in plot_patch.paths diff --git a/test_autolens/lens/plot/test_tracer_plotters.py b/test_autolens/lens/plot/test_tracer_plotters.py deleted file mode 100644 index 67cb86e60..000000000 --- a/test_autolens/lens/plot/test_tracer_plotters.py +++ /dev/null @@ -1,109 +0,0 @@ -from os import path - -import pytest - -import autolens.plot as aplt - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_tracer_plotter_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), - "files", - "plots", - "tracer", - ) - - -def test__all_individual_plotter( - tracer_x2_plane_7x7, - grid_2d_7x7, - mask_2d_7x7, - plot_path, - plot_patch, -): - tracer_plotter = aplt.TracerPlotter( - tracer=tracer_x2_plane_7x7, - grid=grid_2d_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), - ) - - tracer_plotter.figures_2d( - image=True, - source_plane=True, - convergence=True, - potential=True, - deflections_y=True, - deflections_x=True, - magnification=True, - ) - - assert path.join(plot_path, "image_2d.png") in plot_patch.paths - assert path.join(plot_path, "plane_image_of_plane_1.png") in plot_patch.paths - assert path.join(plot_path, "convergence_2d.png") in plot_patch.paths - assert path.join(plot_path, "potential_2d.png") in plot_patch.paths - assert path.join(plot_path, "deflections_y_2d.png") in plot_patch.paths - assert path.join(plot_path, "deflections_x_2d.png") in plot_patch.paths - assert path.join(plot_path, "magnification_2d.png") in plot_patch.paths - - plot_patch.paths = [] - - tracer_plotter = aplt.TracerPlotter( - tracer=tracer_x2_plane_7x7, - grid=grid_2d_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), - ) - - tracer_plotter.figures_2d( - image=True, source_plane=True, potential=True, magnification=True - ) - - assert path.join(plot_path, "image_2d.png") in plot_patch.paths - assert path.join(plot_path, "plane_image_of_plane_1.png") in plot_patch.paths - assert path.join(plot_path, "convergence_2d.png") not in plot_patch.paths - assert path.join(plot_path, "potential_2d.png") in plot_patch.paths - assert path.join(plot_path, "deflections_y_2d.png") not in plot_patch.paths - assert path.join(plot_path, "deflections_x_2d.png") not in plot_patch.paths - assert path.join(plot_path, "magnification_2d.png") in plot_patch.paths - - -def test__figures_of_plane( - tracer_x2_plane_7x7, - grid_2d_7x7, - mask_2d_7x7, - plot_path, - plot_patch, -): - tracer_plotter = aplt.TracerPlotter( - tracer=tracer_x2_plane_7x7, - grid=grid_2d_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - tracer_plotter.figures_2d_of_planes(plane_image=True, plane_grid=True) - - assert path.join(plot_path, "plane_image_of_plane_0.png") in plot_patch.paths - assert path.join(plot_path, "plane_image_of_plane_1.png") in plot_patch.paths - - plot_patch.paths = [] - - tracer_plotter.figures_2d_of_planes(plane_index=0, plane_image=True) - - assert path.join(plot_path, "plane_image_of_plane_0.png") in plot_patch.paths - assert path.join(plot_path, "plane_image_of_plane_1.png") not in plot_patch.paths - - -def test__tracer_plot_output(tracer_x2_plane_7x7, grid_2d_7x7, plot_path, plot_patch): - tracer_plotter = aplt.TracerPlotter( - tracer=tracer_x2_plane_7x7, - grid=grid_2d_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), - ) - - tracer_plotter.subplot_tracer() - assert path.join(plot_path, "subplot_tracer.png") in plot_patch.paths - - tracer_plotter.subplot_galaxies_images() - assert path.join(plot_path, "subplot_galaxies_images.png") in plot_patch.paths diff --git a/test_autolens/point/model/test_plotter_interface_point.py b/test_autolens/point/model/test_plotter_interface_point.py index 73b4a02b4..a10eeaf82 100644 --- a/test_autolens/point/model/test_plotter_interface_point.py +++ b/test_autolens/point/model/test_plotter_interface_point.py @@ -1,24 +1,24 @@ -import os -import shutil -from os import path - -import pytest -from autolens.point.model.plotter_interface import PlotterInterfacePoint - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plotter_interface_plotter_setup(): - return path.join("{}".format(directory), "files") - - -def test__fit_point(fit_point_dataset_x2_plane, plot_path, plot_patch): - if os.path.exists(plot_path): - shutil.rmtree(plot_path) - - plotter_interface = PlotterInterfacePoint(image_path=plot_path) - - plotter_interface.fit_point(fit=fit_point_dataset_x2_plane) - - assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths +import os +import shutil +from os import path + +import pytest +from autolens.point.model.plotter_interface import PlotterInterfacePoint + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plotter_interface_plotter_setup(): + return path.join("{}".format(directory), "files") + + +def test__fit_point(fit_point_dataset_x2_plane, plot_path, plot_patch): + if os.path.exists(plot_path): + shutil.rmtree(plot_path) + + plotter_interface = PlotterInterfacePoint(image_path=plot_path) + + plotter_interface.fit_point(fit=fit_point_dataset_x2_plane) + + assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths diff --git a/test_autolens/point/plot/test_fit_point_plots.py b/test_autolens/point/plot/test_fit_point_plots.py new file mode 100644 index 000000000..17ad135ea --- /dev/null +++ b/test_autolens/point/plot/test_fit_point_plots.py @@ -0,0 +1,26 @@ +from os import path + +import pytest + +from autolens.point.plot.fit_point_plots import subplot_fit + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_fit_point_plotter_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), + "files", + "plots", + "fit_point", + ) + + +def test__subplot_fit(fit_point_dataset_x2_plane, plot_path, plot_patch): + subplot_fit( + fit=fit_point_dataset_x2_plane, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths diff --git a/test_autolens/point/plot/test_fit_point_plotters.py b/test_autolens/point/plot/test_fit_point_plotters.py deleted file mode 100644 index 43eed3d4f..000000000 --- a/test_autolens/point/plot/test_fit_point_plotters.py +++ /dev/null @@ -1,65 +0,0 @@ -from os import path - -import pytest - -import autolens.plot as aplt - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_fit_point_plotter_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), - "files", - "plots", - "fit_point", - ) - - -def test__fit_point_quantities_are_output( - fit_point_dataset_x2_plane, plot_path, plot_patch -): - fit_point_plotter = aplt.FitPointDatasetPlotter( - fit=fit_point_dataset_x2_plane, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_point_plotter.figures_2d(positions=True, fluxes=True) - - assert path.join(plot_path, "fit_point_positions.png") in plot_patch.paths - assert path.join(plot_path, "fit_point_fluxes.png") in plot_patch.paths - - plot_patch.paths = [] - - fit_point_plotter.figures_2d(positions=True, fluxes=False) - - assert path.join(plot_path, "fit_point_positions.png") in plot_patch.paths - assert path.join(plot_path, "fit_point_fluxes.png") not in plot_patch.paths - - plot_patch.paths = [] - - fit_point_dataset_x2_plane.dataset.fluxes = None - - fit_point_plotter = aplt.FitPointDatasetPlotter( - fit=fit_point_dataset_x2_plane, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_point_plotter.figures_2d(positions=True, fluxes=True) - - assert path.join(plot_path, "fit_point_positions.png") in plot_patch.paths - assert path.join(plot_path, "fit_point_fluxes.png") not in plot_patch.paths - - -def test__subplot_fit(fit_point_dataset_x2_plane, plot_path, plot_patch): - fit_point_plotter = aplt.FitPointDatasetPlotter( - fit=fit_point_dataset_x2_plane, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_point_plotter.subplot_fit() - - assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths diff --git a/test_autolens/point/plot/test_point_dataset_plots.py b/test_autolens/point/plot/test_point_dataset_plots.py new file mode 100644 index 000000000..29862211f --- /dev/null +++ b/test_autolens/point/plot/test_point_dataset_plots.py @@ -0,0 +1,26 @@ +from os import path + +import pytest + +from autolens.point.plot.point_dataset_plots import subplot_dataset + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_point_dataset_plotter_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), + "files", + "plots", + "point_dataset", + ) + + +def test__subplot_dataset(point_dataset, plot_path, plot_patch): + subplot_dataset( + dataset=point_dataset, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_dataset_point.png") in plot_patch.paths diff --git a/test_autolens/point/plot/test_point_dataset_plotters.py b/test_autolens/point/plot/test_point_dataset_plotters.py deleted file mode 100644 index 404015750..000000000 --- a/test_autolens/point/plot/test_point_dataset_plotters.py +++ /dev/null @@ -1,63 +0,0 @@ -from os import path - -import pytest - -import autolens.plot as aplt - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_point_dataset_plotter_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), - "files", - "plots", - "point_dataset", - ) - - -def test__point_dataset_quantities_are_output(point_dataset, plot_path, plot_patch): - point_dataset_plotter = aplt.PointDatasetPlotter( - dataset=point_dataset, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - point_dataset_plotter.figures_2d(positions=True, fluxes=True) - - assert path.join(plot_path, "point_dataset_positions.png") in plot_patch.paths - assert path.join(plot_path, "point_dataset_fluxes.png") in plot_patch.paths - - plot_patch.paths = [] - - point_dataset_plotter.figures_2d(positions=True, fluxes=False) - - assert path.join(plot_path, "point_dataset_positions.png") in plot_patch.paths - assert path.join(plot_path, "point_dataset_fluxes.png") not in plot_patch.paths - - plot_patch.paths = [] - - point_dataset.fluxes = None - - point_dataset_plotter = aplt.PointDatasetPlotter( - dataset=point_dataset, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - point_dataset_plotter.figures_2d(positions=True, fluxes=True) - - assert path.join(plot_path, "point_dataset_positions.png") in plot_patch.paths - assert path.join(plot_path, "point_dataset_fluxes.png") not in plot_patch.paths - - -def test__subplot_dataset(point_dataset, plot_path, plot_patch): - point_dataset_plotter = aplt.PointDatasetPlotter( - dataset=point_dataset, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - point_dataset_plotter.subplot_dataset() - - assert path.join(plot_path, "subplot_dataset_point.png") in plot_patch.paths