diff --git a/autolens/analysis/plotter.py b/autolens/analysis/plotter.py index 1d29b8313..088bc6bfb 100644 --- a/autolens/analysis/plotter.py +++ b/autolens/analysis/plotter.py @@ -37,6 +37,10 @@ def tracer( self, tracer: Tracer, grid: aa.type.Grid2DLike, + image_plane_lines=None, + image_plane_line_colors=None, + source_plane_lines=None, + source_plane_line_colors=None, ): """ Visualizes a `Tracer` object. @@ -47,6 +51,14 @@ def tracer( The maximum log likelihood `Tracer` of the non-linear search. grid A 2D grid of (y,x) arc-second coordinates used to perform ray-tracing. + image_plane_lines + Pre-computed critical-curve lines to overlay on image-plane panels. + image_plane_line_colors + Colours for each image-plane line. + source_plane_lines + Pre-computed caustic lines to overlay on source-plane panels. + source_plane_line_colors + Colours for each source-plane line. """ def should_plot(name): @@ -55,6 +67,18 @@ def should_plot(name): output_path = str(self.image_path) fmt = self.fmt + if should_plot("subplot_tracer"): + subplot_tracer( + tracer=tracer, + grid=grid, + output_path=output_path, + output_format=fmt, + image_plane_lines=image_plane_lines, + image_plane_line_colors=image_plane_line_colors, + source_plane_lines=source_plane_lines, + source_plane_line_colors=source_plane_line_colors, + ) + if should_plot("subplot_galaxies_images"): subplot_galaxies_images( tracer=tracer, diff --git a/autolens/imaging/model/plotter.py b/autolens/imaging/model/plotter.py index fa7932c44..2f542a448 100644 --- a/autolens/imaging/model/plotter.py +++ b/autolens/imaging/model/plotter.py @@ -19,7 +19,7 @@ subplot_tracer_from_fit, subplot_fit_combined, subplot_fit_combined_log10, - _compute_critical_curve_lines, + _compute_critical_curves_from_fit, ) from autolens.analysis.plotter import plot_setting @@ -31,7 +31,13 @@ class PlotterImaging(Plotter): imaging_combined = AgPlotterImaging.imaging_combined def fit_imaging( - self, fit: FitImaging, quick_update: bool = False + self, + fit: FitImaging, + quick_update: bool = False, + image_plane_lines=None, + image_plane_line_colors=None, + source_plane_lines=None, + source_plane_line_colors=None, ): """ Visualizes a `FitImaging` object, which fits an imaging dataset. @@ -42,6 +48,14 @@ def fit_imaging( The maximum log likelihood `FitImaging` of the non-linear search. quick_update If True only the essential subplot_fit is output. + image_plane_lines + Pre-computed critical-curve lines. Computed internally if not provided. + image_plane_line_colors + Colours for each image-plane line. + source_plane_lines + Pre-computed caustic lines. Computed internally if not provided. + source_plane_line_colors + Colours for each source-plane line. """ def should_plot(name): @@ -52,13 +66,15 @@ def should_plot(name): plane_indexes_to_plot = [i for i in fit.tracer.plane_indexes_with_images if i != 0] - # Compute critical curves and caustics once for all subplot functions. - tracer = fit.tracer_linear_light_profiles_to_light_profiles - _zoom = aa.Zoom2D(mask=fit.mask) - _cc_grid = aa.Grid2D.from_extent( - extent=_zoom.extent_from(buffer=0), shape_native=_zoom.shape_native - ) - ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines(tracer, _cc_grid) + # Compute critical curves and caustics once for all subplot functions, + # unless already provided by the caller (e.g. the visualizer). + if image_plane_lines is None and source_plane_lines is None: + ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curves_from_fit(fit) + else: + ip_lines, ip_colors, sp_lines, sp_colors = ( + image_plane_lines, image_plane_line_colors, + source_plane_lines, source_plane_line_colors, + ) if should_plot("subplot_fit") or quick_update: diff --git a/autolens/imaging/model/visualizer.py b/autolens/imaging/model/visualizer.py index e8b40c439..cb0328920 100644 --- a/autolens/imaging/model/visualizer.py +++ b/autolens/imaging/model/visualizer.py @@ -6,6 +6,7 @@ import autogalaxy as ag from autolens.imaging.model.plotter import PlotterImaging +from autolens.imaging.plot.fit_imaging_plots import _compute_critical_curves_from_fit from autolens import exc @@ -101,10 +102,19 @@ def visualize( title_prefix=analysis.title_prefix, ) + grid = fit.mask.derive_grid.all_false + + # Compute critical curves once for all plot functions. + ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curves_from_fit(fit) + try: plotter.fit_imaging( fit=fit, quick_update=quick_update, + image_plane_lines=ip_lines, + image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, + source_plane_line_colors=sp_colors, ) except exc.InversionException: pass @@ -135,16 +145,13 @@ def visualize( ) return - zoom = ag.Zoom2D(mask=fit.mask) - - extent = zoom.extent_from(buffer=0) - shape_native = zoom.shape_native - - grid = ag.Grid2D.from_extent(extent=extent, shape_native=shape_native) - plotter.tracer( tracer=tracer, grid=grid, + image_plane_lines=ip_lines, + image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, + source_plane_line_colors=sp_colors, ) plotter.galaxies( galaxies=tracer.galaxies, diff --git a/autolens/imaging/plot/fit_imaging_plots.py b/autolens/imaging/plot/fit_imaging_plots.py index b37097038..59908304d 100644 --- a/autolens/imaging/plot/fit_imaging_plots.py +++ b/autolens/imaging/plot/fit_imaging_plots.py @@ -53,6 +53,19 @@ def _compute_critical_curve_lines(tracer, grid): return None, None, None, None +def _compute_critical_curves_from_fit(fit): + """Compute critical-curve and caustic lines from a FitImaging object. + + Convenience wrapper around :func:`_compute_critical_curve_lines` that + derives the tracer and grid from *fit* directly, using the fully unmasked + image-plane grid so the curves cover the whole image extent. + + Returns the same 4-tuple as :func:`_compute_critical_curve_lines`. + """ + tracer = fit.tracer_linear_light_profiles_to_light_profiles + return _compute_critical_curve_lines(tracer, fit.mask.derive_grid.all_false) + + def _get_source_vmax(fit): """ Return the colour-scale maximum for source-plane panels. @@ -217,14 +230,8 @@ def subplot_fit( source_vmax = _get_source_vmax(fit) if image_plane_lines is None and source_plane_lines is None: - tracer = fit.tracer_linear_light_profiles_to_light_profiles - _zoom = aa.Zoom2D(mask=fit.mask) - _cc_grid = aa.Grid2D.from_extent( - extent=_zoom.extent_from(buffer=0), - shape_native=_zoom.shape_native, - ) image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = ( - _compute_critical_curve_lines(tracer, _cc_grid) + _compute_critical_curves_from_fit(fit) ) fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4)) @@ -422,14 +429,8 @@ def subplot_fit_log10( source_vmax = _get_source_vmax(fit) if image_plane_lines is None and source_plane_lines is None: - tracer = fit.tracer_linear_light_profiles_to_light_profiles - _zoom = aa.Zoom2D(mask=fit.mask) - _cc_grid = aa.Grid2D.from_extent( - extent=_zoom.extent_from(buffer=0), - shape_native=_zoom.shape_native, - ) image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = ( - _compute_critical_curve_lines(tracer, _cc_grid) + _compute_critical_curves_from_fit(fit) ) fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4)) @@ -673,14 +674,11 @@ def subplot_tracer_from_fit( tracer = fit.tracer_linear_light_profiles_to_light_profiles # --- grid --- - zoom = aa.Zoom2D(mask=fit.mask) - grid = aa.Grid2D.from_extent( - extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native - ) + grid = fit.mask.derive_grid.all_false if image_plane_lines is None and source_plane_lines is None: image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = ( - _compute_critical_curve_lines(tracer, grid) + _compute_critical_curves_from_fit(fit) ) source_vmax = _get_source_vmax(fit) @@ -689,7 +687,7 @@ def subplot_tracer_from_fit( lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0]) lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0]) - deflections = lens_galaxies.deflections_yx_2d_from(grid=grid) + deflections = tracer.deflections_yx_2d_from(grid=grid) deflections_y = aa.Array2D(values=deflections.slim[:, 0], mask=grid.mask) deflections_x = aa.Array2D(values=deflections.slim[:, 1], mask=grid.mask) @@ -703,10 +701,10 @@ def subplot_tracer_from_fit( lines=image_plane_lines, line_colors=image_plane_line_colors, colormap=colormap) - # Panel 1: Source Model Image (image-plane projection) + # Panel 1: Source Model Image (same as subplot_fit panel 7) try: source_model_img = fit.model_images_of_planes_list[final_plane_index] - except Exception: + except (IndexError, AttributeError): source_model_img = None if source_model_img is not None: plot_array(array=source_model_img, ax=axes_flat[1], title="Source Model Image", @@ -715,7 +713,7 @@ def subplot_tracer_from_fit( else: axes_flat[1].axis("off") - # Panel 2: Source Plane (No Zoom) + # Panel 2: Source Plane (No Zoom) (same as subplot_fit panel 12) _plot_source_plane(fit, axes_flat[2], final_plane_index, zoom_to_brightest=False, colormap=colormap, title="Source Plane (No Zoom)", lines=source_plane_lines, line_colors=source_plane_line_colors, diff --git a/autolens/interferometer/model/plotter.py b/autolens/interferometer/model/plotter.py index b537e6fbe..f03911f45 100644 --- a/autolens/interferometer/model/plotter.py +++ b/autolens/interferometer/model/plotter.py @@ -1,8 +1,9 @@ +import autoarray as aa + from autogalaxy.interferometer.model.plotter import ( PlotterInterferometer as AgPlotterInterferometer, ) -from autogalaxy.interferometer.plot import fit_interferometer_plots as ag_fit_interferometer_plots from autogalaxy.interferometer.plot.fit_interferometer_plots import ( fits_galaxy_images, fits_dirty_images, @@ -11,7 +12,9 @@ from autolens.interferometer.fit_interferometer import FitInterferometer from autolens.interferometer.plot.fit_interferometer_plots import ( subplot_fit, + subplot_fit_dirty_images, subplot_fit_real_space, + _compute_critical_curve_lines, ) from autolens.analysis.plotter import Plotter @@ -25,6 +28,10 @@ def fit_interferometer( self, fit: FitInterferometer, quick_update: bool = False, + image_plane_lines=None, + image_plane_line_colors=None, + source_plane_lines=None, + source_plane_line_colors=None, ): """ Visualizes a `FitInterferometer` object. @@ -33,6 +40,14 @@ def fit_interferometer( ---------- fit The maximum log likelihood `FitInterferometer` of the non-linear search. + image_plane_lines + Pre-computed critical-curve lines. Computed internally if not provided. + image_plane_line_colors + Colours for each image-plane line. + source_plane_lines + Pre-computed caustic lines. Computed internally if not provided. + source_plane_line_colors + Colours for each source-plane line. """ def should_plot(name): @@ -41,21 +56,41 @@ def should_plot(name): output_path = str(self.image_path) fmt = self.fmt + # Use pre-computed critical curves if provided, otherwise compute once here. + if image_plane_lines is None and source_plane_lines is None: + tracer = fit.tracer_linear_light_profiles_to_light_profiles + _zoom = aa.Zoom2D(mask=fit.dataset.real_space_mask) + _cc_grid = aa.Grid2D.from_extent( + extent=_zoom.extent_from(buffer=0), shape_native=_zoom.shape_native + ) + ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines(tracer, _cc_grid) + else: + ip_lines, ip_colors, sp_lines, sp_colors = ( + image_plane_lines, image_plane_line_colors, + source_plane_lines, source_plane_line_colors, + ) + if should_plot("subplot_fit"): - subplot_fit(fit, output_path=output_path, output_format=fmt) + subplot_fit( + fit, output_path=output_path, output_format=fmt, + image_plane_lines=ip_lines, image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, source_plane_line_colors=sp_colors, + ) if should_plot("subplot_fit_dirty_images") or quick_update: - ag_fit_interferometer_plots.subplot_fit_dirty_images( - fit=fit, - output_path=self.image_path, - output_format=self.fmt, + subplot_fit_dirty_images( + fit, output_path=output_path, output_format=fmt, + image_plane_lines=ip_lines, image_plane_line_colors=ip_colors, ) if quick_update: return if should_plot("subplot_fit_real_space"): - subplot_fit_real_space(fit, output_path=output_path, output_format=fmt) + subplot_fit_real_space( + fit, output_path=output_path, output_format=fmt, + source_plane_lines=sp_lines, source_plane_line_colors=sp_colors, + ) if should_plot("fits_galaxy_images"): fits_galaxy_images(fit=fit, output_path=self.image_path) diff --git a/autolens/interferometer/model/visualizer.py b/autolens/interferometer/model/visualizer.py index a9d203315..830f67ba3 100644 --- a/autolens/interferometer/model/visualizer.py +++ b/autolens/interferometer/model/visualizer.py @@ -6,6 +6,7 @@ from autolens.interferometer.model.plotter import ( PlotterInterferometer, ) +from autolens.interferometer.plot.fit_interferometer_plots import _compute_critical_curve_lines from autogalaxy import exc logger = logging.getLogger(__name__) @@ -93,15 +94,29 @@ def visualize( via a non-linear search). """ fit = analysis.fit_from(instance=instance) + tracer = fit.tracer_linear_light_profiles_to_light_profiles plotter = PlotterInterferometer( image_path=paths.image_path, title_prefix=analysis.title_prefix ) + # Compute grid and critical curves once for all plot functions. + zoom = ag.Zoom2D(mask=fit.dataset.real_space_mask) + grid = ag.Grid2D.from_extent( + extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native + ) + ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines( + tracer, grid + ) + try: plotter.fit_interferometer( fit=fit, quick_update=quick_update, + image_plane_lines=ip_lines, + image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, + source_plane_line_colors=sp_colors, ) except exc.InversionException: logger(ag.exc.invalid_linear_algebra_for_visualization_message()) @@ -130,23 +145,13 @@ def visualize( except exc.InversionException: return - tracer = fit.tracer_linear_light_profiles_to_light_profiles - - zoom = ag.Zoom2D(mask=fit.dataset.real_space_mask) - - extent = zoom.extent_from(buffer=0) - shape_native = zoom.shape_native - - grid = ag.Grid2D.from_extent(extent=extent, shape_native=shape_native) - - try: - plotter.fit_interferometer(fit=fit) - except exc.InversionException: - pass - plotter.tracer( tracer=tracer, grid=grid, + image_plane_lines=ip_lines, + image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, + source_plane_line_colors=sp_colors, ) plotter.galaxies( galaxies=tracer.galaxies, diff --git a/autolens/interferometer/plot/fit_interferometer_plots.py b/autolens/interferometer/plot/fit_interferometer_plots.py index da34d5d01..9b745b434 100644 --- a/autolens/interferometer/plot/fit_interferometer_plots.py +++ b/autolens/interferometer/plot/fit_interferometer_plots.py @@ -1,3 +1,4 @@ +import logging import matplotlib.pyplot as plt import numpy as np from typing import Optional @@ -6,82 +7,123 @@ import autogalaxy as ag from autoarray.plot.array import plot_array -from autoarray.plot.utils import save_figure +from autoarray.plot.yx import plot_yx +from autoarray.plot.utils import save_figure, conf_subplot_figsize from autoarray.plot.utils import numpy_lines as _to_lines -from autogalaxy.plot.plot_utils import _critical_curves_from +from autoarray.inversion.mappers.abstract import Mapper +from autoarray.inversion.plot.mapper_plots import plot_mapper +from autogalaxy.plot.plot_utils import _critical_curves_from, _caustics_from +from autolens.lens.plot.tracer_plots import plane_image_from +logger = logging.getLogger(__name__) -def _plot_yx(y, x, ax, title, xlabel="", ylabel=""): - """ - Render a scatter plot of *y* versus *x* into an existing axes object. - Parameters - ---------- - y : array-like - Dependent-variable values (y-axis). - x : array-like - Independent-variable values (x-axis). - ax : matplotlib.axes.Axes - The axes into which the scatter plot is drawn. - title : str - Axes title. - xlabel : str, optional - Label for the x-axis. - ylabel : str, optional - Label for the y-axis. +def _compute_critical_curve_lines(tracer, grid): + """Compute critical-curve and caustic lines for a tracer on a given grid. + + Returns a 4-tuple ``(image_plane_lines, image_plane_line_colors, + source_plane_lines, source_plane_line_colors)`` suitable for passing + directly to :func:`~autoarray.plot.array.plot_array`. On failure + returns ``(None, None, None, None)``. """ - ax.scatter(x, y, s=1) - ax.set_title(title) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) + try: + tan_cc, rad_cc = _critical_curves_from(tracer, grid) + tan_ca, rad_ca = _caustics_from(tracer, grid) + _tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or [] + _rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or [] + _tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or [] + _rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or [] + image_plane_lines = (_tan_cc_lines + _rad_cc_lines) or None + image_plane_line_colors = ( + ["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines) + ) + source_plane_lines = (_tan_ca_lines + _rad_ca_lines) or None + source_plane_line_colors = ( + ["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines) + ) + return image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors + except Exception: + return None, None, None, None def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True, - colormap=None, use_log10=False): + colormap=None, use_log10=False, title=None, + lines=None, line_colors=None, vmax=None): """ - Plot the source-plane image (or a blank inversion placeholder) into an axes. + Plot the source-plane image into an axes, matching the imaging subplot_fit behaviour. - For parametric sources the function ray-traces a zoomed real-space grid - to the source plane and renders the resulting image via - :func:`~autoarray.plot.array.plot_array`. When the plane contains a - pixelization (inversion source), the axes are turned off and labelled - as a placeholder, because the reconstruction is rendered separately. + For parametric sources, evaluates the source light profiles directly on + the unmasked real-space grid (``fit.dataset.real_space_mask.derive_grid.all_false``) + via :func:`~autolens.lens.plot.tracer_plots.plane_image_from` — identical + to the imaging path. For pixelized sources, renders the inversion + reconstruction via :func:`~autoarray.inversion.plot.mapper_plots.plot_mapper`. Parameters ---------- fit : FitInterferometer - The interferometer fit providing the tracer and real-space mask. + The interferometer fit providing the tracer, real-space mask, and inversion. ax : matplotlib.axes.Axes or None - The axes into which the source-plane image is drawn. Passing - ``None`` is a no-op. + The axes into which the source-plane image is drawn. ``None`` is a no-op. plane_index : int Index of the plane in ``fit.tracer.planes`` to visualise. zoom_to_brightest : bool, optional - Reserved for future zoomed rendering; currently unused in the - rendering call. + For parametric sources: zoom the evaluation grid in on the brightest + region. For inversion sources: zoom the colourmap to brightest pixels. colormap : str, optional Matplotlib colormap name. use_log10 : bool, optional - If ``True`` the colour scale is applied on a log10 stretch. + Apply a log10 colour stretch. + title : str, optional + Axes title. Defaults to ``"Source Plane (Zoomed)"`` / + ``"Source Plane (No Zoom)"`` according to ``zoom_to_brightest``. + lines : list, optional + Caustic lines to overlay (passed to :func:`plot_array` / :func:`plot_mapper`). + line_colors : list, optional + Colours for each entry in *lines*. + vmax : float, optional + Shared colour-scale maximum. """ + if ax is None: + return + + if title is None: + title = "Source Plane (Zoomed)" if zoom_to_brightest else "Source Plane (No Zoom)" + tracer = fit.tracer_linear_light_profiles_to_light_profiles if not tracer.planes[plane_index].has(cls=aa.Pixelization): - zoom = aa.Zoom2D(mask=fit.dataset.real_space_mask) - grid = aa.Grid2D.from_extent( - extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native + image = plane_image_from( + galaxies=tracer.planes[plane_index], + grid=fit.dataset.real_space_mask.derive_grid.all_false, + zoom_to_brightest=zoom_to_brightest, ) - traced_grids = tracer.traced_grid_2d_list_from(grid=grid) - plane_galaxies = ag.Galaxies(galaxies=tracer.planes[plane_index]) - image = plane_galaxies.image_2d_from(grid=traced_grids[plane_index]) plot_array( array=image, ax=ax, - title=f"Source Plane {plane_index}", - colormap=colormap, use_log10=use_log10, + title=title, + colormap=colormap, use_log10=use_log10, vmax=vmax, + lines=lines, line_colors=line_colors, ) else: - if ax is not None: + try: + inversion = fit.inversion + mapper_list = inversion.cls_list_from(cls=Mapper) + mapper = mapper_list[plane_index - 1] if plane_index > 0 else mapper_list[0] + pixel_values = inversion.reconstruction_dict[mapper] + plot_mapper( + mapper, + solution_vector=pixel_values, + ax=ax, + title=title, + colormap=colormap, + use_log10=use_log10, + vmax=vmax, + zoom_to_brightest=zoom_to_brightest, + lines=lines, + line_colors=line_colors, + ) + except Exception as exc: + logger.warning(f"Could not plot source reconstruction for plane {plane_index}: {exc}") ax.axis("off") - ax.set_title(f"Source Reconstruction (plane {plane_index})") + ax.set_title(title) def subplot_fit( @@ -89,6 +131,10 @@ def subplot_fit( output_path: Optional[str] = None, output_format: str = "png", colormap: Optional[str] = None, + image_plane_lines=None, + image_plane_line_colors=None, + source_plane_lines=None, + source_plane_line_colors=None, ): """ Produce a 12-panel subplot summarising an interferometer fit. @@ -121,64 +167,84 @@ def subplot_fit( colormap : str, optional Matplotlib colormap name applied to all image panels. """ - tracer = fit.tracer_linear_light_profiles_to_light_profiles final_plane_index = len(fit.tracer.planes) - 1 - fig, axes = plt.subplots(3, 4, figsize=(28, 21)) + if image_plane_lines is None and source_plane_lines is None: + tracer = fit.tracer_linear_light_profiles_to_light_profiles + _zoom = aa.Zoom2D(mask=fit.dataset.real_space_mask) + _cc_grid = aa.Grid2D.from_extent( + extent=_zoom.extent_from(buffer=0), shape_native=_zoom.shape_native + ) + image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = ( + _compute_critical_curve_lines(tracer, _cc_grid) + ) + + fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4)) axes_flat = list(axes.flatten()) - # Panel 0: amplitudes vs UV distances - _plot_yx( + # Panel 0: amplitudes vs UV-distances + plot_yx( y=np.real(fit.residual_map), x=fit.dataset.uv_distances / 10 ** 3.0, ax=axes_flat[0], title="Amplitudes vs UV-Distance", - xlabel=r"k$\lambda$", + xtick_suffix='"', + ytick_suffix="Jy", + plot_axis_type="scatter", ) plot_array(array=fit.dirty_image, ax=axes_flat[1], title="Dirty Image", colormap=colormap) plot_array(array=fit.dirty_signal_to_noise_map, ax=axes_flat[2], title="Dirty Signal-To-Noise Map", colormap=colormap) + + # Panel 3 (4th): dirty model image with critical curves plot_array(array=fit.dirty_model_image, ax=axes_flat[3], title="Dirty Model Image", - colormap=colormap) + colormap=colormap, lines=image_plane_lines, + line_colors=image_plane_line_colors) - # Panel 4: source image - _plot_source_plane(fit, axes_flat[4], final_plane_index, colormap=colormap) + # Panel 4: dirty residual map + plot_array(array=fit.dirty_residual_map, ax=axes_flat[4], + title="Dirty Residual Map", colormap=colormap) - # Normalized residual vs UV distances (real) - _plot_yx( + # Panel 5: normalized residual vs UV-distances (real) + plot_yx( y=np.real(fit.normalized_residual_map), x=fit.dataset.uv_distances / 10 ** 3.0, ax=axes_flat[5], - title="Norm Residual vs UV-Distance (real)", - ylabel=r"$\sigma$", - xlabel=r"k$\lambda$", + title="Normalized Residual Map (Real)", + xtick_suffix='"', + ytick_suffix=r"$\sigma$", + plot_axis_type="scatter", ) - # Normalized residual vs UV distances (imag) - _plot_yx( + # Panel 6: normalized residual vs UV-distances (imag) + plot_yx( y=np.imag(fit.normalized_residual_map), x=fit.dataset.uv_distances / 10 ** 3.0, ax=axes_flat[6], - title="Norm Residual vs UV-Distance (imag)", - ylabel=r"$\sigma$", - xlabel=r"k$\lambda$", + title="Normalized Residual Map (Imag)", + xtick_suffix='"', + ytick_suffix=r"$\sigma$", + plot_axis_type="scatter", ) - # Panel 7: source plane zoomed - _plot_source_plane(fit, axes_flat[7], final_plane_index, colormap=colormap) + # Panel 7 (8th): source plane zoomed with caustics + _plot_source_plane(fit, axes_flat[7], final_plane_index, + zoom_to_brightest=True, colormap=colormap, + title="Source Plane (Zoomed)", + lines=source_plane_lines, + line_colors=source_plane_line_colors) plot_array(array=fit.dirty_normalized_residual_map, ax=axes_flat[8], title="Dirty Normalized Residual Map", colormap=colormap, cb_unit=r"$\sigma$") - # Panel 9: clipped to [-1, 1] + # Panel 9: clipped to ±1σ plot_array( fit.dirty_normalized_residual_map, - ax=axes_flat[8], + ax=axes_flat[9], title=r"Normalized Residual Map $1\sigma$", colormap=colormap, - use_log10=False, vmin=-1.0, vmax=1.0, cb_unit=r"$\sigma$", ) @@ -186,19 +252,84 @@ def subplot_fit( plot_array(array=fit.dirty_chi_squared_map, ax=axes_flat[10], title="Dirty Chi-Squared Map", colormap=colormap, cb_unit=r"$\chi^2$") - # Panel 11: source plane not zoomed + # Panel 11 (12th): source plane not zoomed with caustics _plot_source_plane(fit, axes_flat[11], final_plane_index, - zoom_to_brightest=False, colormap=colormap) + zoom_to_brightest=False, colormap=colormap, + title="Source Plane (No Zoom)", + lines=source_plane_lines, + line_colors=source_plane_line_colors) plt.tight_layout() save_figure(fig, path=output_path, filename="fit", format=output_format) +def subplot_fit_dirty_images( + fit, + output_path: Optional[str] = None, + output_format: str = "png", + colormap: Optional[str] = None, + use_log10: bool = False, + image_plane_lines=None, + image_plane_line_colors=None, +): + """ + Produce a 2×3 subplot of dirty-image diagnostics for an interferometer fit. + + Panels (row-major order): + Dirty Image | Dirty Signal-To-Noise Map | Dirty Model Image (critical curves) + Dirty Residual Map | Dirty Norm Residual Map | Dirty Chi-Squared Map + + Parameters + ---------- + fit : FitInterferometer + The interferometer fit to visualise. + output_path : str, optional + Directory in which to save the figure. + output_format : str, optional + Image format. + colormap : str, optional + Matplotlib colormap name. + use_log10 : bool, optional + Apply a log10 colour stretch. + """ + if image_plane_lines is None: + tracer = fit.tracer_linear_light_profiles_to_light_profiles + _zoom = aa.Zoom2D(mask=fit.dataset.real_space_mask) + _cc_grid = aa.Grid2D.from_extent( + extent=_zoom.extent_from(buffer=0), shape_native=_zoom.shape_native + ) + image_plane_lines, image_plane_line_colors, _, _ = ( + _compute_critical_curve_lines(tracer, _cc_grid) + ) + + fig, axes = plt.subplots(2, 3, figsize=conf_subplot_figsize(2, 3)) + axes_flat = list(axes.flatten()) + + plot_array(array=fit.dirty_image, ax=axes_flat[0], title="Dirty Image", + colormap=colormap, use_log10=use_log10) + plot_array(array=fit.dirty_signal_to_noise_map, ax=axes_flat[1], + title="Dirty Signal-To-Noise Map", colormap=colormap) + plot_array(array=fit.dirty_model_image, ax=axes_flat[2], + title="Dirty Model Image", colormap=colormap, use_log10=use_log10, + lines=image_plane_lines, line_colors=image_plane_line_colors) + plot_array(array=fit.dirty_residual_map, ax=axes_flat[3], + title="Dirty Residual Map", colormap=colormap) + plot_array(array=fit.dirty_normalized_residual_map, ax=axes_flat[4], + title="Dirty Normalized Residual Map", colormap=colormap, cb_unit=r"$\sigma$") + plot_array(array=fit.dirty_chi_squared_map, ax=axes_flat[5], + title="Dirty Chi-Squared Map", colormap=colormap, cb_unit=r"$\chi^2$") + + plt.tight_layout() + save_figure(fig, path=output_path, filename="fit_dirty_images", format=output_format) + + def subplot_fit_real_space( fit, output_path: Optional[str] = None, output_format: str = "png", colormap: Optional[str] = None, + source_plane_lines=None, + source_plane_line_colors=None, ): """ Produce a real-space subplot for an interferometer fit. @@ -228,34 +359,30 @@ def subplot_fit_real_space( tracer = fit.tracer_linear_light_profiles_to_light_profiles final_plane_index = len(fit.tracer.planes) - 1 - if fit.inversion is None: - # No inversion: image + source plane image - fig, axes = plt.subplots(1, 2, figsize=(14, 7)) - axes_flat = list(axes.flatten()) + fig, axes = plt.subplots(1, 2, figsize=conf_subplot_figsize(1, 2)) + axes_flat = list(axes.flatten()) + if fit.inversion is None: + # Parametric source: image-plane model image + source-plane image zoom = aa.Zoom2D(mask=fit.dataset.real_space_mask) grid = aa.Grid2D.from_extent( extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native ) - traced_grids = tracer.traced_grid_2d_list_from(grid=grid) - image = tracer.image_2d_from(grid=grid) plot_array(array=image, ax=axes_flat[0], title="Image", colormap=colormap) - source_galaxies = ag.Galaxies(galaxies=tracer.planes[final_plane_index]) - source_image = source_galaxies.image_2d_from( - grid=traced_grids[final_plane_index] - ) - plot_array(array=source_image, ax=axes_flat[1], title="Source Plane Image", - colormap=colormap) + _plot_source_plane(fit, axes_flat[1], final_plane_index, + zoom_to_brightest=True, colormap=colormap, + title="Source Plane (Zoomed)", + lines=source_plane_lines, line_colors=source_plane_line_colors) else: - fig, axes = plt.subplots(1, 2, figsize=(14, 7)) - axes_flat = list(axes.flatten()) - # Inversion: show blank placeholder panels - for _ax in axes_flat: - _ax.axis("off") - axes_flat[0].set_title("Reconstructed Data") - axes_flat[1].set_title("Source Plane (Zoom)") + # Pixelized source: dirty model image + source reconstruction + plot_array(array=fit.dirty_model_image, ax=axes_flat[0], + title="Reconstructed Image", colormap=colormap) + _plot_source_plane(fit, axes_flat[1], final_plane_index, + zoom_to_brightest=True, colormap=colormap, + title="Source Reconstruction", + lines=source_plane_lines, line_colors=source_plane_line_colors) plt.tight_layout() save_figure(fig, path=output_path, filename="fit_real_space", format=output_format) diff --git a/autolens/point/model/plotter.py b/autolens/point/model/plotter.py index 9c25c82da..add77a389 100644 --- a/autolens/point/model/plotter.py +++ b/autolens/point/model/plotter.py @@ -1,4 +1,7 @@ +import autogalaxy as ag + from autolens.analysis.plotter import Plotter +from autolens.imaging.plot.fit_imaging_plots import _compute_critical_curve_lines from autolens.point.fit.dataset import FitPointDataset from autolens.point.plot.fit_point_plots import subplot_fit as subplot_fit_point @@ -32,6 +35,10 @@ def fit_point( self, fit: FitPointDataset, quick_update: bool = False, + image_plane_lines=None, + image_plane_line_colors=None, + source_plane_lines=None, + source_plane_line_colors=None, ): """ Visualizes a `FitPointDataset` object. @@ -40,6 +47,14 @@ def fit_point( ---------- fit The maximum log likelihood `FitPointDataset` of the non-linear search. + image_plane_lines + Pre-computed critical-curve lines to overlay on image-plane panels. + image_plane_line_colors + Colours for each image-plane line. + source_plane_lines + Pre-computed caustic lines to overlay on source-plane panels. + source_plane_line_colors + Colours for each source-plane line. """ def should_plot(name): @@ -48,8 +63,28 @@ def should_plot(name): output_path = str(self.image_path) fmt = self.fmt + # Use pre-computed critical curves if provided, otherwise compute once here. + if image_plane_lines is None and source_plane_lines is None: + grid = ag.Grid2D.from_extent( + extent=fit.dataset.extent_from(), shape_native=(100, 100) + ) + ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines( + fit.tracer, grid + ) + else: + ip_lines, ip_colors, sp_lines, sp_colors = ( + image_plane_lines, image_plane_line_colors, + source_plane_lines, source_plane_line_colors, + ) + if should_plot("subplot_fit") or quick_update: - subplot_fit_point(fit, output_path=output_path, output_format=fmt) + subplot_fit_point( + fit, output_path=output_path, output_format=fmt, + image_plane_lines=ip_lines, + image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, + source_plane_line_colors=sp_colors, + ) if quick_update: return diff --git a/autolens/point/model/visualizer.py b/autolens/point/model/visualizer.py index c070bb056..0d885f8a4 100644 --- a/autolens/point/model/visualizer.py +++ b/autolens/point/model/visualizer.py @@ -2,6 +2,7 @@ import autogalaxy as ag from autolens.point.model.plotter import PlotterPoint +from autolens.imaging.plot.fit_imaging_plots import _compute_critical_curve_lines class VisualizerPoint(af.Visualizer): @@ -69,20 +70,36 @@ def visualize( image_path=paths.image_path, title_prefix=analysis.title_prefix ) - plotter.fit_point(fit=fit, quick_update=quick_update) - - if quick_update: - return - tracer = fit.tracer grid = ag.Grid2D.from_extent( extent=fit.dataset.extent_from(), shape_native=(100, 100) ) + # Compute critical curves and caustics once for all plot functions. + ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines( + tracer, grid + ) + + plotter.fit_point( + fit=fit, + quick_update=quick_update, + image_plane_lines=ip_lines, + image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, + source_plane_line_colors=sp_colors, + ) + + if quick_update: + return + plotter.tracer( tracer=tracer, grid=grid, + image_plane_lines=ip_lines, + image_plane_line_colors=ip_colors, + source_plane_lines=sp_lines, + source_plane_line_colors=sp_colors, ) plotter.galaxies( galaxies=tracer.galaxies, diff --git a/autolens/point/plot/fit_point_plots.py b/autolens/point/plot/fit_point_plots.py index 4873c2cab..2507f9aee 100644 --- a/autolens/point/plot/fit_point_plots.py +++ b/autolens/point/plot/fit_point_plots.py @@ -9,6 +9,10 @@ def subplot_fit( fit, output_path: Optional[str] = None, output_format: str = "png", + image_plane_lines=None, + image_plane_line_colors=None, + source_plane_lines=None, + source_plane_line_colors=None, ): """ Produce a subplot summarising a `FitPointDataset`.