Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions autolens/analysis/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ def should_plot(name):
output_path = str(self.image_path)
fmt = self.fmt

if should_plot("subplot_tracer"):
subplot_tracer(
tracer=tracer,
grid=grid,
output_path=output_path,
output_format=fmt,
)

if should_plot("subplot_galaxies_images"):
subplot_galaxies_images(
tracer=tracer,
Expand Down
215 changes: 166 additions & 49 deletions autolens/imaging/plot/fit_imaging_plots.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import matplotlib.pyplot as plt
import numpy as np
from typing import Optional, List
Expand All @@ -8,8 +9,12 @@
from autoarray.plot.array import plot_array, _zoom_array_2d
from autoarray.plot.utils import save_figure, hide_unused_axes, conf_subplot_figsize
from autoarray.plot.utils import numpy_lines as _to_lines
from autoarray.inversion.mappers.abstract import Mapper
from autoarray.inversion.plot.mapper_plots import plot_mapper
from autogalaxy.plot.plot_utils import _critical_curves_from, _caustics_from

logger = logging.getLogger(__name__)


def _get_source_vmax(fit):
"""
Expand Down Expand Up @@ -39,7 +44,8 @@ def _get_source_vmax(fit):


def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
colormap=None, use_log10=False):
colormap=None, use_log10=False, title=None,
lines=None, line_colors=None):
"""
Plot the source-plane image (or a blank inversion placeholder) into an axes.

Expand All @@ -48,9 +54,10 @@ def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
function ray-traces a zoomed image-plane grid to the source plane,
evaluates the source-galaxy light, and renders the resulting 2-D array
via :func:`~autoarray.plot.array.plot_array`. When the plane *does*
contain a pixelization (an inversion source), the axes are turned off
and labelled as "Source Reconstruction" instead, because the inversion
reconstruction is rendered separately by the inversion plotter.
contain a pixelization (an inversion source), the source reconstruction
is rendered via :func:`~autoarray.inversion.plot.mapper_plots.plot_mapper`
using ``zoom_to_brightest`` to control whether the view is zoomed in on
the brightest pixels or shown at full extent.

Parameters
----------
Expand All @@ -62,8 +69,9 @@ def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
plane_index : int
Index of the plane in ``fit.tracer.planes`` to visualise.
zoom_to_brightest : bool, optional
Passed through to the zoom logic (currently unused in the
rendering call but reserved for future use).
For inversion sources, zooms the colormap extent to the brightest
reconstructed pixels. For parametric sources, this parameter has
no effect.
colormap : str, optional
Matplotlib colormap name.
use_log10 : bool, optional
Expand All @@ -80,14 +88,33 @@ def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
image = plane_galaxies.image_2d_from(grid=traced_grids[plane_index])
plot_array(
array=image, ax=ax,
title=f"Source Plane {plane_index}",
colormap=colormap, use_log10=use_log10,
title=title if title is not None else f"Source Plane {plane_index}",
colormap=colormap, use_log10=use_log10, lines=lines,
line_colors=line_colors,
)
else:
# Inversion path: in subplot context show a blank panel.
if ax is not None:
ax.axis("off")
ax.set_title(f"Source Reconstruction (plane {plane_index})")
# Inversion path: plot the source reconstruction via the mapper.
try:
inversion = fit.inversion
mapper_list = inversion.cls_list_from(cls=Mapper)
mapper = mapper_list[plane_index - 1] if plane_index > 0 else mapper_list[0]
pixel_values = inversion.reconstruction_dict[mapper]
plot_mapper(
mapper,
solution_vector=pixel_values,
ax=ax,
title=title if title is not None else f"Source Reconstruction (plane {plane_index})",
colormap=colormap,
use_log10=use_log10,
zoom_to_brightest=zoom_to_brightest,
lines=lines,
line_colors=line_colors,
)
except Exception as exc:
logger.warning(f"Could not plot source reconstruction for plane {plane_index}: {exc}")
if ax is not None:
ax.axis("off")
ax.set_title(f"Source Reconstruction (plane {plane_index})")


def subplot_fit(
Expand Down Expand Up @@ -144,6 +171,31 @@ def subplot_fit(

source_vmax = _get_source_vmax(fit)

tracer = fit.tracer_linear_light_profiles_to_light_profiles
try:
_zoom = aa.Zoom2D(mask=fit.mask)
_cc_grid = aa.Grid2D.from_extent(
extent=_zoom.extent_from(buffer=0),
shape_native=_zoom.shape_native,
)
tan_cc, rad_cc = _critical_curves_from(tracer, _cc_grid)
tan_ca, rad_ca = _caustics_from(tracer, _cc_grid)
_tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or []
_rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or []
_tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or []
_rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or []
image_plane_lines = _tan_cc_lines + _rad_cc_lines
image_plane_line_colors = ["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines)
source_plane_lines = _tan_ca_lines + _rad_ca_lines
source_plane_line_colors = ["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines)
image_plane_lines = image_plane_lines or None
source_plane_lines = source_plane_lines or None
except Exception:
image_plane_lines = None
image_plane_line_colors = None
source_plane_lines = None
source_plane_line_colors = None

fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
axes_flat = list(axes.flatten())

Expand All @@ -156,7 +208,8 @@ def subplot_fit(
plot_array(array=fit.signal_to_noise_map, ax=axes_flat[2],
title="Signal-To-Noise Map", colormap=colormap)
plot_array(array=fit.model_data, ax=axes_flat[3], title="Model Image",
colormap=colormap)
colormap=colormap, lines=image_plane_lines,
line_colors=image_plane_line_colors)

# Lens model image
try:
Expand Down Expand Up @@ -188,31 +241,34 @@ def subplot_fit(
source_model_img = None
if source_model_img is not None:
plot_array(array=source_model_img, ax=axes_flat[6], title="Source Model Image",
colormap=colormap, vmax=source_vmax)
colormap=colormap, vmax=source_vmax, lines=image_plane_lines,
line_colors=image_plane_line_colors)
else:
axes_flat[6].axis("off")

# Source plane zoomed
_plot_source_plane(fit, axes_flat[7], final_plane_index, zoom_to_brightest=True,
colormap=colormap)
colormap=colormap, title="Source Plane (Zoomed)",
lines=source_plane_lines, line_colors=source_plane_line_colors)

# Normalized residual map (symmetric)
norm_resid = fit.normalized_residual_map
_abs_max = _symmetric_vmax(norm_resid)
plot_array(array=norm_resid, ax=axes_flat[8], title="Normalized Residual Map",
colormap=colormap, vmin=-_abs_max, vmax=_abs_max, cb_unit=r"$\sigma$")
colormap=colormap, vmin=-_abs_max, vmax=_abs_max)

# Normalized residual map clipped to [-1, 1]
plot_array(array=norm_resid, ax=axes_flat[9],
title=r"Normalized Residual Map $1\sigma$",
colormap=colormap, vmin=-1.0, vmax=1.0, cb_unit=r"$\sigma$")
colormap=colormap, vmin=-1.0, vmax=1.0)

plot_array(array=fit.chi_squared_map, ax=axes_flat[10],
title="Chi-Squared Map", colormap=colormap, cb_unit=r"$\chi^2$")

# Source plane not zoomed
_plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False,
colormap=colormap)
colormap=colormap, title="Source Plane (No Zoom)",
lines=source_plane_lines, line_colors=source_plane_line_colors)

