Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions PLOT_REFACTOR_PLAN.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Problems:
describes the workaround as a "nasty hack"
- The config system switches every wrap object between `figure:` and `subplot:` sections
based on whether `subplot_index is not None`, adding hidden state to every config lookup
- Nested plotters (FitImagingPlotterTracerPlotter → InversionPlotter) share one
- Nested plotters (FitImagingTracer → InversionPlotter) share one
mat_plot object so their indices accumulate in the same global counter

The fix is to use matplotlib's native `plt.subplots()` and pass `ax` objects directly.
Expand Down Expand Up @@ -272,7 +272,7 @@ constructor accepts `output_path` and `output_filename` strings.

---

#### PR A3 · Update `ImagingPlotter`, `InversionPlotter`, `MapperPlotter`, `InterferometerPlotter`
#### PR A3 · Update `Imaging`, `InversionPlotter`, `MapperPlotter`, `Interferometer`

Same `ax`-passing pattern. Mixed 1D/2D subplots (e.g. interferometer) use:

Expand Down Expand Up @@ -365,13 +365,13 @@ They have no config dependency.

---

#### PR G2 · Update `LightProfilePlotter`, `MassProfilePlotter`, `GalaxyPlotter`, `GalaxiesPlotter`
#### PR G2 · Update `LightProfile`, `MassProfilePlotter`, `Galaxy`, `Galaxies`

Each plotter computes its own overlay data from its galaxy/profile then passes it
to `plot_array`:

```python
class GalaxiesPlotter(AbstractPlotter):
class Galaxies(AbstractPlotter):
def figure_image(self, ax=None):
owns = ax is None
if owns:
Expand All @@ -390,7 +390,7 @@ Remove autogalaxy `MatPlot2D` subclass and autogalaxy `Visuals2D` subclass.

---

#### PR G3 · Update autogalaxy `FitImagingPlotter` and `FitInterferometerPlotter`
#### PR G3 · Update autogalaxy `FitImaging` and `FitInterferometer`

```python
def subplot_fit(self):
Expand Down Expand Up @@ -420,13 +420,13 @@ Update `autogalaxy/plot/__init__.py`.

---

#### PR L1 · Update `TracerPlotter`
#### PR L1 · Update `Tracer`

The plotter computes critical curves / caustics itself from the tracer, then passes
them as `lines` to `plot_array`:

