From 545ce1de7c13cf2aa94db539030dd4a808cac37e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 26 Mar 2026 20:22:03 +0000 Subject: [PATCH] Plot improvements: subplot_tracer_from_fit overhaul, line_colors, caustics - subplot_tracer_from_fit: compute critical curves and caustics before drawing panels so all source panels receive line overlays; add deflections Y/X and magnification to panels 6-8; use lens_galaxies.deflections_yx_2d_from for correct panel content - subplot_tracer: add critical curves (black=tangential, white=radial) and caustics; harden source image and source plane panels; add source_vmax scaling - _plot_source_plane: log exceptions instead of silently blanking the axis - plotter.py: remove standalone subplot_tracer call from tracer plotter (imaging plotter handles subplot_tracer via subplot_tracer_from_fit) Co-Authored-By: Claude Sonnet 4.6 --- autolens/analysis/plotter.py | 8 - autolens/imaging/plot/fit_imaging_plots.py | 215 ++++++++++++++++----- autolens/lens/plot/tracer_plots.py | 40 ++-- 3 files changed, 192 insertions(+), 71 deletions(-) diff --git a/autolens/analysis/plotter.py b/autolens/analysis/plotter.py index b7ef98b15..1d29b8313 100644 --- a/autolens/analysis/plotter.py +++ b/autolens/analysis/plotter.py @@ -55,14 +55,6 @@ def should_plot(name): output_path = str(self.image_path) fmt = self.fmt - if should_plot("subplot_tracer"): - subplot_tracer( - tracer=tracer, - grid=grid, - output_path=output_path, - output_format=fmt, - ) - if should_plot("subplot_galaxies_images"): subplot_galaxies_images( tracer=tracer, diff --git a/autolens/imaging/plot/fit_imaging_plots.py b/autolens/imaging/plot/fit_imaging_plots.py index df6418c0e..04a123725 100644 --- a/autolens/imaging/plot/fit_imaging_plots.py +++ b/autolens/imaging/plot/fit_imaging_plots.py @@ -1,3 +1,4 @@ +import logging import matplotlib.pyplot as plt import numpy as np from typing import Optional, List @@ -8,8 +9,12 @@ from autoarray.plot.array import plot_array, _zoom_array_2d from autoarray.plot.utils import save_figure, hide_unused_axes, conf_subplot_figsize from autoarray.plot.utils import numpy_lines as _to_lines +from autoarray.inversion.mappers.abstract import Mapper +from autoarray.inversion.plot.mapper_plots import plot_mapper from autogalaxy.plot.plot_utils import _critical_curves_from, _caustics_from +logger = logging.getLogger(__name__) + def _get_source_vmax(fit): """ @@ -39,7 +44,8 @@ def _get_source_vmax(fit): def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True, - colormap=None, use_log10=False): + colormap=None, use_log10=False, title=None, + lines=None, line_colors=None): """ Plot the source-plane image (or a blank inversion placeholder) into an axes. @@ -48,9 +54,10 @@ def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True, function ray-traces a zoomed image-plane grid to the source plane, evaluates the source-galaxy light, and renders the resulting 2-D array via :func:`~autoarray.plot.array.plot_array`. When the plane *does* - contain a pixelization (an inversion source), the axes are turned off - and labelled as "Source Reconstruction" instead, because the inversion - reconstruction is rendered separately by the inversion plotter. + contain a pixelization (an inversion source), the source reconstruction + is rendered via :func:`~autoarray.inversion.plot.mapper_plots.plot_mapper` + using ``zoom_to_brightest`` to control whether the view is zoomed in on + the brightest pixels or shown at full extent. Parameters ---------- @@ -62,8 +69,9 @@ def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True, plane_index : int Index of the plane in ``fit.tracer.planes`` to visualise. zoom_to_brightest : bool, optional - Passed through to the zoom logic (currently unused in the - rendering call but reserved for future use). + For inversion sources, zooms the colormap extent to the brightest + reconstructed pixels. For parametric sources, this parameter has + no effect. colormap : str, optional Matplotlib colormap name. use_log10 : bool, optional @@ -80,14 +88,33 @@ def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True, image = plane_galaxies.image_2d_from(grid=traced_grids[plane_index]) plot_array( array=image, ax=ax, - title=f"Source Plane {plane_index}", - colormap=colormap, use_log10=use_log10, + title=title if title is not None else f"Source Plane {plane_index}", + colormap=colormap, use_log10=use_log10, lines=lines, + line_colors=line_colors, ) else: - # Inversion path: in subplot context show a blank panel. - if ax is not None: - ax.axis("off") - ax.set_title(f"Source Reconstruction (plane {plane_index})") + # Inversion path: plot the source reconstruction via the mapper. + try: + inversion = fit.inversion + mapper_list = inversion.cls_list_from(cls=Mapper) + mapper = mapper_list[plane_index - 1] if plane_index > 0 else mapper_list[0] + pixel_values = inversion.reconstruction_dict[mapper] + plot_mapper( + mapper, + solution_vector=pixel_values, + ax=ax, + title=title if title is not None else f"Source Reconstruction (plane {plane_index})", + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=zoom_to_brightest, + lines=lines, + line_colors=line_colors, + ) + except Exception as exc: + logger.warning(f"Could not plot source reconstruction for plane {plane_index}: {exc}") + if ax is not None: + ax.axis("off") + ax.set_title(f"Source Reconstruction (plane {plane_index})") def subplot_fit( @@ -144,6 +171,31 @@ def subplot_fit( source_vmax = _get_source_vmax(fit) + tracer = fit.tracer_linear_light_profiles_to_light_profiles + try: + _zoom = aa.Zoom2D(mask=fit.mask) + _cc_grid = aa.Grid2D.from_extent( + extent=_zoom.extent_from(buffer=0), + shape_native=_zoom.shape_native, + ) + tan_cc, rad_cc = _critical_curves_from(tracer, _cc_grid) + tan_ca, rad_ca = _caustics_from(tracer, _cc_grid) + _tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or [] + _rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or [] + _tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or [] + _rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or [] + image_plane_lines = _tan_cc_lines + _rad_cc_lines + image_plane_line_colors = ["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines) + source_plane_lines = _tan_ca_lines + _rad_ca_lines + source_plane_line_colors = ["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines) + image_plane_lines = image_plane_lines or None + source_plane_lines = source_plane_lines or None + except Exception: + image_plane_lines = None + image_plane_line_colors = None + source_plane_lines = None + source_plane_line_colors = None + fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4)) axes_flat = list(axes.flatten()) @@ -156,7 +208,8 @@ def subplot_fit( plot_array(array=fit.signal_to_noise_map, ax=axes_flat[2], title="Signal-To-Noise Map", colormap=colormap) plot_array(array=fit.model_data, ax=axes_flat[3], title="Model Image", - colormap=colormap) + colormap=colormap, lines=image_plane_lines, + line_colors=image_plane_line_colors) # Lens model image try: @@ -188,31 +241,34 @@ def subplot_fit( source_model_img = None if source_model_img is not None: plot_array(array=source_model_img, ax=axes_flat[6], title="Source Model Image", - colormap=colormap, vmax=source_vmax) + colormap=colormap, vmax=source_vmax, lines=image_plane_lines, + line_colors=image_plane_line_colors) else: axes_flat[6].axis("off") # Source plane zoomed _plot_source_plane(fit, axes_flat[7], final_plane_index, zoom_to_brightest=True, - colormap=colormap) + colormap=colormap, title="Source Plane (Zoomed)", + lines=source_plane_lines, line_colors=source_plane_line_colors) # Normalized residual map (symmetric) norm_resid = fit.normalized_residual_map _abs_max = _symmetric_vmax(norm_resid) plot_array(array=norm_resid, ax=axes_flat[8], title="Normalized Residual Map", - colormap=colormap, vmin=-_abs_max, vmax=_abs_max, cb_unit=r"$\sigma$") + colormap=colormap, vmin=-_abs_max, vmax=_abs_max) # Normalized residual map clipped to [-1, 1] plot_array(array=norm_resid, ax=axes_flat[9], title=r"Normalized Residual Map $1\sigma$", - colormap=colormap, vmin=-1.0, vmax=1.0, cb_unit=r"$\sigma$") + colormap=colormap, vmin=-1.0, vmax=1.0) plot_array(array=fit.chi_squared_map, ax=axes_flat[10], title="Chi-Squared Map", colormap=colormap, cb_unit=r"$\chi^2$") # Source plane not zoomed _plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False, - colormap=colormap) + colormap=colormap, title="Source Plane (No Zoom)", + lines=source_plane_lines, line_colors=source_plane_line_colors) hide_unused_axes(axes_flat) plt.tight_layout() @@ -530,17 +586,16 @@ def subplot_tracer_from_fit( """ Produce a 9-panel tracer subplot derived from a `FitImaging` object. - Uses the best-fit linear-light-profile tracer to render: - - * Model image (full lensed image) - * Source model image (source-plane brightness at image scale) - * Source plane image (evaluated on the image-plane grid, full extent) - * Lens-plane image with critical curves (log10 scale) - * Panels 5–9 are reserved (currently blank) for future mass-map panels - - The critical curves are computed from the tracer via - :func:`~autogalaxy.plot.plot_utils._critical_curves_from` and overlaid - on the lens-plane image. + Panels (3x3 = 9 axes): + 0: Model image with critical curves + 1: Source model image (image-plane projection) with critical curves + 2: Source plane (no zoom) with caustics + 3: Lens image (log10) with critical curves + 4: Convergence (log10) + 5: Potential (log10) + 6: Deflections Y with critical curves + 7: Deflections X with critical curves + 8: Magnification with critical curves Parameters ---------- @@ -554,44 +609,106 @@ def subplot_tracer_from_fit( colormap : str, optional Matplotlib colormap name applied to all image panels. """ + from autogalaxy.operate.lens_calc import LensCalc + final_plane_index = len(fit.tracer.planes) - 1 + tracer = fit.tracer_linear_light_profiles_to_light_profiles + + # --- grid and critical curves (computed first so all panels can use them) --- + zoom = aa.Zoom2D(mask=fit.mask) + grid = aa.Grid2D.from_extent( + extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native + ) + + try: + tan_cc, rad_cc = _critical_curves_from(tracer, grid) + tan_ca, rad_ca = _caustics_from(tracer, grid) + _tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or [] + _rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or [] + _tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or [] + _rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or [] + image_plane_lines = (_tan_cc_lines + _rad_cc_lines) or None + image_plane_line_colors = ["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines) + source_plane_lines = (_tan_ca_lines + _rad_ca_lines) or None + source_plane_line_colors = ["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines) + except Exception: + image_plane_lines = None + image_plane_line_colors = None + source_plane_lines = None + source_plane_line_colors = None + + source_vmax = _get_source_vmax(fit) + + traced_grids = tracer.traced_grid_2d_list_from(grid=grid) + lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0]) + lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0]) + + deflections = lens_galaxies.deflections_yx_2d_from(grid=grid) + deflections_y = aa.Array2D(values=deflections.slim[:, 0], mask=grid.mask) + deflections_x = aa.Array2D(values=deflections.slim[:, 1], mask=grid.mask) + + magnification = LensCalc.from_mass_obj(tracer).magnification_2d_from(grid=grid) fig, axes = plt.subplots(3, 3, figsize=conf_subplot_figsize(3, 3)) axes_flat = list(axes.flatten()) - tracer = fit.tracer_linear_light_profiles_to_light_profiles - + # Panel 0: Model Image plot_array(array=fit.model_data, ax=axes_flat[0], title="Model Image", + lines=image_plane_lines, line_colors=image_plane_line_colors, colormap=colormap) + # Panel 1: Source Model Image (image-plane projection) try: source_model_img = fit.model_images_of_planes_list[final_plane_index] - source_vmax = float(np.max(source_model_img.array)) + except Exception: + source_model_img = None + if source_model_img is not None: plot_array(array=source_model_img, ax=axes_flat[1], title="Source Model Image", - colormap=colormap, vmax=source_vmax) - except (IndexError, AttributeError, ValueError): + colormap=colormap, vmax=source_vmax, + lines=image_plane_lines, line_colors=image_plane_line_colors) + else: axes_flat[1].axis("off") + # Panel 2: Source Plane (No Zoom) _plot_source_plane(fit, axes_flat[2], final_plane_index, zoom_to_brightest=False, - colormap=colormap) + colormap=colormap, title="Source Plane (No Zoom)", + lines=source_plane_lines, line_colors=source_plane_line_colors) - # Lens plane mass quantities (log10) - zoom = aa.Zoom2D(mask=fit.mask) - grid = aa.Grid2D.from_extent( - extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native - ) + # Panel 3: Lens Image (log10) + plot_array(array=lens_image, ax=axes_flat[3], title="Lens Image", + lines=image_plane_lines, line_colors=image_plane_line_colors, + colormap=colormap, use_log10=True) - tan_cc, rad_cc = _critical_curves_from(tracer, grid) - image_plane_lines = _to_lines(list(tan_cc) + (list(rad_cc) if rad_cc is not None else [])) + # Panel 4: Convergence (log10) + try: + convergence = tracer.convergence_2d_from(grid=grid) + plot_array(array=convergence, ax=axes_flat[4], title="Convergence", + colormap=colormap, use_log10=True) + except Exception: + axes_flat[4].axis("off") - traced_grids = tracer.traced_grid_2d_list_from(grid=grid) - lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0]) - lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0]) - plot_array(array=lens_image, ax=axes_flat[3], title="Lens Image", - lines=image_plane_lines, colormap=colormap, use_log10=True) + # Panel 5: Potential (log10) + try: + potential = tracer.potential_2d_from(grid=grid) + plot_array(array=potential, ax=axes_flat[5], title="Potential", + colormap=colormap, use_log10=True) + except Exception: + axes_flat[5].axis("off") + + # Panel 6: Deflections Y + plot_array(array=deflections_y, ax=axes_flat[6], title="Deflections Y", + lines=image_plane_lines, line_colors=image_plane_line_colors, + colormap=colormap) - for i in range(4, 9): - axes_flat[i].axis("off") + # Panel 7: Deflections X + plot_array(array=deflections_x, ax=axes_flat[7], title="Deflections X", + lines=image_plane_lines, line_colors=image_plane_line_colors, + colormap=colormap) + + # Panel 8: Magnification + plot_array(array=magnification, ax=axes_flat[8], title="Magnification", + lines=image_plane_lines, line_colors=image_plane_line_colors, + colormap=colormap) plt.tight_layout() save_figure(fig, path=output_path, filename="tracer", format=output_format) diff --git a/autolens/lens/plot/tracer_plots.py b/autolens/lens/plot/tracer_plots.py index 2005b2f71..62ff69353 100644 --- a/autolens/lens/plot/tracer_plots.py +++ b/autolens/lens/plot/tracer_plots.py @@ -40,8 +40,14 @@ def subplot_tracer( tan_cc, rad_cc = _critical_curves_from(tracer, grid) tan_ca, rad_ca = _caustics_from(tracer, grid) - image_plane_lines = _to_lines((list(tan_cc) if tan_cc is not None else []) + (list(rad_cc) if rad_cc is not None else [])) - source_plane_lines = _to_lines((list(tan_ca) if tan_ca is not None else []) + (list(rad_ca) if rad_ca is not None else [])) + _tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or [] + _rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or [] + _tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or [] + _rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or [] + image_plane_lines = (_tan_cc_lines + _rad_cc_lines) or None + image_plane_line_colors = ["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines) + source_plane_lines = (_tan_ca_lines + _rad_ca_lines) or None + source_plane_line_colors = ["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines) pos_list = _to_positions(positions) # --- compute arrays --- @@ -49,6 +55,10 @@ def subplot_tracer( source_galaxies = ag.Galaxies(galaxies=tracer.planes[final_plane_index]) source_image = source_galaxies.image_2d_from(grid=traced_grids[final_plane_index]) + try: + source_vmax = float(np.max(source_image.array)) + except (AttributeError, ValueError): + source_vmax = None lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0]) lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0]) @@ -65,25 +75,27 @@ def subplot_tracer( fig, axes = plt.subplots(3, 3, figsize=conf_subplot_figsize(3, 3)) axes_flat = list(axes.flatten()) - plot_array(array=image, ax=axes_flat[0], title="Image", - lines=image_plane_lines, positions=pos_list, colormap=colormap, - use_log10=use_log10) - plot_array(array=source_image, ax=axes_flat[1], title="Source Image", + plot_array(array=image, ax=axes_flat[0], title="Model Image", + lines=image_plane_lines, line_colors=image_plane_line_colors, + positions=pos_list, colormap=colormap, use_log10=use_log10) + plot_array(array=source_image, ax=axes_flat[1], title="Source Model Image", + colormap=colormap, use_log10=use_log10, vmax=source_vmax) + plot_array(array=source_image, ax=axes_flat[2], title="Source Plane (No Zoom)", + lines=source_plane_lines, line_colors=source_plane_line_colors, colormap=colormap, use_log10=use_log10) - plot_array(array=source_image, ax=axes_flat[2], title="Source Plane Image", - lines=source_plane_lines, colormap=colormap, use_log10=use_log10) plot_array(array=lens_image, ax=axes_flat[3], title="Lens Image", - colormap=colormap, use_log10=use_log10) + lines=image_plane_lines, line_colors=image_plane_line_colors, + colormap=colormap, use_log10=True) plot_array(array=convergence, ax=axes_flat[4], title="Convergence", - lines=image_plane_lines, colormap=colormap, use_log10=use_log10) + colormap=colormap, use_log10=True) plot_array(array=potential, ax=axes_flat[5], title="Potential", - lines=image_plane_lines, colormap=colormap, use_log10=use_log10) + colormap=colormap, use_log10=True) plot_array(array=deflections_y, ax=axes_flat[6], title="Deflections Y", - lines=image_plane_lines, colormap=colormap) + lines=image_plane_lines, line_colors=image_plane_line_colors, colormap=colormap) plot_array(array=deflections_x, ax=axes_flat[7], title="Deflections X", - lines=image_plane_lines, colormap=colormap) + lines=image_plane_lines, line_colors=image_plane_line_colors, colormap=colormap) plot_array(array=magnification, ax=axes_flat[8], title="Magnification", - lines=image_plane_lines, colormap=colormap) + lines=image_plane_lines, line_colors=image_plane_line_colors, colormap=colormap) hide_unused_axes(axes_flat) plt.tight_layout()