hide_unused_axes(axes_flat)
plt.tight_layout()
Expand Down Expand Up @@ -530,17 +586,16 @@ def subplot_tracer_from_fit(
"""
Produce a 9-panel tracer subplot derived from a `FitImaging` object.

Uses the best-fit linear-light-profile tracer to render:

* Model image (full lensed image)
* Source model image (source-plane brightness at image scale)
* Source plane image (evaluated on the image-plane grid, full extent)
* Lens-plane image with critical curves (log10 scale)
* Panels 5–9 are reserved (currently blank) for future mass-map panels

The critical curves are computed from the tracer via
:func:`~autogalaxy.plot.plot_utils._critical_curves_from` and overlaid
on the lens-plane image.
Panels (3x3 = 9 axes):
0: Model image with critical curves
1: Source model image (image-plane projection) with critical curves
2: Source plane (no zoom) with caustics
3: Lens image (log10) with critical curves
4: Convergence (log10)
5: Potential (log10)
6: Deflections Y with critical curves
7: Deflections X with critical curves
8: Magnification with critical curves

Parameters
----------
Expand All @@ -554,44 +609,106 @@ def subplot_tracer_from_fit(
colormap : str, optional
Matplotlib colormap name applied to all image panels.
"""
from autogalaxy.operate.lens_calc import LensCalc

final_plane_index = len(fit.tracer.planes) - 1
tracer = fit.tracer_linear_light_profiles_to_light_profiles

# --- grid and critical curves (computed first so all panels can use them) ---
zoom = aa.Zoom2D(mask=fit.mask)
grid = aa.Grid2D.from_extent(
extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native
)

try:
tan_cc, rad_cc = _critical_curves_from(tracer, grid)
tan_ca, rad_ca = _caustics_from(tracer, grid)
_tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or []
_rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or []
_tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or []
_rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or []
image_plane_lines = (_tan_cc_lines + _rad_cc_lines) or None
image_plane_line_colors = ["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines)
source_plane_lines = (_tan_ca_lines + _rad_ca_lines) or None
source_plane_line_colors = ["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines)
except Exception:
image_plane_lines = None
image_plane_line_colors = None
source_plane_lines = None
source_plane_line_colors = None

source_vmax = _get_source_vmax(fit)

traced_grids = tracer.traced_grid_2d_list_from(grid=grid)
lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0])
lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0])

