Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions autolens/analysis/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def tracer(
self,
tracer: Tracer,
grid: aa.type.Grid2DLike,
image_plane_lines=None,
image_plane_line_colors=None,
source_plane_lines=None,
source_plane_line_colors=None,
):
"""
Visualizes a `Tracer` object.
Expand All @@ -47,6 +51,14 @@ def tracer(
The maximum log likelihood `Tracer` of the non-linear search.
grid
A 2D grid of (y,x) arc-second coordinates used to perform ray-tracing.
image_plane_lines
Pre-computed critical-curve lines to overlay on image-plane panels.
image_plane_line_colors
Colours for each image-plane line.
source_plane_lines
Pre-computed caustic lines to overlay on source-plane panels.
source_plane_line_colors
Colours for each source-plane line.
"""

def should_plot(name):
Expand All @@ -55,6 +67,18 @@ def should_plot(name):
output_path = str(self.image_path)
fmt = self.fmt

if should_plot("subplot_tracer"):
subplot_tracer(
tracer=tracer,
grid=grid,
output_path=output_path,
output_format=fmt,
image_plane_lines=image_plane_lines,
image_plane_line_colors=image_plane_line_colors,
source_plane_lines=source_plane_lines,
source_plane_line_colors=source_plane_line_colors,
)

if should_plot("subplot_galaxies_images"):
subplot_galaxies_images(
tracer=tracer,
Expand Down
34 changes: 25 additions & 9 deletions autolens/imaging/model/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
subplot_tracer_from_fit,
subplot_fit_combined,
subplot_fit_combined_log10,
_compute_critical_curve_lines,
_compute_critical_curves_from_fit,
)

from autolens.analysis.plotter import plot_setting
Expand All @@ -31,7 +31,13 @@ class PlotterImaging(Plotter):
imaging_combined = AgPlotterImaging.imaging_combined

def fit_imaging(
self, fit: FitImaging, quick_update: bool = False
self,
fit: FitImaging,
quick_update: bool = False,
image_plane_lines=None,
image_plane_line_colors=None,
source_plane_lines=None,
source_plane_line_colors=None,
):
"""
Visualizes a `FitImaging` object, which fits an imaging dataset.
Expand All @@ -42,6 +48,14 @@ def fit_imaging(
The maximum log likelihood `FitImaging` of the non-linear search.
quick_update
If True only the essential subplot_fit is output.
image_plane_lines
Pre-computed critical-curve lines. Computed internally if not provided.
image_plane_line_colors
Colours for each image-plane line.
source_plane_lines
Pre-computed caustic lines. Computed internally if not provided.
source_plane_line_colors
Colours for each source-plane line.
"""

def should_plot(name):
Expand All @@ -52,13 +66,15 @@ def should_plot(name):

plane_indexes_to_plot = [i for i in fit.tracer.plane_indexes_with_images if i != 0]

# Compute critical curves and caustics once for all subplot functions.
tracer = fit.tracer_linear_light_profiles_to_light_profiles
_zoom = aa.Zoom2D(mask=fit.mask)
_cc_grid = aa.Grid2D.from_extent(
extent=_zoom.extent_from(buffer=0), shape_native=_zoom.shape_native
)
ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines(tracer, _cc_grid)
# Compute critical curves and caustics once for all subplot functions,
# unless already provided by the caller (e.g. the visualizer).
if image_plane_lines is None and source_plane_lines is None:
ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curves_from_fit(fit)
else:
ip_lines, ip_colors, sp_lines, sp_colors = (
image_plane_lines, image_plane_line_colors,
source_plane_lines, source_plane_line_colors,
)

if should_plot("subplot_fit") or quick_update:

Expand Down
21 changes: 14 additions & 7 deletions autolens/imaging/model/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import autogalaxy as ag

from autolens.imaging.model.plotter import PlotterImaging
from autolens.imaging.plot.fit_imaging_plots import _compute_critical_curves_from_fit

from autolens import exc

Expand Down Expand Up @@ -101,10 +102,19 @@ def visualize(
title_prefix=analysis.title_prefix,
)

grid = fit.mask.derive_grid.all_false

# Compute critical curves once for all plot functions.
ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curves_from_fit(fit)

