Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions autolens/imaging/model/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from typing import List

import autoarray as aa
from autogalaxy.imaging.model.plotter import PlotterImaging as AgPlotterImaging
from autogalaxy.imaging.plot.fit_imaging_plots import (
fits_fit,
Expand All @@ -18,6 +19,7 @@
subplot_tracer_from_fit,
subplot_fit_combined,
subplot_fit_combined_log10,
_compute_critical_curve_lines,
)

from autolens.analysis.plotter import plot_setting
Expand Down Expand Up @@ -50,29 +52,57 @@ def should_plot(name):

plane_indexes_to_plot = [i for i in fit.tracer.plane_indexes_with_images if i != 0]

# Compute critical curves and caustics once for all subplot functions.
tracer = fit.tracer_linear_light_profiles_to_light_profiles
_zoom = aa.Zoom2D(mask=fit.mask)
_cc_grid = aa.Grid2D.from_extent(
extent=_zoom.extent_from(buffer=0), shape_native=_zoom.shape_native
)
ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines(tracer, _cc_grid)

if should_plot("subplot_fit") or quick_update:

if len(fit.tracer.planes) > 2:
for plane_index in plane_indexes_to_plot:
subplot_fit(fit, output_path=output_path, output_format=fmt,
plane_index=plane_index)
subplot_fit(
fit, output_path=output_path, output_format=fmt,
plane_index=plane_index,
image_plane_lines=ip_lines, image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines, source_plane_line_colors=sp_colors,
)
else:
subplot_fit(fit, output_path=output_path, output_format=fmt)
subplot_fit(
fit, output_path=output_path, output_format=fmt,
image_plane_lines=ip_lines, image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines, source_plane_line_colors=sp_colors,
)

if quick_update:
return

if plot_setting(section="tracer", name="subplot_tracer"):
subplot_tracer_from_fit(fit, output_path=output_path, output_format=fmt)
subplot_tracer_from_fit(
fit, output_path=output_path, output_format=fmt,
image_plane_lines=ip_lines, image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines, source_plane_line_colors=sp_colors,
)

if should_plot("subplot_fit_log10"):
try:
if len(fit.tracer.planes) > 2:
for plane_index in plane_indexes_to_plot:
subplot_fit_log10(fit, output_path=output_path, output_format=fmt,
plane_index=plane_index)
subplot_fit_log10(
fit, output_path=output_path, output_format=fmt,
plane_index=plane_index,
image_plane_lines=ip_lines, image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines, source_plane_line_colors=sp_colors,
)
else:
subplot_fit_log10(fit, output_path=output_path, output_format=fmt)
subplot_fit_log10(
fit, output_path=output_path, output_format=fmt,
image_plane_lines=ip_lines, image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines, source_plane_line_colors=sp_colors,
)
except ValueError:
pass

Expand Down
163 changes: 105 additions & 58 deletions autolens/imaging/plot/fit_imaging_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,43 @@
logger = logging.getLogger(__name__)