deflections = lens_galaxies.deflections_yx_2d_from(grid=grid)
deflections_y = aa.Array2D(values=deflections.slim[:, 0], mask=grid.mask)
deflections_x = aa.Array2D(values=deflections.slim[:, 1], mask=grid.mask)

magnification = LensCalc.from_mass_obj(tracer).magnification_2d_from(grid=grid)

fig, axes = plt.subplots(3, 3, figsize=conf_subplot_figsize(3, 3))
axes_flat = list(axes.flatten())

tracer = fit.tracer_linear_light_profiles_to_light_profiles

# Panel 0: Model Image
plot_array(array=fit.model_data, ax=axes_flat[0], title="Model Image",
lines=image_plane_lines, line_colors=image_plane_line_colors,
colormap=colormap)

# Panel 1: Source Model Image (image-plane projection)
try:
source_model_img = fit.model_images_of_planes_list[final_plane_index]
source_vmax = float(np.max(source_model_img.array))
except Exception:
source_model_img = None
if source_model_img is not None:
plot_array(array=source_model_img, ax=axes_flat[1], title="Source Model Image",
colormap=colormap, vmax=source_vmax)
except (IndexError, AttributeError, ValueError):
colormap=colormap, vmax=source_vmax,
lines=image_plane_lines, line_colors=image_plane_line_colors)
else:
axes_flat[1].axis("off")

# Panel 2: Source Plane (No Zoom)
_plot_source_plane(fit, axes_flat[2], final_plane_index, zoom_to_brightest=False,
colormap=colormap)
colormap=colormap, title="Source Plane (No Zoom)",
lines=source_plane_lines, line_colors=source_plane_line_colors)

# Lens plane mass quantities (log10)
zoom = aa.Zoom2D(mask=fit.mask)
grid = aa.Grid2D.from_extent(
extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native
)
# Panel 3: Lens Image (log10)
plot_array(array=lens_image, ax=axes_flat[3], title="Lens Image",
lines=image_plane_lines, line_colors=image_plane_line_colors,
colormap=colormap, use_log10=True)

tan_cc, rad_cc = _critical_curves_from(tracer, grid)
image_plane_lines = _to_lines(list(tan_cc) + (list(rad_cc) if rad_cc is not None else []))
# Panel 4: Convergence (log10)
try:
convergence = tracer.convergence_2d_from(grid=grid)
plot_array(array=convergence, ax=axes_flat[4], title="Convergence",
colormap=colormap, use_log10=True)
except Exception:
axes_flat[4].axis("off")

traced_grids = tracer.traced_grid_2d_list_from(grid=grid)
lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0])
lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0])
plot_array(array=lens_image, ax=axes_flat[3], title="Lens Image",
lines=image_plane_lines, colormap=colormap, use_log10=True)
# Panel 5: Potential (log10)
try:
potential = tracer.potential_2d_from(grid=grid)
plot_array(array=potential, ax=axes_flat[5], title="Potential",
colormap=colormap, use_log10=True)
except Exception:
axes_flat[5].axis("off")

# Panel 6: Deflections Y
plot_array(array=deflections_y, ax=axes_flat[6], title="Deflections Y",
lines=image_plane_lines, line_colors=image_plane_line_colors,
colormap=colormap)

for i in range(4, 9):
axes_flat[i].axis("off")
# Panel 7: Deflections X
plot_array(array=deflections_x, ax=axes_flat[7], title="Deflections X",
lines=image_plane_lines, line_colors=image_plane_line_colors,
colormap=colormap)

# Panel 8: Magnification
plot_array(array=magnification, ax=axes_flat[8], title="Magnification",
lines=image_plane_lines, line_colors=image_plane_line_colors,
colormap=colormap)

plt.tight_layout()
save_figure(fig, path=output_path, filename="tracer", format=output_format)
Expand Down
Loading
Loading