try:
plotter.fit_imaging(
fit=fit,
quick_update=quick_update,
image_plane_lines=ip_lines,
image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines,
source_plane_line_colors=sp_colors,
)
except exc.InversionException:
pass
Expand Down Expand Up @@ -135,16 +145,13 @@ def visualize(
)
return

zoom = ag.Zoom2D(mask=fit.mask)

extent = zoom.extent_from(buffer=0)
shape_native = zoom.shape_native

grid = ag.Grid2D.from_extent(extent=extent, shape_native=shape_native)

plotter.tracer(
tracer=tracer,
grid=grid,
image_plane_lines=ip_lines,
image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines,
source_plane_line_colors=sp_colors,
)
plotter.galaxies(
galaxies=tracer.galaxies,
Expand Down
44 changes: 21 additions & 23 deletions autolens/imaging/plot/fit_imaging_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ def _compute_critical_curve_lines(tracer, grid):
return None, None, None, None


def _compute_critical_curves_from_fit(fit):
"""Compute critical-curve and caustic lines from a FitImaging object.

Convenience wrapper around :func:`_compute_critical_curve_lines` that
derives the tracer and grid from *fit* directly, using the fully unmasked
image-plane grid so the curves cover the whole image extent.

Returns the same 4-tuple as :func:`_compute_critical_curve_lines`.
"""
tracer = fit.tracer_linear_light_profiles_to_light_profiles
return _compute_critical_curve_lines(tracer, fit.mask.derive_grid.all_false)


def _get_source_vmax(fit):
"""
Return the colour-scale maximum for source-plane panels.
Expand Down Expand Up @@ -217,14 +230,8 @@ def subplot_fit(
source_vmax = _get_source_vmax(fit)

if image_plane_lines is None and source_plane_lines is None:
tracer = fit.tracer_linear_light_profiles_to_light_profiles
_zoom = aa.Zoom2D(mask=fit.mask)
_cc_grid = aa.Grid2D.from_extent(
extent=_zoom.extent_from(buffer=0),
shape_native=_zoom.shape_native,
)
image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = (
_compute_critical_curve_lines(tracer, _cc_grid)
_compute_critical_curves_from_fit(fit)
)

fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
Expand Down Expand Up @@ -422,14 +429,8 @@ def subplot_fit_log10(
source_vmax = _get_source_vmax(fit)

if image_plane_lines is None and source_plane_lines is None:
tracer = fit.tracer_linear_light_profiles_to_light_profiles
_zoom = aa.Zoom2D(mask=fit.mask)
_cc_grid = aa.Grid2D.from_extent(
extent=_zoom.extent_from(buffer=0),
shape_native=_zoom.shape_native,
)
image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = (
_compute_critical_curve_lines(tracer, _cc_grid)
_compute_critical_curves_from_fit(fit)
)

fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
Expand Down Expand Up @@ -673,14 +674,11 @@ def subplot_tracer_from_fit(
tracer = fit.tracer_linear_light_profiles_to_light_profiles

# --- grid ---
zoom = aa.Zoom2D(mask=fit.mask)
grid = aa.Grid2D.from_extent(
extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native
)
grid = fit.mask.derive_grid.all_false

if image_plane_lines is None and source_plane_lines is None:
image_plane_lines, image_plane_line_colors, source_plane_lines, source_plane_line_colors = (
_compute_critical_curve_lines(tracer, grid)
_compute_critical_curves_from_fit(fit)
)

source_vmax = _get_source_vmax(fit)
Expand All @@ -689,7 +687,7 @@ def subplot_tracer_from_fit(
lens_galaxies = ag.Galaxies(galaxies=tracer.planes[0])
lens_image = lens_galaxies.image_2d_from(grid=traced_grids[0])

deflections = lens_galaxies.deflections_yx_2d_from(grid=grid)
deflections = tracer.deflections_yx_2d_from(grid=grid)
deflections_y = aa.Array2D(values=deflections.slim[:, 0], mask=grid.mask)
deflections_x = aa.Array2D(values=deflections.slim[:, 1], mask=grid.mask)

Expand All @@ -703,10 +701,10 @@ def subplot_tracer_from_fit(
lines=image_plane_lines, line_colors=image_plane_line_colors,
colormap=colormap)

# Panel 1: Source Model Image (image-plane projection)
# Panel 1: Source Model Image (same as subplot_fit panel 7)
try:
source_model_img = fit.model_images_of_planes_list[final_plane_index]
except Exception:
except (IndexError, AttributeError):
source_model_img = None
if source_model_img is not None:
plot_array(array=source_model_img, ax=axes_flat[1], title="Source Model Image",
Expand All @@ -715,7 +713,7 @@ def subplot_tracer_from_fit(
else:
axes_flat[1].axis("off")

# Panel 2: Source Plane (No Zoom)
# Panel 2: Source Plane (No Zoom) (same as subplot_fit panel 12)
_plot_source_plane(fit, axes_flat[2], final_plane_index, zoom_to_brightest=False,
colormap=colormap, title="Source Plane (No Zoom)",
lines=source_plane_lines, line_colors=source_plane_line_colors,
Expand Down
49 changes: 42 additions & 7 deletions autolens/interferometer/model/plotter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import autoarray as aa

from autogalaxy.interferometer.model.plotter import (
PlotterInterferometer as AgPlotterInterferometer,
)

from autogalaxy.interferometer.plot import fit_interferometer_plots as ag_fit_interferometer_plots
from autogalaxy.interferometer.plot.fit_interferometer_plots import (
fits_galaxy_images,
fits_dirty_images,
Expand All @@ -11,7 +12,9 @@
from autolens.interferometer.fit_interferometer import FitInterferometer
from autolens.interferometer.plot.fit_interferometer_plots import (
subplot_fit,
subplot_fit_dirty_images,
subplot_fit_real_space,
_compute_critical_curve_lines,
)
from autolens.analysis.plotter import Plotter

Expand All @@ -25,6 +28,10 @@ def fit_interferometer(
self,
fit: FitInterferometer,
quick_update: bool = False,
image_plane_lines=None,
image_plane_line_colors=None,
source_plane_lines=None,
source_plane_line_colors=None,
):
"""
Visualizes a `FitInterferometer` object.
Expand All @@ -33,6 +40,14 @@ def fit_interferometer(
----------
fit
The maximum log likelihood `FitInterferometer` of the non-linear search.
image_plane_lines
Pre-computed critical-curve lines. Computed internally if not provided.
image_plane_line_colors
Colours for each image-plane line.
source_plane_lines
Pre-computed caustic lines. Computed internally if not provided.
source_plane_line_colors
Colours for each source-plane line.
"""

def should_plot(name):
Expand All @@ -41,21 +56,41 @@ def should_plot(name):
output_path = str(self.image_path)
fmt = self.fmt

# Use pre-computed critical curves if provided, otherwise compute once here.
if image_plane_lines is None and source_plane_lines is None:
tracer = fit.tracer_linear_light_profiles_to_light_profiles
_zoom = aa.Zoom2D(mask=fit.dataset.real_space_mask)
_cc_grid = aa.Grid2D.from_extent(
extent=_zoom.extent_from(buffer=0), shape_native=_zoom.shape_native
)
ip_lines, ip_colors, sp_lines, sp_colors = _compute_critical_curve_lines(tracer, _cc_grid)
else:
ip_lines, ip_colors, sp_lines, sp_colors = (
image_plane_lines, image_plane_line_colors,
source_plane_lines, source_plane_line_colors,
)

if should_plot("subplot_fit"):
subplot_fit(fit, output_path=output_path, output_format=fmt)
subplot_fit(
fit, output_path=output_path, output_format=fmt,
image_plane_lines=ip_lines, image_plane_line_colors=ip_colors,
source_plane_lines=sp_lines, source_plane_line_colors=sp_colors,
)

if should_plot("subplot_fit_dirty_images") or quick_update:
ag_fit_interferometer_plots.subplot_fit_dirty_images(
fit=fit,
output_path=self.image_path,
output_format=self.fmt,
subplot_fit_dirty_images(
fit, output_path=output_path, output_format=fmt,
image_plane_lines=ip_lines, image_plane_line_colors=ip_colors,
)

if quick_update:
return

if should_plot("subplot_fit_real_space"):
subplot_fit_real_space(fit, output_path=output_path, output_format=fmt)
subplot_fit_real_space(
fit, output_path=output_path, output_format=fmt,
source_plane_lines=sp_lines, source_plane_line_colors=sp_colors,
)

if should_plot("fits_galaxy_images"):
fits_galaxy_images(fit=fit, output_path=self.image_path)
Expand Down
Loading
Loading