```python
class TracerPlotter(AbstractPlotter):
class Tracer(AbstractPlotter):
def figure_convergence(self, ax=None):
owns = ax is None
if owns:
Expand Down Expand Up @@ -464,7 +464,7 @@ Add `show_critical_curves: bool = True`, `show_caustics: bool = True`.

---

#### PR L2 · Update `FitImagingPlotter`
#### PR L2 · Update `FitImaging`

Largest single plotter. The 12-panel `subplot_fit` becomes:

Expand Down Expand Up @@ -511,7 +511,7 @@ def subplot_of_planes(self):

---

#### PR L3 · Update `FitInterferometerPlotter`, `PointDatasetPlotter`, `FitPointDatasetPlotter`
#### PR L3 · Update `FitInterferometer`, `PointDatasetPlotter`, `FitPointDatasetPlotter`

**PointDatasetPlotter** — mixed 1D/2D, which was the "nasty hack" case:

Expand Down Expand Up @@ -566,15 +566,15 @@ plot_array(
|---|---|---|---|
| A1 | autoarray | Add `plots/` module | New unit tests |
| A2 | autoarray | Rewrite Array2D/Grid2DPlotter | Update existing |
| A3 | autoarray | Rewrite Imaging/Inversion/Mapper/InterferometerPlotter | Update existing |
| A3 | autoarray | Rewrite Imaging/Inversion/Mapper/Interferometer | Update existing |
| A4 | autoarray | Delete mat_plot/, wrap/, visuals/ | Delete wrap tests |
| A5 | autoarray | Config cleanup, finalise helpers | Smoke tests |
| G1 | autogalaxy | Add overlay helpers | New unit tests |
| G2 | autogalaxy | Rewrite Galaxy/Mass/LightProfile plotters | Update existing |
| G3 | autogalaxy | Rewrite FitImaging/FitInterferometer plotters | Update existing |
| G4 | autogalaxy | Delete MatPlot2D/Visuals2D extensions | Delete wrap tests |
| L1 | autolens | Rewrite TracerPlotter | Update existing |
| L2 | autolens | Rewrite FitImagingPlotter | Update existing |
| L1 | autolens | Rewrite Tracer | Update existing |
| L2 | autolens | Rewrite FitImaging | Update existing |
| L3 | autolens | Rewrite FitInterferometer/Point plotters | Update existing |
| L4 | autolens | Rewrite Subhalo plotters, clean abstract_plotters | Update existing |

Expand Down
76 changes: 7 additions & 69 deletions autolens/analysis/plotter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import ast
import numpy as np
from typing import Optional

from autoconf import conf
from autoconf.fitsable import hdu_list_for_output_from

import autoarray as aa
import autogalaxy as ag

Expand All @@ -13,7 +9,11 @@
from autogalaxy.analysis.plotter import Plotter as AgPlotter

from autolens.lens.tracer import Tracer
from autolens.lens.plot.tracer_plots import subplot_galaxies_images
from autolens.lens.plot.tracer_plots import (
subplot_galaxies_images,
save_tracer_fits,
save_source_plane_images_fits,
)
from autoarray.plot.array import plot_array


Expand Down Expand Up @@ -63,72 +63,10 @@ def should_plot(name):
)

if should_plot("fits_tracer"):

zoom = aa.Zoom2D(mask=grid.mask)
mask = zoom.mask_2d_from(buffer=1)
grid_zoom = aa.Grid2D.from_mask(mask=mask)

image_list = [
tracer.convergence_2d_from(grid=grid_zoom).native,
tracer.potential_2d_from(grid=grid_zoom).native,
tracer.deflections_yx_2d_from(grid=grid_zoom).native[:, :, 0],
tracer.deflections_yx_2d_from(grid=grid_zoom).native[:, :, 1],
]

hdu_list = hdu_list_for_output_from(
values_list=[image_list[0].mask.astype("float")] + image_list,
ext_name_list=[
"mask",
"convergence",
"potential",
"deflections_y",
"deflections_x",
],
header_dict=grid_zoom.mask.header_dict,
)

hdu_list.writeto(self.image_path / "tracer.fits", overwrite=True)
save_tracer_fits(tracer=tracer, grid=grid, output_path=self.image_path)

if should_plot("fits_source_plane_images"):

shape_native = conf.instance["visualize"]["plots"]["tracer"][
"fits_source_plane_shape"
]
shape_native = ast.literal_eval(shape_native)

zoom = aa.Zoom2D(mask=grid.mask)
mask = zoom.mask_2d_from(buffer=1)
grid_source_plane = aa.Grid2D.from_extent(
extent=mask.geometry.extent, shape_native=tuple(shape_native)
)

image_list = [grid_source_plane.mask.astype("float")]
ext_name_list = ["mask"]

for i, plane in enumerate(tracer.planes[1:]):

if plane.has(cls=ag.LightProfile):

image = plane.image_2d_from(
grid=grid_source_plane,
).native

else:

image = np.zeros(grid_source_plane.shape_native)

image_list.append(image)
ext_name_list.append(f"source_plane_image_{i+1}")

hdu_list = hdu_list_for_output_from(
values_list=image_list,
ext_name_list=ext_name_list,
header_dict=grid_source_plane.mask.header_dict,
)

hdu_list.writeto(
self.image_path / "source_plane_images.fits", overwrite=True
)
save_source_plane_images_fits(tracer=tracer, grid=grid, output_path=self.image_path)

def image_with_positions(self, image: aa.Array2D, positions: aa.Grid2DIrregular):
"""
Expand Down
23 changes: 12 additions & 11 deletions autolens/imaging/plot/fit_imaging_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import autogalaxy as ag

from autoarray.plot.array import plot_array, _zoom_array_2d
from autoarray.plot.utils import save_figure
from autoarray.plot.utils import save_figure, hide_unused_axes
from autoarray.plot.utils import numpy_lines as _to_lines
from autogalaxy.plot.plot_utils import _critical_curves_from, _caustics_from

Expand Down Expand Up @@ -39,7 +39,7 @@ def _get_source_vmax(fit):


def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
colormap="jet", use_log10=False):
colormap=None, use_log10=False):
"""
Plot the source-plane image (or a blank inversion placeholder) into an axes.

Expand Down Expand Up @@ -94,7 +94,7 @@ def subplot_fit(
fit,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
plane_index: Optional[int] = None,
):
"""
Expand Down Expand Up @@ -214,6 +214,7 @@ def subplot_fit(
_plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False,
colormap=colormap)

hide_unused_axes(axes_flat)
plt.tight_layout()
save_figure(fig, path=output_path, filename=f"subplot_fit{plane_index_tag}", format=output_format)

Expand All @@ -222,7 +223,7 @@ def subplot_fit_x1_plane(
fit,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
):
"""
Produce a 6-panel subplot for a single-plane tracer imaging fit.
Expand Down Expand Up @@ -286,7 +287,7 @@ def subplot_fit_log10(
fit,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
plane_index: Optional[int] = None,
):
"""
Expand Down Expand Up @@ -395,7 +396,7 @@ def subplot_fit_log10_x1_plane(
fit,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
):
"""
Produce a 6-panel log10 subplot for a single-plane tracer imaging fit.
Expand Down Expand Up @@ -456,7 +457,7 @@ def subplot_of_planes(
fit,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
plane_index: Optional[int] = None,
):
"""
Expand Down Expand Up @@ -524,7 +525,7 @@ def subplot_tracer_from_fit(
fit,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
):
"""
Produce a 9-panel tracer subplot derived from a `FitImaging` object.
Expand Down Expand Up @@ -581,7 +582,7 @@ def subplot_tracer_from_fit(
)

tan_cc, rad_cc = _critical_curves_from(tracer, grid)
image_plane_lines = _to_lines(list(tan_cc) + list(rad_cc))
image_plane_lines = _to_lines(list(tan_cc) + (list(rad_cc) if rad_cc is not None else []))

traced_grids = tracer.traced_grid_2d_list_from(grid=grid)
lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0])
Expand All @@ -600,7 +601,7 @@ def subplot_fit_combined(
fit_list: List,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
):
"""
Produce a combined multi-row subplot for a list of `FitImaging` objects.
Expand Down Expand Up @@ -682,7 +683,7 @@ def subplot_fit_combined_log10(
fit_list: List,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
):
"""
Produce a combined log10 multi-row subplot for a list of `FitImaging` objects.
Expand Down
6 changes: 3 additions & 3 deletions autolens/interferometer/plot/fit_interferometer_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _plot_yx(y, x, ax, title, xlabel="", ylabel=""):


def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
colormap="jet", use_log10=False):
colormap=None, use_log10=False):
"""
Plot the source-plane image (or a blank inversion placeholder) into an axes.

Expand Down Expand Up @@ -88,7 +88,7 @@ def subplot_fit(
fit,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
):
"""
Produce a 12-panel subplot summarising an interferometer fit.
Expand Down Expand Up @@ -197,7 +197,7 @@ def subplot_fit_real_space(
fit,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
):
"""
Produce a real-space subplot for an interferometer fit.
Expand Down
6 changes: 3 additions & 3 deletions autolens/lens/plot/sensitivity_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def subplot_tracer_images(
source_image,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
use_log10: bool = False,
):
"""
Expand Down Expand Up @@ -120,7 +120,7 @@ def subplot_sensitivity(
data_subtracted,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
use_log10: bool = False,
):
"""
Expand Down Expand Up @@ -248,7 +248,7 @@ def subplot_figures_of_merit_grid(
result,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
use_log_evidences: bool = True,
remove_zeros: bool = True,
):
Expand Down
4 changes: 2 additions & 2 deletions autolens/lens/plot/subhalo_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def subplot_detection_imaging(
fit_imaging_with_subhalo,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
use_log10: bool = False,
use_log_evidences: bool = True,
relative_to_value: float = 0.0,
Expand Down Expand Up @@ -103,7 +103,7 @@ def subplot_detection_fits(
fit_imaging_with_subhalo,
output_path: Optional[str] = None,
output_format: str = "png",
colormap: str = "jet",
colormap: Optional[str] = None,
):
"""
Produce a 6-panel subplot comparing imaging fits with and without a subhalo.
Expand Down
Loading
Loading