def _compute_critical_curve_lines(tracer, grid):
"""Compute critical-curve and caustic lines for a tracer on a given grid.

Returns a 4-tuple ``(image_plane_lines, image_plane_line_colors,
source_plane_lines, source_plane_line_colors)`` suitable for passing
directly to :func:`~autoarray.plot.array.plot_array`. On failure
(e.g. the mass model has no critical curves) returns
``(None, None, None, None)``.

Parameters
----------
tracer
The tracer whose mass distribution is used to trace critical curves
and caustics.
grid
Image-plane grid on which the curves are evaluated.
"""
try:
tan_cc, rad_cc = _critical_curves_from(tracer, grid)
tan_ca, rad_ca = _caustics_from(tracer, grid)
_tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or []
_rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or []
_tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or []
_rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or []
image_plane_lines = (_tan_cc_lines + _rad_cc_lines) or None
image_plane_line_colors = (
["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines)
)
source_plane_lines = (_tan_ca_lines + _rad_ca_lines) or None
source_plane_line_colors = (
["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines)
)
return image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors
except Exception:
return None, None, None, None


def _get_source_vmax(fit):
"""
Return the colour-scale maximum for source-plane panels.
Expand Down Expand Up @@ -43,17 +80,22 @@ def _get_source_vmax(fit):
return None


from autolens.lens.plot.tracer_plots import plane_image_from


def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
colormap=None, use_log10=False, title=None,
lines=None, line_colors=None):
lines=None, line_colors=None, vmax=None):
"""
Plot the source-plane image (or a blank inversion placeholder) into an axes.

When the plane at ``plane_index`` does not contain a
`~autoarray.Pixelization` (i.e. it is a parametric source), the
function ray-traces a zoomed image-plane grid to the source plane,
evaluates the source-galaxy light, and renders the resulting 2-D array
via :func:`~autoarray.plot.array.plot_array`. When the plane *does*
`~autoarray.Pixelization` (i.e. it is a parametric source), the source
galaxy light profiles are evaluated on a plain uniform grid
(``fit.mask.derive_grid.all_false``) — **not** a ray-traced grid. This
shows the source as it appears in its own plane, without lensing
distortion. :func:`~autolens.lens.plot.tracer_plots.plane_image_from`
handles the optional zoom to the brightest region. When the plane *does*
contain a pixelization (an inversion source), the source reconstruction
is rendered via :func:`~autoarray.inversion.plot.mapper_plots.plot_mapper`
using ``zoom_to_brightest`` to control whether the view is zoomed in on
Expand All @@ -69,27 +111,25 @@ def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
plane_index : int
Index of the plane in ``fit.tracer.planes`` to visualise.
zoom_to_brightest : bool, optional
For inversion sources, zooms the colormap extent to the brightest
reconstructed pixels. For parametric sources, this parameter has
no effect.
For parametric sources, zooms the evaluation grid in on the brightest
region of the source plane via :func:`plane_image_from`. For inversion
sources, zooms the colormap extent to the brightest reconstructed pixels.
colormap : str, optional
Matplotlib colormap name.
use_log10 : bool, optional
If ``True`` the colour scale is applied on a log10 stretch.
"""
tracer = fit.tracer_linear_light_profiles_to_light_profiles
if not tracer.planes[plane_index].has(cls=aa.Pixelization):
zoom = aa.Zoom2D(mask=fit.mask)
grid = aa.Grid2D.from_extent(
extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native
image = plane_image_from(
galaxies=tracer.planes[plane_index],
grid=fit.mask.derive_grid.all_false,
zoom_to_brightest=zoom_to_brightest,
)
traced_grids = tracer.traced_grid_2d_list_from(grid=grid)
plane_galaxies = ag.Galaxies(galaxies=tracer.planes[plane_index])
image = plane_galaxies.image_2d_from(grid=traced_grids[0])
plot_array(
array=image, ax=ax,
title=title if title is not None else f"Source Plane {plane_index}",
colormap=colormap, use_log10=use_log10, lines=lines,
colormap=colormap, use_log10=use_log10, vmax=vmax, lines=lines,
line_colors=line_colors,
)
else:
Expand All @@ -106,6 +146,7 @@ def _plot_source_plane(fit, ax, plane_index, zoom_to_brightest=True,
title=title if title is not None else f"Source Reconstruction (plane {plane_index})",
colormap=colormap,
use_log10=use_log10,
vmax=vmax,
zoom_to_brightest=zoom_to_brightest,
lines=lines,
line_colors=line_colors,
Expand All @@ -123,6 +164,10 @@ def subplot_fit(
output_format: str = "png",
colormap: Optional[str] = None,
plane_index: Optional[int] = None,
image_plane_lines=None,
image_plane_line_colors=None,
source_plane_lines=None,
source_plane_line_colors=None,
):
"""
Produce a 12-panel subplot summarising an imaging fit.
Expand Down Expand Up @@ -171,30 +216,16 @@ def subplot_fit(

source_vmax = _get_source_vmax(fit)

tracer = fit.tracer_linear_light_profiles_to_light_profiles
try:
if image_plane_lines is None and source_plane_lines is None:
tracer = fit.tracer_linear_light_profiles_to_light_profiles
_zoom = aa.Zoom2D(mask=fit.mask)
_cc_grid = aa.Grid2D.from_extent(
extent=_zoom.extent_from(buffer=0),
shape_native=_zoom.shape_native,
)
tan_cc, rad_cc = _critical_curves_from(tracer, _cc_grid)
tan_ca, rad_ca = _caustics_from(tracer, _cc_grid)
_tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or []
_rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or []
_tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or []
_rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or []
image_plane_lines = _tan_cc_lines + _rad_cc_lines
image_plane_line_colors = ["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines)
source_plane_lines = _tan_ca_lines + _rad_ca_lines
source_plane_line_colors = ["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines)
image_plane_lines = image_plane_lines or None
source_plane_lines = source_plane_lines or None
except Exception:
image_plane_lines = None
image_plane_line_colors = None
source_plane_lines = None
source_plane_line_colors = None
image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = (
_compute_critical_curve_lines(tracer, _cc_grid)
)

fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
axes_flat = list(axes.flatten())
Expand Down Expand Up @@ -249,7 +280,8 @@ def subplot_fit(
# Source plane zoomed
_plot_source_plane(fit, axes_flat[7], final_plane_index, zoom_to_brightest=True,
colormap=colormap, title="Source Plane (Zoomed)",
lines=source_plane_lines, line_colors=source_plane_line_colors)
lines=source_plane_lines, line_colors=source_plane_line_colors,
vmax=source_vmax)

# Normalized residual map (symmetric)
norm_resid = fit.normalized_residual_map
Expand All @@ -268,7 +300,8 @@ def subplot_fit(
# Source plane not zoomed
_plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False,
colormap=colormap, title="Source Plane (No Zoom)",
lines=source_plane_lines, line_colors=source_plane_line_colors)
lines=source_plane_lines, line_colors=source_plane_line_colors,
vmax=source_vmax)

hide_unused_axes(axes_flat)
plt.tight_layout()
Expand Down Expand Up @@ -345,6 +378,10 @@ def subplot_fit_log10(
output_format: str = "png",
colormap: Optional[str] = None,
plane_index: Optional[int] = None,
image_plane_lines=None,
image_plane_line_colors=None,
source_plane_lines=None,
source_plane_line_colors=None,
):
"""
Produce a 12-panel subplot summarising an imaging fit with log10 colour scaling.
Expand Down Expand Up @@ -384,6 +421,17 @@ def subplot_fit_log10(

source_vmax = _get_source_vmax(fit)

if image_plane_lines is None and source_plane_lines is None:
tracer = fit.tracer_linear_light_profiles_to_light_profiles
_zoom = aa.Zoom2D(mask=fit.mask)
_cc_grid = aa.Grid2D.from_extent(
extent=_zoom.extent_from(buffer=0),
shape_native=_zoom.shape_native,
)
image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = (
_compute_critical_curve_lines(tracer, _cc_grid)
)

fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
axes_flat = list(axes.flatten())

Expand All @@ -403,7 +451,8 @@ def subplot_fit_log10(
axes_flat[2].axis("off")

plot_array(array=fit.model_data, ax=axes_flat[3], title="Model Image",
colormap=colormap, use_log10=True)
colormap=colormap, use_log10=True, lines=image_plane_lines,
line_colors=image_plane_line_colors)

try:
lens_model_img = fit.model_images_of_planes_list[0]
Expand All @@ -422,12 +471,15 @@ def subplot_fit_log10(
try:
source_model_img = fit.model_images_of_planes_list[final_plane_index]
plot_array(array=source_model_img, ax=axes_flat[6],
title="Source Model Image", colormap=colormap, use_log10=True)
title="Source Model Image", colormap=colormap, use_log10=True,
lines=image_plane_lines, line_colors=image_plane_line_colors)
except (IndexError, AttributeError):
axes_flat[6].axis("off")

_plot_source_plane(fit, axes_flat[7], final_plane_index, zoom_to_brightest=True,
colormap=colormap, use_log10=True)
colormap=colormap, use_log10=True,
lines=source_plane_lines, line_colors=source_plane_line_colors,
vmax=source_vmax)

norm_resid = fit.normalized_residual_map
_abs_max = _symmetric_vmax(norm_resid)
Expand All @@ -442,7 +494,9 @@ def subplot_fit_log10(
colormap=colormap, use_log10=True, cb_unit=r"$\chi^2$")

_plot_source_plane(fit, axes_flat[11], final_plane_index, zoom_to_brightest=False,
colormap=colormap, use_log10=True)
colormap=colormap, use_log10=True,
lines=source_plane_lines, line_colors=source_plane_line_colors,
vmax=source_vmax)

plt.tight_layout()
save_figure(fig, path=output_path, filename=f"fit_log10{plane_index_tag}", format=output_format)
Expand Down Expand Up @@ -582,6 +636,10 @@ def subplot_tracer_from_fit(
output_path: Optional[str] = None,
output_format: str = "png",
colormap: Optional[str] = None,
image_plane_lines=None,
image_plane_line_colors=None,
source_plane_lines=None,
source_plane_line_colors=None,
):
"""
Produce a 9-panel tracer subplot derived from a `FitImaging` object.
Expand Down Expand Up @@ -614,28 +672,16 @@ def subplot_tracer_from_fit(
final_plane_index = len(fit.tracer.planes) - 1
tracer = fit.tracer_linear_light_profiles_to_light_profiles

# --- grid and critical curves (computed first so all panels can use them) ---
# --- grid ---
zoom = aa.Zoom2D(mask=fit.mask)
grid = aa.Grid2D.from_extent(
extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native
)

try:
tan_cc, rad_cc = _critical_curves_from(tracer, grid)
tan_ca, rad_ca = _caustics_from(tracer, grid)
_tan_cc_lines = _to_lines(list(tan_cc) if tan_cc is not None else []) or []
_rad_cc_lines = _to_lines(list(rad_cc) if rad_cc is not None else []) or []
_tan_ca_lines = _to_lines(list(tan_ca) if tan_ca is not None else []) or []
_rad_ca_lines = _to_lines(list(rad_ca) if rad_ca is not None else []) or []
image_plane_lines = (_tan_cc_lines + _rad_cc_lines) or None
image_plane_line_colors = ["black"] * len(_tan_cc_lines) + ["white"] * len(_rad_cc_lines)
source_plane_lines = (_tan_ca_lines + _rad_ca_lines) or None
source_plane_line_colors = ["black"] * len(_tan_ca_lines) + ["white"] * len(_rad_ca_lines)
except Exception:
image_plane_lines = None
image_plane_line_colors = None
source_plane_lines = None
source_plane_line_colors = None
if image_plane_lines is None and source_plane_lines is None:
image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = (
_compute_critical_curve_lines(tracer, grid)
)

source_vmax = _get_source_vmax(fit)

Expand Down Expand Up @@ -672,7 +718,8 @@ def subplot_tracer_from_fit(
# Panel 2: Source Plane (No Zoom)
_plot_source_plane(fit, axes_flat[2], final_plane_index, zoom_to_brightest=False,
colormap=colormap, title="Source Plane (No Zoom)",
lines=source_plane_lines, line_colors=source_plane_line_colors)
lines=source_plane_lines, line_colors=source_plane_line_colors,
vmax=source_vmax)

# Panel 3: Lens Image (log10)
plot_array(array=lens_image, ax=axes_flat[3], title="Lens Image",
Expand Down
2 changes: 1 addition & 1 deletion autolens/interferometer/plot/fit_interferometer_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def subplot_fit_real_space(
for _ax in axes_flat:
_ax.axis("off")
axes_flat[0].set_title("Reconstructed Data")
axes_flat[1].set_title("Source Reconstruction")
axes_flat[1].set_title("Source Plane (Zoom)")

plt.tight_layout()
save_figure(fig, path=output_path, filename="fit_real_space", format=output_format)
Loading
Loading