diff --git a/README.rst b/README.rst index 44658160e..bf8319815 100644 --- a/README.rst +++ b/README.rst @@ -26,6 +26,18 @@ PyAutoLens: Open-Source Strong Lensing .. |arXiv| image:: https://img.shields.io/badge/arXiv-1708.07377-blue :target: https://arxiv.org/abs/1708.07377 +.. image:: https://www.repostatus.org/badges/latest/active.svg + :target: https://www.repostatus.org/#active + :alt: Project Status: Active + +.. image:: https://img.shields.io/pypi/pyversions/autolens + :target: https://pypi.org/project/autolens/ + :alt: Python Versions + +.. image:: https://img.shields.io/pypi/v/autolens.svg + :target: https://pypi.org/project/autolens/ + :alt: PyPI Version + |binder| |RTD| |Tests| |Build| |code-style| |JOSS| |arXiv| `Installation Guide `_ | diff --git a/autolens/analysis/plotter_interface.py b/autolens/analysis/plotter_interface.py index 3654401b6..7365018d7 100644 --- a/autolens/analysis/plotter_interface.py +++ b/autolens/analysis/plotter_interface.py @@ -1,5 +1,6 @@ import ast import numpy as np +from typing import Optional from autoconf import conf from autoconf.fitsable import hdu_list_for_output_from @@ -31,7 +32,12 @@ class PlotterInterface(AgPlotterInterface): The path on the hard-disk to the `image` folder of the non-linear searches results. """ - def tracer(self, tracer: Tracer, grid: aa.type.Grid2DLike): + def tracer( + self, + tracer: Tracer, + grid: aa.type.Grid2DLike, + visuals_2d_of_planes_list: Optional[aplt.Visuals2D] = None, + ): """ Visualizes a `Tracer` object. @@ -63,7 +69,7 @@ def should_plot(name): tracer=tracer, grid=grid, mat_plot_2d=mat_plot_2d, - include_2d=self.include_2d, + visuals_2d_of_planes_list=visuals_2d_of_planes_list, ) if should_plot("subplot_galaxies_images"): @@ -169,7 +175,6 @@ def should_plot(name): image_plotter = aplt.Array2DPlotter( array=image, mat_plot_2d=mat_plot_2d, - include_2d=self.include_2d, visuals_2d=visuals_2d, ) diff --git a/autolens/imaging/model/plotter_interface.py b/autolens/imaging/model/plotter_interface.py index a795525a6..f4bb79889 100644 --- a/autolens/imaging/model/plotter_interface.py +++ b/autolens/imaging/model/plotter_interface.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import autoarray.plot as aplt @@ -19,7 +19,7 @@ class PlotterInterfaceImaging(PlotterInterface): imaging_combined = AgPlotterInterfaceImaging.imaging_combined def fit_imaging( - self, fit: FitImaging, + self, fit: FitImaging, visuals_2d_of_planes_list : Optional[aplt.Visuals2D] = None ): """ Visualizes a `FitImaging` object, which fits an imaging dataset. @@ -45,7 +45,7 @@ def fit_imaging( mat_plot_2d = self.mat_plot_2d_from() fit_plotter = FitImagingPlotter( - fit=fit, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d + fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list, ) fit_plotter.subplot_tracer() @@ -56,7 +56,7 @@ def should_plot(name): mat_plot_2d = self.mat_plot_2d_from() fit_plotter = FitImagingPlotter( - fit=fit, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d + fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list, ) plane_indexes_to_plot = [i for i in fit.tracer.plane_indexes_with_images if i != 0] @@ -72,6 +72,7 @@ def should_plot(name): fit_plotter.subplot_fit() if should_plot("subplot_fit_log10"): + try: if len(fit.tracer.planes) > 2: for plane_index in plane_indexes_to_plot: @@ -81,6 +82,7 @@ def should_plot(name): except ValueError: pass + if should_plot("subplot_of_planes"): fit_plotter.subplot_of_planes() @@ -92,7 +94,7 @@ def should_plot(name): fits_to_fits(should_plot=should_plot, image_path=self.image_path, fit=fit) - def fit_imaging_combined(self, fit_list: List[FitImaging]): + def fit_imaging_combined(self, fit_list: List[FitImaging], visuals_2d_of_planes_list : Optional[aplt.Visuals2D] = None): """ Output visualization of all `FitImaging` objects in a summed combined analysis, typically during or after a model-fit is performed. @@ -119,7 +121,7 @@ def should_plot(name): fit_plotter_list = [ FitImagingPlotter( - fit=fit, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d + fit=fit, mat_plot_2d=mat_plot_2d, visuals_2d_of_planes_list=visuals_2d_of_planes_list, ) for fit in fit_list ] diff --git a/autolens/imaging/model/visualizer.py b/autolens/imaging/model/visualizer.py index 72069c31d..3c2724f24 100644 --- a/autolens/imaging/model/visualizer.py +++ b/autolens/imaging/model/visualizer.py @@ -4,6 +4,8 @@ import autogalaxy as ag from autolens.imaging.model.plotter_interface import PlotterInterfaceImaging + +from autolens.lens import tracer_util from autolens import exc @@ -109,15 +111,6 @@ def visualize( except exc.InversionException: return - plotter_interface = PlotterInterfaceImaging( - image_path=paths.image_path, title_prefix=analysis.title_prefix - ) - - try: - plotter_interface.fit_imaging(fit=fit) - except exc.InversionException: - pass - tracer = fit.tracer_linear_light_profiles_to_light_profiles zoom = ag.Zoom2D(mask=fit.mask) @@ -127,8 +120,28 @@ def visualize( grid = ag.Grid2D.from_extent(extent=extent, shape_native=shape_native) + visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from( + tracer=fit.tracer, + grid=fit.grids.lp.mask.derive_grid.all_false + ) + + plotter_interface = PlotterInterfaceImaging( + image_path=paths.image_path, + title_prefix=analysis.title_prefix, + ) + + try: + plotter_interface.fit_imaging( + fit=fit, + visuals_2d_of_planes_list=visuals_2d_of_planes_list + ) + except exc.InversionException: + pass + plotter_interface.tracer( - tracer=tracer, grid=grid, + tracer=tracer, + grid=grid, + visuals_2d_of_planes_list=visuals_2d_of_planes_list ) plotter_interface.galaxies( galaxies=tracer.galaxies, diff --git a/autolens/imaging/plot/fit_imaging_plotters.py b/autolens/imaging/plot/fit_imaging_plotters.py index 8d01c7aef..0a47fcdca 100644 --- a/autolens/imaging/plot/fit_imaging_plotters.py +++ b/autolens/imaging/plot/fit_imaging_plotters.py @@ -14,6 +14,8 @@ from autolens.imaging.fit_imaging import FitImaging from autolens.lens.plot.tracer_plotters import TracerPlotter +from autolens.lens import tracer_util + class FitImagingPlotter(Plotter): def __init__( @@ -21,8 +23,8 @@ def __init__( fit: FitImaging, mat_plot_2d: aplt.MatPlot2D = None, visuals_2d: aplt.Visuals2D = None, - include_2d: aplt.Include2D = None, - residuals_symmetric_cmap: bool = True + residuals_symmetric_cmap: bool = True, + visuals_2d_of_planes_list : Optional = None ): """ Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -33,8 +35,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- @@ -44,38 +45,69 @@ def __init__( Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. residuals_symmetric_cmap If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such that `abs(vmin) = abs(vmax)`. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit self._fit_imaging_meta_plotter = FitImagingPlotterMeta( fit=self.fit, - get_visuals_2d=self.get_visuals_2d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, - residuals_symmetric_cmap=residuals_symmetric_cmap + residuals_symmetric_cmap=residuals_symmetric_cmap, ) self.residuals_symmetric_cmap = residuals_symmetric_cmap - def get_visuals_2d(self) -> aplt.Visuals2D: - return self.get_2d.via_fit_imaging_from(fit=self.fit) + self._visuals_2d_of_planes_list = visuals_2d_of_planes_list + + @property + def visuals_2d_of_planes_list(self): + + if self._visuals_2d_of_planes_list is None: + + self._visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from( + tracer=self.fit.tracer, + grid=self.fit.grids.lp.mask.derive_grid.all_false, + ) + + return self._visuals_2d_of_planes_list + + def visuals_2d_from( + self, plane_index: Optional[int] = None, remove_critical_caustic: bool = False + ) -> aplt.Visuals2D: + """ + Returns the `Visuals2D` of the plotter with critical curves and caustics added, which are used to plot + the critical curves and caustics of the `Tracer` object. + + If `remove_critical_caustic` is `True`, critical curves and caustics are not included in the visuals. + + Parameters + ---------- + plane_index + The index of the plane in the tracer which is used to extract quantities, as only one plane is plotted + at a time. + remove_critical_caustic + Whether to remove critical curves and caustics from the visuals. + """ + if remove_critical_caustic: + return self.visuals_2d + + return ( + self.visuals_2d + + self.visuals_2d_of_planes_list[plane_index] + ) @property def tracer(self): return self.fit.tracer_linear_light_profiles_to_light_profiles - @property - def tracer_plotter(self) -> TracerPlotter: + def tracer_plotter_of_plane( + self, plane_index: int, remove_critical_caustic: bool = False + ) -> TracerPlotter: """ Returns an `TracerPlotter` corresponding to the `Tracer` in the `FitImaging`. """ @@ -83,19 +115,20 @@ def tracer_plotter(self) -> TracerPlotter: zoom = aa.Zoom2D(mask=self.fit.mask) grid = aa.Grid2D.from_extent( - extent=zoom.extent_from(buffer=0), - shape_native=zoom.shape_native + extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native ) - return TracerPlotter( tracer=self.tracer, grid=grid, mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - include_2d=self.include_2d, + visuals_2d=self.visuals_2d_from( + plane_index=plane_index, remove_critical_caustic=remove_critical_caustic + ), ) - def inversion_plotter_of_plane(self, plane_index: int) -> aplt.InversionPlotter: + def inversion_plotter_of_plane( + self, plane_index: int, remove_critical_caustic: bool = False + ) -> aplt.InversionPlotter: """ Returns an `InversionPlotter` corresponding to one of the `Inversion`'s in the fit, which is specified via the index of the `Plane` that inversion was performed on. @@ -110,16 +143,14 @@ def inversion_plotter_of_plane(self, plane_index: int) -> aplt.InversionPlotter: InversionPlotter An object that plots inversions which is used for plotting attributes of the inversion. """ + inversion_plotter = aplt.InversionPlotter( inversion=self.fit.inversion, mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.tracer_plotter.get_visuals_2d_of_plane( - plane_index=plane_index + visuals_2d=self.visuals_2d_from( + plane_index=plane_index, remove_critical_caustic=remove_critical_caustic ), - include_2d=self.include_2d, ) - inversion_plotter.visuals_2d.mask = None - inversion_plotter.visuals_2d.border = None return inversion_plotter def plane_indexes_from(self, plane_index: int): @@ -152,7 +183,7 @@ def figures_2d_of_planes( use_source_vmax: bool = False, zoom_to_brightest: bool = True, interpolate_to_uniform: bool = False, - remove_critical_caustic : bool = True + remove_critical_caustic: bool = False, ): """ Plots images representing each individual `Plane` in the fit's `Tracer` in 2D, which are computed via the @@ -200,23 +231,14 @@ def figures_2d_of_planes( Whether to remove critical curves and caustics from the plot. """ - visuals_2d = self.get_visuals_2d() - - visuals_2d_no_critical_caustic = self.get_visuals_2d() - - if remove_critical_caustic: - - visuals_2d_no_critical_caustic.tangential_critical_curves = None - visuals_2d_no_critical_caustic.radial_critical_curves = None - visuals_2d_no_critical_caustic.tangential_caustics = None - visuals_2d_no_critical_caustic.radial_caustics = None - plane_indexes = self.plane_indexes_from(plane_index=plane_index) for plane_index in plane_indexes: if use_source_vmax: - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[plane_index].array) + self.mat_plot_2d.cmap.kwargs["vmax"] = np.max( + self.fit.model_images_of_planes_list[plane_index].array + ) if subtracted_image: @@ -237,25 +259,15 @@ def figures_2d_of_planes( self.mat_plot_2d.plot_array( array=self.fit.subtracted_images_of_planes_list[plane_index], - visuals_2d=visuals_2d_no_critical_caustic, - auto_labels=aplt.AutoLabels( - title=title, - filename=filename + visuals_2d=self.visuals_2d_from( + plane_index=plane_index, + remove_critical_caustic=remove_critical_caustic, ), + auto_labels=aplt.AutoLabels(title=title, filename=filename), ) if model_image: - if self.tracer.planes[plane_index].has(cls=aa.Pixelization): - - # Overwrite plane_index=0 so that model image uses critical curves -- improve via config cutomization - - visuals_2d_model_image = self.inversion_plotter_of_plane(plane_index=0).get_visuals_2d_for_data() - - else: - - visuals_2d_model_image = visuals_2d - title = f"Model Image of Plane {plane_index}" filename = f"model_image_of_plane_{plane_index}" @@ -273,34 +285,41 @@ def figures_2d_of_planes( self.mat_plot_2d.plot_array( array=self.fit.model_images_of_planes_list[plane_index], - visuals_2d=visuals_2d_model_image, - auto_labels=aplt.AutoLabels( - title=title, - filename=filename + visuals_2d=self.visuals_2d_from( + plane_index=plane_index, + remove_critical_caustic=remove_critical_caustic, ), + auto_labels=aplt.AutoLabels(title=title, filename=filename), ) if plane_image: if not self.tracer.planes[plane_index].has(cls=aa.Pixelization): - self.tracer_plotter.figures_2d_of_planes( + tracer_plotter = self.tracer_plotter_of_plane( + plane_index=plane_index, + remove_critical_caustic=remove_critical_caustic, + ) + + tracer_plotter.figures_2d_of_planes( plane_image=True, plane_index=plane_index, - zoom_to_brightest=zoom_to_brightest + zoom_to_brightest=zoom_to_brightest, + retain_visuals=True, ) elif self.tracer.planes[plane_index].has(cls=aa.Pixelization): inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index + plane_index=plane_index, + remove_critical_caustic=remove_critical_caustic, ) inversion_plotter.figures_2d_of_pixelization( pixelization_index=0, reconstruction=True, zoom_to_brightest=zoom_to_brightest, - interpolate_to_uniform=interpolate_to_uniform + interpolate_to_uniform=interpolate_to_uniform, ) if use_source_vmax: @@ -314,14 +333,15 @@ def figures_2d_of_planes( if self.tracer.planes[plane_index].has(cls=aa.Pixelization): inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index + plane_index=plane_index, + remove_critical_caustic=remove_critical_caustic, ) inversion_plotter.figures_2d_of_pixelization( pixelization_index=0, reconstruction_noise_map=True, zoom_to_brightest=zoom_to_brightest, - interpolate_to_uniform=interpolate_to_uniform + interpolate_to_uniform=interpolate_to_uniform, ) if plane_signal_to_noise_map: @@ -329,50 +349,17 @@ def figures_2d_of_planes( if self.tracer.planes[plane_index].has(cls=aa.Pixelization): inversion_plotter = self.inversion_plotter_of_plane( - plane_index=plane_index + plane_index=plane_index, + remove_critical_caustic=remove_critical_caustic, ) inversion_plotter.figures_2d_of_pixelization( pixelization_index=0, signal_to_noise_map=True, zoom_to_brightest=zoom_to_brightest, - interpolate_to_uniform=interpolate_to_uniform + interpolate_to_uniform=interpolate_to_uniform, ) - def subplot_of_planes(self, plane_index: Optional[int] = None): - """ - Plots images representing each individual `Plane` in the plotter's `Tracer` in 2D on a subplot, which are - computed via the plotter's 2D grid object. - - These images subtract or omit the contribution of other planes in the plane, such that plots showing - each individual plane are made. - - The subplot plots the subtracted image, model image and plane image of each plane, where are described in the - `figures_2d_of_planes` function. - - Parameters - ---------- - plane_index - The index of the plane whose images are included on the subplot. - """ - - plane_indexes = self.plane_indexes_from(plane_index=plane_index) - - for plane_index in plane_indexes: - - self.open_subplot_figure(number_subplots=4) - - self.figures_2d(data=True) - - self.figures_2d_of_planes(subtracted_image=True, plane_index=plane_index) - self.figures_2d_of_planes(model_image=True, plane_index=plane_index) - self.figures_2d_of_planes(plane_image=True, plane_index=plane_index) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"subplot_of_plane_{plane_index}" - ) - self.close_subplot_figure() - def subplot( self, data: bool = False, @@ -439,27 +426,43 @@ def subplot_fit(self, plane_index: Optional[int] = None): self.figures_2d(model_image=True) self.set_title(label="Lens Light Model Image") - self.figures_2d_of_planes(plane_index=0, model_image=True) + self.figures_2d_of_planes( + plane_index=0, model_image=True, remove_critical_caustic=True + ) # If the lens light is not included the subplot index does not increase, so we must manually set it to 4 self.mat_plot_2d.subplot_index = 6 plane_index_tag = "" if plane_index is None else f"_{plane_index}" - plane_index = len(self.fit.tracer.planes) - 1 if plane_index is None else plane_index + plane_index = ( + len(self.fit.tracer.planes) - 1 if plane_index is None else plane_index + ) self.mat_plot_2d.cmap.kwargs["vmin"] = 0.0 self.set_title(label="Lens Light Subtracted Image") - self.figures_2d_of_planes(plane_index=plane_index, subtracted_image=True, use_source_vmax=True) + self.figures_2d_of_planes( + plane_index=plane_index, + subtracted_image=True, + use_source_vmax=True, + remove_critical_caustic=True, + ) self.set_title(label="Source Model Image (Image Plane)") - self.figures_2d_of_planes(plane_index=plane_index, model_image=True, use_source_vmax=True) + self.figures_2d_of_planes( + plane_index=plane_index, + model_image=True, + use_source_vmax=True, + remove_critical_caustic=True, + ) self.mat_plot_2d.cmap.kwargs.pop("vmin") self.set_title(label="Source Plane (Zoomed)") - self.figures_2d_of_planes(plane_index=plane_index, plane_image=True, use_source_vmax=True) + self.figures_2d_of_planes( + plane_index=plane_index, plane_image=True, use_source_vmax=True + ) self.set_title(label=None) @@ -484,7 +487,7 @@ def subplot_fit(self, plane_index: Optional[int] = None): plane_index=plane_index, plane_image=True, zoom_to_brightest=False, - use_source_vmax=True + use_source_vmax=True, ) self.set_title(label=None) @@ -526,27 +529,35 @@ def subplot_fit_log10(self, plane_index: Optional[int] = None): self.figures_2d(model_image=True) self.set_title(label="Lens Light Model Image") - self.figures_2d_of_planes(plane_index=0, model_image=True) + self.figures_2d_of_planes(plane_index=0, model_image=True, remove_critical_caustic=True) # If the lens light is not included the subplot index does not increase, so we must manually set it to 4 self.mat_plot_2d.subplot_index = 6 plane_index_tag = "" if plane_index is None else f"_{plane_index}" - plane_index = len(self.fit.tracer.planes) - 1 if plane_index is None else plane_index + plane_index = ( + len(self.fit.tracer.planes) - 1 if plane_index is None else plane_index + ) self.mat_plot_2d.cmap.kwargs["vmin"] = 0.0 self.set_title(label="Lens Light Subtracted Image") - self.figures_2d_of_planes(plane_index=plane_index, subtracted_image=True, use_source_vmax=True) + self.figures_2d_of_planes( + plane_index=plane_index, subtracted_image=True, use_source_vmax=True, remove_critical_caustic=True + ) self.set_title(label="Source Model Image (Image Plane)") - self.figures_2d_of_planes(plane_index=plane_index, model_image=True, use_source_vmax=True) + self.figures_2d_of_planes( + plane_index=plane_index, model_image=True, use_source_vmax=True, remove_critical_caustic=True + ) self.mat_plot_2d.cmap.kwargs.pop("vmin") self.set_title(label="Source Plane (Zoomed)") - self.figures_2d_of_planes(plane_index=plane_index, plane_image=True, use_source_vmax=True) + self.figures_2d_of_planes( + plane_index=plane_index, plane_image=True, use_source_vmax=True + ) self.set_title(label=None) @@ -575,7 +586,7 @@ def subplot_fit_log10(self, plane_index: Optional[int] = None): plane_index=plane_index, plane_image=True, zoom_to_brightest=False, - use_source_vmax=True + use_source_vmax=True, ) self.set_title(label=None) @@ -588,6 +599,40 @@ def subplot_fit_log10(self, plane_index: Optional[int] = None): self.mat_plot_2d.use_log10 = use_log10_original self.mat_plot_2d.contour = contour_original + def subplot_of_planes(self, plane_index: Optional[int] = None): + """ + Plots images representing each individual `Plane` in the plotter's `Tracer` in 2D on a subplot, which are + computed via the plotter's 2D grid object. + + These images subtract or omit the contribution of other planes in the plane, such that plots showing + each individual plane are made. + + The subplot plots the subtracted image, model image and plane image of each plane, where are described in the + `figures_2d_of_planes` function. + + Parameters + ---------- + plane_index + The index of the plane whose images are included on the subplot. + """ + + plane_indexes = self.plane_indexes_from(plane_index=plane_index) + + for plane_index in plane_indexes: + + self.open_subplot_figure(number_subplots=4) + + self.figures_2d(data=True) + + self.figures_2d_of_planes(subtracted_image=True, plane_index=plane_index) + self.figures_2d_of_planes(model_image=True, plane_index=plane_index) + self.figures_2d_of_planes(plane_image=True, plane_index=plane_index) + + self.mat_plot_2d.output.subplot_to_figure( + auto_filename=f"subplot_of_plane_{plane_index}" + ) + self.close_subplot_figure() + def subplot_tracer(self): """ Standard subplot of a Tracer. @@ -600,6 +645,7 @@ def subplot_tracer(self): ------- """ + use_log10_original = self.mat_plot_2d.use_log10 final_plane_index = len(self.fit.tracer.planes) - 1 @@ -609,7 +655,9 @@ def subplot_tracer(self): self.figures_2d(model_image=True) self.set_title(label="Lensed Source Image") - self.figures_2d_of_planes(plane_index=final_plane_index, model_image=True, use_source_vmax=True) + self.figures_2d_of_planes( + plane_index=final_plane_index, model_image=True, use_source_vmax=True + ) self.set_title(label=None) self.set_title(label="Source Plane") @@ -617,26 +665,21 @@ def subplot_tracer(self): plane_index=final_plane_index, plane_image=True, zoom_to_brightest=False, - use_source_vmax=True + use_source_vmax=True, ) - tracer_plotter = self.tracer_plotter - - include_tangential_critical_curves_original = tracer_plotter.include_2d._tangential_critical_curves - include_radial_critical_curves_original = tracer_plotter.include_2d._radial_critical_curves + tracer_plotter = self.tracer_plotter_of_plane(plane_index=0) tracer_plotter._subplot_lens_and_mass() - self.mat_plot_2d.output.subplot_to_figure( - auto_filename="subplot_tracer" - ) + self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_tracer") self.close_subplot_figure() - self.include_2d._tangential_critical_curves = include_tangential_critical_curves_original - self.include_2d._radial_critical_curves = include_radial_critical_curves_original self.mat_plot_2d.use_log10 = use_log10_original - def subplot_mappings_of_plane(self, plane_index: Optional[int] = None, auto_filename: str = "subplot_mappings"): + def subplot_mappings_of_plane( + self, plane_index: Optional[int] = None, auto_filename: str = "subplot_mappings" + ): try: @@ -658,34 +701,32 @@ def subplot_mappings_of_plane(self, plane_index: Optional[int] = None, auto_file "total_mappings_pixels" ] - mapper = inversion_plotter.inversion.cls_list_from(cls=aa.AbstractMapper)[0] + mapper = inversion_plotter.inversion.cls_list_from( + cls=aa.AbstractMapper + )[0] mapper_valued = aa.MapperValued( values=inversion_plotter.inversion.reconstruction_dict[mapper], mapper=mapper, ) + pix_indexes = mapper_valued.max_pixel_list_from( total_pixels=total_pixels, filter_neighbors=True ) - inversion_plotter.visuals_2d.pix_indexes = [ - [index] for index in pix_indexes[pixelization_index] - ] + indexes = mapper.slim_indexes_for_pix_indexes(pix_indexes=pix_indexes) - inversion_plotter.visuals_2d.tangential_critical_curves = None - inversion_plotter.visuals_2d.radial_critical_curves = None + inversion_plotter.visuals_2d.indexes = indexes inversion_plotter.figures_2d_of_pixelization( pixelization_index=pixelization_index, reconstructed_image=True ) - self.visuals_2d.pix_indexes = [ + self.visuals_2d.source_plane_mesh_indexes = [ [index] for index in pix_indexes[pixelization_index] ] self.figures_2d_of_planes( - plane_index=plane_index, - plane_image=True, - use_source_vmax=True + plane_index=plane_index, plane_image=True, use_source_vmax=True ) self.set_title(label="Source Reconstruction (Unzoomed)") @@ -693,11 +734,11 @@ def subplot_mappings_of_plane(self, plane_index: Optional[int] = None, auto_file plane_index=plane_index, plane_image=True, zoom_to_brightest=False, - use_source_vmax=True + use_source_vmax=True, ) self.set_title(label=None) - self.visuals_2d.pix_indexes = None + self.visuals_2d.source_plane_mesh_indexes = None inversion_plotter.mat_plot_2d.output.subplot_to_figure( auto_filename=f"{auto_filename}_{pixelization_index}" @@ -719,7 +760,7 @@ def figures_2d( normalized_residual_map: bool = False, chi_squared_map: bool = False, residual_flux_fraction_map: bool = False, - use_source_vmax : bool = False, + use_source_vmax: bool = False, suffix: str = "", ): """ @@ -751,25 +792,19 @@ def figures_2d( certain plots (e.g. the `data`) in order to ensure the lensed source is visible compared to the lens. """ - visuals_2d = self.get_visuals_2d() - - visuals_2d_no_critical_caustic = self.get_visuals_2d() - visuals_2d_no_critical_caustic.tangential_critical_curves = None - visuals_2d_no_critical_caustic.radial_critical_curves = None - visuals_2d_no_critical_caustic.tangential_caustics = None - visuals_2d_no_critical_caustic.radial_caustics = None - visuals_2d_no_critical_caustic.origin = None - visuals_2d_no_critical_caustic.light_profile_centres = None - visuals_2d_no_critical_caustic.mass_profile_centres = None - if data: if use_source_vmax: - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max([model_image.array for model_image in self.fit.model_images_of_planes_list[1:]]) + self.mat_plot_2d.cmap.kwargs["vmax"] = np.max( + [ + model_image.array + for model_image in self.fit.model_images_of_planes_list[1:] + ] + ) self.mat_plot_2d.plot_array( array=self.fit.data, - visuals_2d=visuals_2d_no_critical_caustic, + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Data", filename=f"data{suffix}"), ) @@ -780,7 +815,7 @@ def figures_2d( self.mat_plot_2d.plot_array( array=self.fit.noise_map, - visuals_2d=visuals_2d_no_critical_caustic, + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Noise-Map", filename=f"noise_map{suffix}" ), @@ -790,20 +825,27 @@ def figures_2d( self.mat_plot_2d.plot_array( array=self.fit.signal_to_noise_map, - visuals_2d=visuals_2d_no_critical_caustic, + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( - title="Signal-To-Noise Map", cb_unit=" S/N", filename=f"signal_to_noise_map{suffix}" + title="Signal-To-Noise Map", + cb_unit=" S/N", + filename=f"signal_to_noise_map{suffix}", ), ) if model_image: if use_source_vmax: - self.mat_plot_2d.cmap.kwargs["vmax"] = np.max([model_image.array for model_image in self.fit.model_images_of_planes_list[1:]]) + self.mat_plot_2d.cmap.kwargs["vmax"] = np.max( + [ + model_image.array + for model_image in self.fit.model_images_of_planes_list[1:] + ] + ) self.mat_plot_2d.plot_array( array=self.fit.model_data, - visuals_2d=visuals_2d, + visuals_2d=self.visuals_2d_from(plane_index=0), auto_labels=AutoLabels( title="Model Image", filename=f"model_image{suffix}" ), @@ -822,7 +864,7 @@ def figures_2d( self.mat_plot_2d.plot_array( array=self.fit.residual_map, - visuals_2d=visuals_2d_no_critical_caustic, + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Residual Map", filename=f"residual_map{suffix}" ), @@ -832,7 +874,7 @@ def figures_2d( self.mat_plot_2d.plot_array( array=self.fit.normalized_residual_map, - visuals_2d=visuals_2d_no_critical_caustic, + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Normalized Residual Map", cb_unit=r" $\sigma$", @@ -846,9 +888,11 @@ def figures_2d( self.mat_plot_2d.plot_array( array=self.fit.chi_squared_map, - visuals_2d=visuals_2d_no_critical_caustic, + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( - title="Chi-Squared Map", cb_unit=r" $\chi^2$", filename=f"chi_squared_map{suffix}" + title="Chi-Squared Map", + cb_unit=r" $\chi^2$", + filename=f"chi_squared_map{suffix}", ), ) @@ -856,8 +900,9 @@ def figures_2d( self.mat_plot_2d.plot_array( array=self.fit.residual_flux_fraction_map, - visuals_2d=visuals_2d_no_critical_caustic, + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( - title="Residual Flux Fraction Map", filename=f"residual_flux_fraction_map{suffix}" + title="Residual Flux Fraction Map", + filename=f"residual_flux_fraction_map{suffix}", ), - ) \ No newline at end of file + ) diff --git a/autolens/interferometer/model/plotter_interface.py b/autolens/interferometer/model/plotter_interface.py index 827544d41..dd889a272 100644 --- a/autolens/interferometer/model/plotter_interface.py +++ b/autolens/interferometer/model/plotter_interface.py @@ -1,4 +1,6 @@ -from os import path +from typing import Optional + +import autoarray.plot as aplt from autogalaxy.interferometer.model.plotter_interface import ( PlotterInterfaceInterferometer as AgPlotterInterfaceInterferometer, @@ -21,6 +23,7 @@ class PlotterInterfaceInterferometer(PlotterInterface): def fit_interferometer( self, fit: FitInterferometer, + visuals_2d_of_planes_list: Optional[aplt.Visuals2D] = None, ): """ Visualizes a `FitInterferometer` object, which fits an interferometer dataset. @@ -48,7 +51,6 @@ def should_plot(name): fit_plotter = FitInterferometerPlotter( fit=fit, - include_2d=self.include_2d, mat_plot_1d=mat_plot_1d, mat_plot_2d=mat_plot_2d, ) @@ -67,9 +69,9 @@ def should_plot(name): fit_plotter = FitInterferometerPlotter( fit=fit, - include_2d=self.include_2d, mat_plot_1d=mat_plot_1d, mat_plot_2d=mat_plot_2d, + visuals_2d_of_planes_list=visuals_2d_of_planes_list, ) if plot_setting(section="inversion", name="subplot_mappings"): diff --git a/autolens/interferometer/model/visualizer.py b/autolens/interferometer/model/visualizer.py index 0b1820337..1914c0a96 100644 --- a/autolens/interferometer/model/visualizer.py +++ b/autolens/interferometer/model/visualizer.py @@ -4,6 +4,7 @@ from autolens.interferometer.model.plotter_interface import ( PlotterInterfaceInterferometer, ) +from autolens.lens import tracer_util from autogalaxy import exc @@ -109,22 +110,34 @@ def visualize( except exc.InversionException: return + visuals_2d_of_planes_list = tracer_util.visuals_2d_of_planes_list_from( + tracer=fit.tracer, grid=fit.grids.lp.mask.derive_grid.all_false + ) + + tracer = fit.tracer_linear_light_profiles_to_light_profiles + + zoom = ag.Zoom2D(mask=fit.dataset.real_space_mask) + + extent = zoom.extent_from(buffer=0) + shape_native = zoom.shape_native + + grid = ag.Grid2D.from_extent(extent=extent, shape_native=shape_native) + plotter_interface = PlotterInterfaceInterferometer( image_path=paths.image_path, title_prefix=analysis.title_prefix ) try: plotter_interface.fit_interferometer( - fit=fit, + fit=fit, visuals_2d_of_planes_list=visuals_2d_of_planes_list ) except exc.InversionException: pass - tracer = fit.tracer_linear_light_profiles_to_light_profiles - plotter_interface.tracer( tracer=tracer, - grid=fit.grids.lp, + grid=grid, + visuals_2d_of_planes_list=visuals_2d_of_planes_list, ) plotter_interface.galaxies( galaxies=tracer.galaxies, diff --git a/autolens/interferometer/plot/fit_interferometer_plotters.py b/autolens/interferometer/plot/fit_interferometer_plotters.py index 9694a8d80..2ca412515 100644 --- a/autolens/interferometer/plot/fit_interferometer_plotters.py +++ b/autolens/interferometer/plot/fit_interferometer_plotters.py @@ -6,12 +6,15 @@ import autogalaxy.plot as aplt from autoarray.fit.plot.fit_interferometer_plotters import FitInterferometerPlotterMeta +from autoarray.plot.auto_labels import AutoLabels from autolens.interferometer.fit_interferometer import FitInterferometer from autolens.lens.tracer import Tracer from autolens.lens.plot.tracer_plotters import TracerPlotter from autolens.plot.abstract_plotters import Plotter +from autolens.lens import tracer_util + class FitInterferometerPlotter(Plotter): def __init__( @@ -19,11 +22,10 @@ def __init__( fit: FitInterferometer, mat_plot_1d: aplt.MatPlot1D = None, visuals_1d: aplt.Visuals1D = None, - include_1d: aplt.Include1D = None, mat_plot_2d: aplt.MatPlot2D = None, visuals_2d: aplt.Visuals2D = None, - include_2d: aplt.Include2D = None, residuals_symmetric_cmap: bool = True, + visuals_2d_of_planes_list: Optional = None, ): """ Plots the attributes of `FitInterferometer` objects using the matplotlib method `imshow()` and many @@ -35,8 +37,7 @@ def __init__( customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `FitInterferometer` and plotted via the visuals object, if the corresponding entry is `True` in - the `Include1D` or `Include2D` object or the `config/visualize/include.ini` file. + extracted from the `FitInterferometer` and plotted via the visuals object. Parameters ---------- @@ -46,24 +47,18 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 1D plots. visuals_1d Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `FitInterferometer` are extracted and plotted as visuals for 1D plots. mat_plot_2d Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `FitInterferometer` are extracted and plotted as visuals for 2D plots. residuals_symmetric_cmap If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such that `abs(vmin) = abs(vmax)`. """ super().__init__( mat_plot_1d=mat_plot_1d, - include_1d=include_1d, visuals_1d=visuals_1d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, visuals_2d=visuals_2d, ) @@ -71,62 +66,83 @@ def __init__( self._fit_interferometer_meta_plotter = FitInterferometerPlotterMeta( fit=self.fit, - get_visuals_2d_real_space=self.get_visuals_2d_real_space, mat_plot_1d=self.mat_plot_1d, - include_1d=self.include_1d, visuals_1d=self.visuals_1d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, residuals_symmetric_cmap=residuals_symmetric_cmap, ) self.subplot = self._fit_interferometer_meta_plotter.subplot - # self.subplot_fit = self._fit_interferometer_meta_plotter.subplot_fit self.subplot_fit_dirty_images = ( self._fit_interferometer_meta_plotter.subplot_fit_dirty_images ) - def get_visuals_2d_real_space(self) -> aplt.Visuals2D: - return self.get_2d.via_mask_from(mask=self.fit.dataset.real_space_mask) + self._visuals_2d_of_planes_list = visuals_2d_of_planes_list - def plane_indexes_from(self, plane_index: int): + @property + def visuals_2d_of_planes_list(self): + + if self._visuals_2d_of_planes_list is None: + self._visuals_2d_of_planes_list = ( + tracer_util.visuals_2d_of_planes_list_from( + tracer=self.fit.tracer, + grid=self.fit.grids.lp.mask.derive_grid.all_false, + ) + ) + + return self._visuals_2d_of_planes_list + + def visuals_2d_from( + self, plane_index: Optional[int] = None, remove_critical_caustic: bool = False + ) -> aplt.Visuals2D: """ - Returns a list of all indexes of the planes in the fit, which is iterated over in figures that plot - individual figures of each plane in a tracer. + Returns the `Visuals2D` of the plotter with critical curves and caustics added, which are used to plot + the critical curves and caustics of the `Tracer` object. + + If `remove_critical_caustic` is `True`, critical curves and caustics are not included in the visuals. Parameters ---------- plane_index - A specific plane index which when input means that only a single plane index is returned. - - Returns - ------- - list - A list of galaxy indexes corresponding to planes in the plane. + The index of the plane in the tracer which is used to extract quantities, as only one plane is plotted + at a time. + remove_critical_caustic + Whether to remove critical curves and caustics from the visuals. """ - if plane_index is None: - return range(len(self.fit.tracer.planes)) - return [plane_index] + if remove_critical_caustic: + return self.visuals_2d + + return self.visuals_2d + self.visuals_2d_of_planes_list[plane_index] @property def tracer(self) -> Tracer: return self.fit.tracer_linear_light_profiles_to_light_profiles - @property - def tracer_plotter(self) -> TracerPlotter: + def tracer_plotter_of_plane( + self, plane_index: int, remove_critical_caustic: bool = False + ) -> TracerPlotter: """ Returns an `TracerPlotter` corresponding to the `Tracer` in the `FitImaging`. """ + + zoom = aa.Zoom2D(mask=self.fit.dataset.real_space_mask) + + grid = aa.Grid2D.from_extent( + extent=zoom.extent_from(buffer=0), shape_native=zoom.shape_native + ) return TracerPlotter( tracer=self.tracer, - grid=self.fit.grids.lp, + grid=grid, mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - include_2d=self.include_2d, + visuals_2d=self.visuals_2d_from( + plane_index=plane_index, remove_critical_caustic=remove_critical_caustic + ), ) - def inversion_plotter_of_plane(self, plane_index: int) -> aplt.InversionPlotter: + def inversion_plotter_of_plane( + self, plane_index: int, remove_critical_caustic: bool = False + ) -> aplt.InversionPlotter: """ Returns an `InversionPlotter` corresponding to one of the `Inversion`'s in the fit, which is specified via the index of the `Plane` that inversion was performed on. @@ -141,138 +157,34 @@ def inversion_plotter_of_plane(self, plane_index: int) -> aplt.InversionPlotter: InversionPlotter An object that plots inversions which is used for plotting attributes of the inversion. """ - return aplt.InversionPlotter( + + inversion_plotter = aplt.InversionPlotter( inversion=self.fit.inversion, mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.tracer_plotter.get_visuals_2d_of_plane( - plane_index=plane_index + visuals_2d=self.visuals_2d_from( + plane_index=plane_index, remove_critical_caustic=remove_critical_caustic ), - include_2d=self.include_2d, ) + return inversion_plotter - def subplot_fit(self): - """ - Standard subplot of the attributes of the plotter's `FitImaging` object. + def plane_indexes_from(self, plane_index: int): """ + Returns a list of all indexes of the planes in the fit, which is iterated over in figures that plot + individual figures of each plane in a tracer. - self.open_subplot_figure(number_subplots=12) - - self.figures_2d(amplitudes_vs_uv_distances=True) - - self.mat_plot_1d.subplot_index = 2 - self.mat_plot_2d.subplot_index = 2 - - self.figures_2d(dirty_image=True) - self.figures_2d(dirty_signal_to_noise_map=True) - self.figures_2d(dirty_model_image=True) - self.figures_2d(image=True) - - self.mat_plot_1d.subplot_index = 6 - self.mat_plot_2d.subplot_index = 6 - - self.figures_2d(normalized_residual_map_real=True) - self.figures_2d(normalized_residual_map_imag=True) - - self.mat_plot_1d.subplot_index = 8 - self.mat_plot_2d.subplot_index = 8 - - final_plane_index = len(self.fit.tracer.planes) - 1 - - self.set_title(label="Source Plane (Zoomed)") - self.figures_2d_of_planes(plane_index=final_plane_index, plane_image=True) - self.set_title(label=None) - - self.figures_2d(dirty_normalized_residual_map=True) - - self.mat_plot_2d.cmap.kwargs["vmin"] = -1.0 - self.mat_plot_2d.cmap.kwargs["vmax"] = 1.0 - - self.set_title(label="Normalized Residual Map (1 sigma)") - self.figures_2d(dirty_normalized_residual_map=True) - self.set_title(label=None) - - self.mat_plot_2d.cmap.kwargs.pop("vmin") - self.mat_plot_2d.cmap.kwargs.pop("vmax") - - self.figures_2d(dirty_chi_squared_map=True) - - self.set_title(label="Source Plane (No Zoom)") - self.figures_2d_of_planes( - plane_index=final_plane_index, - plane_image=True, - zoom_to_brightest=False, - ) - - self.set_title(label=None) - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_fit") - self.close_subplot_figure() - - def subplot_mappings_of_plane( - self, plane_index: Optional[int] = None, auto_filename: str = "subplot_mappings" - ): - if self.fit.inversion is None: - return - - plane_indexes = self.plane_indexes_from(plane_index=plane_index) - - for plane_index in plane_indexes: - pixelization_index = 0 - - inversion_plotter = self.inversion_plotter_of_plane(plane_index=0) - - inversion_plotter.open_subplot_figure(number_subplots=4) - - self.figures_2d(dirty_image=True) - - total_pixels = conf.instance["visualize"]["general"]["inversion"][ - "total_mappings_pixels" - ] - - mapper = inversion_plotter.inversion.cls_list_from(cls=aa.AbstractMapper)[0] - mapper_valued = aa.MapperValued( - values=inversion_plotter.inversion.reconstruction_dict[mapper], - mapper=mapper, - ) - pix_indexes = mapper_valued.max_pixel_list_from( - total_pixels=total_pixels, filter_neighbors=True - ) - - inversion_plotter.visuals_2d.pix_indexes = [ - [index] for index in pix_indexes[pixelization_index] - ] - - inversion_plotter.visuals_2d.tangential_critical_curves = None - inversion_plotter.visuals_2d.radial_critical_curves = None - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=pixelization_index, reconstructed_image=True - ) - - self.visuals_2d.pix_indexes = [ - [index] for index in pix_indexes[pixelization_index] - ] - - self.figures_2d_of_planes( - plane_index=plane_index, - plane_image=True, - ) - - self.set_title(label="Source Reconstruction (Unzoomed)") - self.figures_2d_of_planes( - plane_index=plane_index, - plane_image=True, - zoom_to_brightest=False, - ) - self.set_title(label=None) - - self.visuals_2d.pix_indexes = None - - inversion_plotter.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"{auto_filename}_{pixelization_index}" - ) + Parameters + ---------- + plane_index + A specific plane index which when input means that only a single plane index is returned. - inversion_plotter.close_subplot_figure() + Returns + ------- + list + A list of galaxy indexes corresponding to planes in the plane. + """ + if plane_index is None: + return range(len(self.fit.tracer.planes)) + return [plane_index] def figures_2d( self, @@ -354,7 +266,6 @@ def figures_2d( dirty_image=dirty_image, dirty_noise_map=dirty_noise_map, dirty_signal_to_noise_map=dirty_signal_to_noise_map, - dirty_model_image=dirty_model_image, dirty_residual_map=dirty_residual_map, dirty_normalized_residual_map=dirty_normalized_residual_map, dirty_chi_squared_map=dirty_chi_squared_map, @@ -364,13 +275,26 @@ def figures_2d( plane_index = len(self.tracer.planes) - 1 if not self.tracer.planes[plane_index].has(cls=aa.Pixelization): - self.tracer_plotter.figures_2d(image=True) + + tracer_plotter = self.tracer_plotter_of_plane(plane_index=plane_index) + + tracer_plotter.figures_2d(image=True) + elif self.tracer.planes[plane_index].has(cls=aa.Pixelization): inversion_plotter = self.inversion_plotter_of_plane( plane_index=plane_index ) inversion_plotter.figures_2d(reconstructed_image=True) + if dirty_model_image: + self.mat_plot_2d.plot_array( + array=self.fit.dirty_model_image, + visuals_2d=self.visuals_2d_of_planes_list[0], + auto_labels=AutoLabels( + title="Dirty Model Image", filename="dirty_model_image_2d" + ), + ) + def figures_2d_of_planes( self, plane_index: Optional[int] = None, @@ -415,7 +339,10 @@ def figures_2d_of_planes( """ if plane_image: if not self.tracer.planes[plane_index].has(cls=aa.Pixelization): - self.tracer_plotter.figures_2d_of_planes( + + tracer_plotter = self.tracer_plotter_of_plane(plane_index=plane_index) + + tracer_plotter.figures_2d_of_planes( plane_image=True, plane_index=plane_index, zoom_to_brightest=zoom_to_brightest, @@ -456,6 +383,130 @@ def figures_2d_of_planes( interpolate_to_uniform=interpolate_to_uniform, ) + def subplot_fit(self): + """ + Standard subplot of the attributes of the plotter's `FitImaging` object. + """ + + self.open_subplot_figure(number_subplots=12) + + self.figures_2d(amplitudes_vs_uv_distances=True) + + self.mat_plot_1d.subplot_index = 2 + self.mat_plot_2d.subplot_index = 2 + + self.figures_2d(dirty_image=True) + self.figures_2d(dirty_signal_to_noise_map=True) + self.figures_2d(dirty_model_image=True) + self.figures_2d(image=True) + + self.mat_plot_1d.subplot_index = 6 + self.mat_plot_2d.subplot_index = 6 + + self.figures_2d(normalized_residual_map_real=True) + self.figures_2d(normalized_residual_map_imag=True) + + self.mat_plot_1d.subplot_index = 8 + self.mat_plot_2d.subplot_index = 8 + + final_plane_index = len(self.fit.tracer.planes) - 1 + + self.set_title(label="Source Plane (Zoomed)") + self.figures_2d_of_planes(plane_index=final_plane_index, plane_image=True) + self.set_title(label=None) + + self.figures_2d(dirty_normalized_residual_map=True) + + self.mat_plot_2d.cmap.kwargs["vmin"] = -1.0 + self.mat_plot_2d.cmap.kwargs["vmax"] = 1.0 + + self.set_title(label="Normalized Residual Map (1 sigma)") + self.figures_2d(dirty_normalized_residual_map=True) + self.set_title(label=None) + + self.mat_plot_2d.cmap.kwargs.pop("vmin") + self.mat_plot_2d.cmap.kwargs.pop("vmax") + + self.figures_2d(dirty_chi_squared_map=True) + + self.set_title(label="Source Plane (No Zoom)") + self.figures_2d_of_planes( + plane_index=final_plane_index, + plane_image=True, + zoom_to_brightest=False, + ) + + self.set_title(label=None) + + self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_fit") + self.close_subplot_figure() + + def subplot_mappings_of_plane( + self, plane_index: Optional[int] = None, auto_filename: str = "subplot_mappings" + ): + if self.fit.inversion is None: + return + + plane_indexes = self.plane_indexes_from(plane_index=plane_index) + + for plane_index in plane_indexes: + pixelization_index = 0 + + inversion_plotter = self.inversion_plotter_of_plane(plane_index=0) + + inversion_plotter.open_subplot_figure(number_subplots=4) + + self.figures_2d(dirty_image=True) + + total_pixels = conf.instance["visualize"]["general"]["inversion"][ + "total_mappings_pixels" + ] + + mapper = inversion_plotter.inversion.cls_list_from(cls=aa.AbstractMapper)[0] + mapper_valued = aa.MapperValued( + values=inversion_plotter.inversion.reconstruction_dict[mapper], + mapper=mapper, + ) + pix_indexes = mapper_valued.max_pixel_list_from( + total_pixels=total_pixels, filter_neighbors=True + ) + + inversion_plotter.visuals_2d.source_plane_mesh_indexes = [ + [index] for index in pix_indexes[pixelization_index] + ] + + inversion_plotter.visuals_2d.tangential_critical_curves = None + inversion_plotter.visuals_2d.radial_critical_curves = None + + inversion_plotter.figures_2d_of_pixelization( + pixelization_index=pixelization_index, reconstructed_image=True + ) + + self.visuals_2d.source_plane_mesh_indexes = [ + [index] for index in pix_indexes[pixelization_index] + ] + + self.figures_2d_of_planes( + plane_index=plane_index, + plane_image=True, + ) + + self.set_title(label="Source Reconstruction (Unzoomed)") + self.figures_2d_of_planes( + plane_index=plane_index, + plane_image=True, + zoom_to_brightest=False, + ) + self.set_title(label=None) + + self.visuals_2d.source_plane_mesh_indexes = None + + inversion_plotter.mat_plot_2d.output.subplot_to_figure( + auto_filename=f"{auto_filename}_{pixelization_index}" + ) + + inversion_plotter.close_subplot_figure() + def subplot_fit_real_space(self): """ Standard subplot of the real-space attributes of the plotter's `FitInterferometer` object. @@ -464,7 +515,10 @@ def subplot_fit_real_space(self): different methods are called to create these real-space images. """ if self.fit.inversion is None: - self.tracer_plotter.subplot( + + tracer_plotter = self.tracer_plotter_of_plane(plane_index=0) + + tracer_plotter.subplot( image=True, source_plane=True, auto_filename="subplot_fit_real_space" ) diff --git a/autolens/lens/plot/tracer_plotters.py b/autolens/lens/plot/tracer_plotters.py index d894afff3..7bf0ec3c4 100644 --- a/autolens/lens/plot/tracer_plotters.py +++ b/autolens/lens/plot/tracer_plotters.py @@ -11,6 +11,8 @@ from autolens import exc +from autolens.lens import tracer_util + class TracerPlotter(Plotter): def __init__( @@ -19,10 +21,9 @@ def __init__( grid: aa.type.Grid2DLike, mat_plot_1d: aplt.MatPlot1D = None, visuals_1d: aplt.Visuals1D = None, - include_1d: aplt.Include1D = None, mat_plot_2d: aplt.MatPlot2D = None, visuals_2d: aplt.Visuals2D = None, - include_2d: aplt.Include2D = None, + visuals_2d_of_planes_list: Optional = None, ): """ Plots the attributes of `Tracer` objects using the matplotlib methods `plot()` and `imshow()` and many @@ -34,8 +35,7 @@ def __init__( customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `MassProfile` and plotted via the visuals object, if the corresponding entry is `True` in - the `Include1D` or `Include2D` object or the `config/visualize/include.ini` file. + extracted from the `MassProfile` and plotted via the visuals object. Parameters ---------- @@ -47,16 +47,11 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 1D plots. visuals_1d Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `MassProfile` are extracted and plotted as visuals for 1D plots. mat_plot_2d Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `MassProfile` are extracted and plotted as visuals for 2D plots. """ - from autogalaxy.profiles.light.linear import ( LightProfileLinear, ) @@ -69,9 +64,7 @@ def __init__( super().__init__( mat_plot_1d=mat_plot_1d, visuals_1d=visuals_1d, - include_1d=include_1d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, visuals_2d=visuals_2d, ) @@ -81,21 +74,27 @@ def __init__( self._mass_plotter = MassPlotter( mass_obj=self.tracer, grid=self.grid, - get_visuals_2d=self.get_visuals_2d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) - def get_visuals_2d(self) -> aplt.Visuals2D: - return self.get_visuals_2d_of_plane(plane_index=0) + self._visuals_2d_of_planes_list = visuals_2d_of_planes_list - def get_visuals_2d_of_plane(self, plane_index: int) -> aplt.Visuals2D: - return self.get_2d.via_tracer_from( - tracer=self.tracer, grid=self.grid, plane_index=plane_index - ) + @property + def visuals_2d_of_planes_list(self): + + if self._visuals_2d_of_planes_list is None: + self._visuals_2d_of_planes_list = ( + tracer_util.visuals_2d_of_planes_list_from( + tracer=self.tracer, grid=self.grid + ) + ) - def galaxies_plotter_from(self, plane_index: int) -> aplt.GalaxiesPlotter: + return self._visuals_2d_of_planes_list + + def galaxies_plotter_from( + self, plane_index: int, retain_visuals: bool = False + ) -> aplt.GalaxiesPlotter: """ Returns an `GalaxiesPlotter` corresponding to a `Plane` in the `Tracer`. @@ -104,14 +103,24 @@ def galaxies_plotter_from(self, plane_index: int) -> aplt.GalaxiesPlotter: plane_index The index of the plane in the `Tracer` used to make the `GalaxiesPlotter`. """ + plane_grid = self.tracer.traced_grid_2d_list_from(grid=self.grid)[plane_index] + if retain_visuals: + + visuals_2d = self.visuals_2d + + else: + + visuals_2d = self.visuals_2d.add_critical_curves_or_caustics( + mass_obj=self.tracer, grid=self.grid, plane_index=plane_index + ) + return aplt.GalaxiesPlotter( galaxies=ag.Galaxies(galaxies=self.tracer.planes[plane_index]), grid=plane_grid, mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.get_visuals_2d_of_plane(plane_index=plane_index), - include_2d=self.include_2d, + visuals_2d=visuals_2d, ) def figures_2d( @@ -154,7 +163,7 @@ def figures_2d( if image: self.mat_plot_2d.plot_array( array=self.tracer.image_2d_from(grid=self.grid), - visuals_2d=self.get_visuals_2d(), + visuals_2d=self._mass_plotter.visuals_2d_with_critical_curves, auto_labels=aplt.AutoLabels(title="Image", filename="image_2d"), ) @@ -196,6 +205,7 @@ def figures_2d_of_planes( plane_grid: bool = False, plane_index: Optional[int] = None, zoom_to_brightest: bool = True, + retain_visuals: bool = False, ): """ Plots source-plane images (e.g. the unlensed light) each individual `Plane` in the plotter's `Tracer` in 2D, @@ -221,7 +231,9 @@ def figures_2d_of_planes( plane_indexes = self.plane_indexes_from(plane_index=plane_index) for plane_index in plane_indexes: - galaxies_plotter = self.galaxies_plotter_from(plane_index=plane_index) + galaxies_plotter = self.galaxies_plotter_from( + plane_index=plane_index, retain_visuals=retain_visuals + ) if plane_index == 1: source_plane_title = True @@ -324,40 +336,31 @@ def subplot_tracer(self): self.figures_2d(source_plane=True) self.set_title(label=None) - include_tangential_critical_curves_original = ( - self.include_2d._tangential_critical_curves - ) - include_radial_critical_curves_original = ( - self.include_2d._radial_critical_curves - ) - self._subplot_lens_and_mass() self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_tracer") self.close_subplot_figure() - self.include_2d._tangential_critical_curves = ( - include_tangential_critical_curves_original - ) - self.include_2d._radial_critical_curves = ( - include_radial_critical_curves_original - ) self.mat_plot_2d.use_log10 = use_log10_original def _subplot_lens_and_mass(self): + self.mat_plot_2d.use_log10 = True - self.include_2d._tangential_critical_curves = False - self.include_2d._radial_critical_curves = False self.set_title(label="Lens Galaxy Image") + self.figures_2d_of_planes( - plane_image=True, plane_index=0, zoom_to_brightest=False + plane_image=True, + plane_index=0, + zoom_to_brightest=False, + retain_visuals=True, ) self.mat_plot_2d.subplot_index = 5 self.set_title(label=None) self.figures_2d(convergence=True) + self.figures_2d(potential=True) self.mat_plot_2d.use_log10 = False diff --git a/autolens/lens/sensitivity.py b/autolens/lens/sensitivity.py index 1777293bc..fd44f450e 100644 --- a/autolens/lens/sensitivity.py +++ b/autolens/lens/sensitivity.py @@ -148,7 +148,6 @@ def __init__( data_subtracted: Optional[aa.Array2D] = None, mat_plot_2d: aplt.MatPlot2D = None, visuals_2d: aplt.Visuals2D = None, - include_2d: aplt.Include2D = None, ): """ Plots the simulated datasets and results of a sensitivity mapping analysis, where dark matter halos are used @@ -160,8 +159,7 @@ def __init__( customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `MassProfile` and plotted via the visuals object, if the corresponding entry is `True` in - the `Include1D` or `Include2D` object or the `config/visualize/include.ini` file. + extracted from the `MassProfile` and plotted via the visuals object. Parameters ---------- @@ -173,19 +171,13 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 1D plots. visuals_1d Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `MassProfile` are extracted and plotted as visuals for 1D plots. mat_plot_2d Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `MassProfile` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.mask = mask self.tracer_perturb = tracer_perturb @@ -195,7 +187,6 @@ def __init__( self.data_subtracted = data_subtracted self.mat_plot_2d = mat_plot_2d self.visuals_2d = visuals_2d - self.include_2d = include_2d def update_mat_plot_array_overlay(self, evidence_max): evidence_half = evidence_max / 2.0 diff --git a/autolens/lens/subhalo.py b/autolens/lens/subhalo.py index c5baf7639..03c0b4a2e 100644 --- a/autolens/lens/subhalo.py +++ b/autolens/lens/subhalo.py @@ -175,7 +175,6 @@ def __init__( fit_imaging_no_subhalo: Optional[FitImaging] = None, mat_plot_2d: aplt.MatPlot2D = None, visuals_2d: aplt.Visuals2D = None, - include_2d: aplt.Include2D = None, ): """ Plots the results of scanning for a dark matter subhalo in strong lens imaging. @@ -206,12 +205,8 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `MassProfile` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.result = result @@ -245,7 +240,6 @@ def fit_imaging_no_subhalo_plotter(self) -> FitImagingPlotter: fit=self.fit_imaging_no_subhalo, mat_plot_2d=self.mat_plot_2d, visuals_2d=self.visuals_2d, - include_2d=self.include_2d, ) @property @@ -275,7 +269,6 @@ def fit_imaging_with_subhalo_plotter_from(self, visuals_2d) -> FitImagingPlotter fit=self.fit_imaging_with_subhalo, mat_plot_2d=self.mat_plot_2d, visuals_2d=visuals_2d, - include_2d=self.include_2d, ) def set_auto_filename( diff --git a/autolens/lens/tracer_util.py b/autolens/lens/tracer_util.py index 24f3276e6..6cbdbdac9 100644 --- a/autolens/lens/tracer_util.py +++ b/autolens/lens/tracer_util.py @@ -3,6 +3,7 @@ import autoarray as aa import autogalaxy as ag +import autogalaxy.plot as aplt from autolens import exc @@ -398,3 +399,20 @@ def ordered_plane_redshifts_with_slicing_from( )[1:] return plane_redshifts[0:-1] + + +def visuals_2d_of_planes_list_from(tracer, grid) -> aplt.Visuals2D: + + visuals_2d_of_planes_list = [] + + for plane_index in range(len(tracer.planes)): + + visuals_2d_of_planes_list.append( + aplt.Visuals2D().add_critical_curves_or_caustics( + mass_obj=tracer, + grid=grid, + plane_index=plane_index, + ) + ) + + return visuals_2d_of_planes_list diff --git a/autolens/plot/__init__.py b/autolens/plot/__init__.py index 94b640250..00ccff679 100644 --- a/autolens/plot/__init__.py +++ b/autolens/plot/__init__.py @@ -28,6 +28,7 @@ GridPlot, VectorYXQuiver, PatchOverlay, + DelaunayDrawer, VoronoiDrawer, OriginScatter, MaskScatter, @@ -66,8 +67,6 @@ from autogalaxy.plot.mat_plot.one_d import MatPlot1D from autogalaxy.plot.mat_plot.two_d import MatPlot2D -from autogalaxy.plot.include.one_d import Include1D -from autogalaxy.plot.include.two_d import Include2D from autogalaxy.plot.visuals.one_d import Visuals1D from autogalaxy.plot.visuals.two_d import Visuals2D diff --git a/autolens/plot/abstract_plotters.py b/autolens/plot/abstract_plotters.py index 42ce27272..622f2ea95 100644 --- a/autolens/plot/abstract_plotters.py +++ b/autolens/plot/abstract_plotters.py @@ -2,17 +2,12 @@ set_backend() -from autolens.plot.get_visuals.one_d import GetVisuals1D -from autolens.plot.get_visuals.two_d import GetVisuals2D - from autogalaxy.plot.abstract_plotters import AbstractPlotter from autogalaxy.plot.mat_plot.one_d import MatPlot1D from autogalaxy.plot.mat_plot.two_d import MatPlot2D from autogalaxy.plot.visuals.one_d import Visuals1D from autogalaxy.plot.visuals.two_d import Visuals2D -from autogalaxy.plot.include.one_d import Include1D -from autogalaxy.plot.include.two_d import Include2D class Plotter(AbstractPlotter): @@ -21,33 +16,19 @@ def __init__( self, mat_plot_1d: MatPlot1D = None, visuals_1d: Visuals1D = None, - include_1d: Include1D = None, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): super().__init__( mat_plot_1d=mat_plot_1d, visuals_1d=visuals_1d, - include_1d=include_1d, mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d, - include_2d=include_2d, ) self.visuals_1d = visuals_1d or Visuals1D() - self.include_1d = include_1d or Include1D() self.mat_plot_1d = mat_plot_1d or MatPlot1D() self.visuals_2d = visuals_2d or Visuals2D() - self.include_2d = include_2d or Include2D() self.mat_plot_2d = mat_plot_2d or MatPlot2D() - - @property - def get_1d(self): - return GetVisuals1D(visuals=self.visuals_1d, include=self.include_1d) - - @property - def get_2d(self): - return GetVisuals2D(visuals=self.visuals_2d, include=self.include_2d) diff --git a/autolens/plot/get_visuals/__init__.py b/autolens/plot/get_visuals/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/autolens/plot/get_visuals/one_d.py b/autolens/plot/get_visuals/one_d.py deleted file mode 100644 index b9e66fb39..000000000 --- a/autolens/plot/get_visuals/one_d.py +++ /dev/null @@ -1,26 +0,0 @@ -import autogalaxy.plot as aplt - -from autogalaxy.plot.get_visuals import one_d as gv1d - - -class GetVisuals1D(gv1d.GetVisuals1D): - def __init__(self, include: aplt.Include1D, visuals: aplt.Visuals1D): - """ - Class which gets 1D attributes and adds them to a `Visuals1D` objects, such that they are plotted on 1D figures. - - For a visual to be extracted and added for plotting, it must have a `True` value in its corresponding entry in - the `Include1D` object. If this entry is `False`, the `GetVisuals1D.get` method returns a None and the attribute - is omitted from the plot. - - The `GetVisuals1D` class adds new visuals to a pre-existing `Visuals1D` object that is passed to its `__init__` - method. This only adds a new entry if the visual are not already in this object. - - Parameters - ---------- - include - Sets which 1D visuals are included on the figure that is to be plotted (only entries which are `True` - are extracted via the `GetVisuals1D` object). - visuals - The pre-existing visuals of the plotter which new visuals are added too via the `GetVisuals1D` class. - """ - super().__init__(include=include, visuals=visuals) diff --git a/autolens/plot/get_visuals/two_d.py b/autolens/plot/get_visuals/two_d.py deleted file mode 100644 index 9b1d40b54..000000000 --- a/autolens/plot/get_visuals/two_d.py +++ /dev/null @@ -1,175 +0,0 @@ -import autoarray as aa -import autogalaxy as ag -import autogalaxy.plot as aplt - -from autogalaxy.plot.get_visuals import two_d as gv2d - -from autolens.imaging.fit_imaging import FitImaging -from autolens.lens.tracer import Tracer - - -class GetVisuals2D(gv2d.GetVisuals2D): - def __init__(self, include: aplt.Include2D, visuals: aplt.Visuals2D): - """ - Class which gets 2D attributes and adds them to a `Visuals2D` objects, such that they are plotted on 2D figures. - - For a visual to be extracted and added for plotting, it must have a `True` value in its corresponding entry in - the `Include2D` object. If this entry is `False`, the `GetVisuals2D.get` method returns a None and the - attribute is omitted from the plot. - - The `GetVisuals2D` class adds new visuals to a pre-existing `Visuals2D` object that is passed to - its `__init__` method. This only adds a new entry if the visual are not already in this object. - - Parameters - ---------- - include - Sets which 2D visuals are included on the figure that is to be plotted (only entries which are `True` - are extracted via the `GetVisuals2D` object). - visuals - The pre-existing visuals of the plotter which new visuals are added too via the `GetVisuals2D` class. - """ - super().__init__(include=include, visuals=visuals) - - def via_tracer_from( - self, tracer: Tracer, grid: aa.type.Grid2DLike, plane_index: int - ) -> aplt.Visuals2D: - """ - From a `Tracer` get the attributes that can be plotted and returns them in a `Visuals2D` object. - - Only attributes with `True` entries in the `Include` object are extracted. - - From a tracer the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the coordinate system used to plot the light object's quantities in 2D. - - border: the border of the mask of the grid used to plot the light object's quantities in 2D. - - light profile centres: the (y,x) centre of every `LightProfile` in the object. - - mass profile centres: the (y,x) centre of every `MassProfile` in the object. - - tangential_critical curves: the tangential critical curves of all of the tracer's mass profiles combined. - - tangential_caustics: the tangential caustics of all of the tracer's mass profiles combined. - - radial_critical curves: the radial critical curves of all of the tracer's mass profiles combined. - - radial_caustics: the radial caustics of all of the tracer's mass profiles combined. - - When plotting a `Tracer` it is common for plots to only display quantities corresponding to one plane at a time - (e.g. the convergence in the image plane, the source in the source plane). Therefore, quantities are only - extracted from one plane, specified by the input `plane_index`. - - Parameters - ---------- - tracer - The `Tracer` object which has attributes extracted for plotting. - grid - The 2D grid of (y,x) coordinates used to plot the tracer's quantities in 2D. - plane_index - The index of the plane in the tracer which is used to extract quantities, as only one plane is plotted - at a time. - - Returns - ------- - vis.Visuals2D - A collection of attributes that can be plotted by a `Plotter` object. - """ - origin = self.get("origin", value=aa.Grid2DIrregular(values=[grid.origin])) - - border = self.get("border", value=grid.mask.derive_grid.border) - - if border is not None and len(border) > 0 and plane_index > 0: - border = tracer.traced_grid_2d_list_from(grid=border)[plane_index] - - light_profile_centres = self.get( - "light_profile_centres", - tracer.planes[plane_index].extract_attribute( - cls=ag.LightProfile, attr_name="centre" - ), - ) - - mass_profile_centres = self.get( - "mass_profile_centres", - tracer.planes[plane_index].extract_attribute( - cls=ag.mp.MassProfile, attr_name="centre" - ), - ) - - tangential_critical_curves = None - radial_critical_curves = None - tangential_caustics = None - radial_caustics = None - - if plane_index == 0: - tangential_critical_curves = self.get( - "tangential_critical_curves", - tracer.tangential_critical_curve_list_from(grid=grid), - "tangential_critical_curves", - ) - - radial_critical_curves = None - - radial_critical_curve_area_list = ( - tracer.radial_critical_curve_area_list_from(grid=grid) - ) - - if any( - [area > grid.pixel_scale for area in radial_critical_curve_area_list] - ): - radial_critical_curves = self.get( - "radial_critical_curves", - tracer.radial_critical_curve_list_from(grid=grid), - "radial_critical_curves", - ) - - if plane_index > 0: - tangential_caustics = self.get( - "tangential_caustics", - tracer.tangential_caustic_list_from(grid=grid), - "tangential_caustics", - ) - - radial_caustics = self.get( - "radial_caustics", - tracer.radial_caustic_list_from(grid=grid), - "radial_caustics", - ) - - return self.visuals + self.visuals.__class__( - origin=origin, - border=border, - light_profile_centres=light_profile_centres, - mass_profile_centres=mass_profile_centres, - tangential_critical_curves=tangential_critical_curves, - tangential_caustics=tangential_caustics, - radial_critical_curves=radial_critical_curves, - radial_caustics=radial_caustics, - ) - - def via_fit_imaging_from(self, fit: FitImaging) -> aplt.Visuals2D: - """ - From a `FitImaging` get its attributes that can be plotted and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `FitImaging` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the 2D coordinate system. - - mask: the 2D mask. - - border: the border of the 2D mask, which are all of the mask's exterior edge pixels. - - light profile centres: the (y,x) centre of every `LightProfile` in the object. - - mass profile centres: the (y,x) centre of every `MassProfile` in the object. - - critical curves: the critical curves of all mass profile combined. - - Parameters - ---------- - fit - The fit imaging object whose attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that are plotted by a `Plotter` object. - """ - visuals_2d_via_mask = self.via_mask_from(mask=fit.mask) - - visuals_2d_via_tracer = self.via_tracer_from( - tracer=fit.tracer, grid=fit.grids.lp, plane_index=0 - ) - - return visuals_2d_via_mask + visuals_2d_via_tracer diff --git a/autolens/point/model/plotter_interface.py b/autolens/point/model/plotter_interface.py index 548e7ec1f..db3df0c43 100644 --- a/autolens/point/model/plotter_interface.py +++ b/autolens/point/model/plotter_interface.py @@ -34,9 +34,7 @@ def should_plot(name): mat_plot_2d = self.mat_plot_2d_from() - dataset_plotter = PointDatasetPlotter( - dataset=dataset, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d - ) + dataset_plotter = PointDatasetPlotter(dataset=dataset, mat_plot_2d=mat_plot_2d) if should_plot("subplot_dataset"): dataset_plotter.subplot_dataset() @@ -69,9 +67,7 @@ def should_plot(name): mat_plot_2d = self.mat_plot_2d_from() - fit_plotter = FitPointDatasetPlotter( - fit=fit, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d - ) + fit_plotter = FitPointDatasetPlotter(fit=fit, mat_plot_2d=mat_plot_2d) if should_plot("subplot_fit"): fit_plotter.subplot_fit() diff --git a/autolens/point/plot/fit_point_plotters.py b/autolens/point/plot/fit_point_plotters.py index 138fd7a8d..4d7cb8c74 100644 --- a/autolens/point/plot/fit_point_plotters.py +++ b/autolens/point/plot/fit_point_plotters.py @@ -10,10 +10,8 @@ def __init__( fit: FitPointDataset, mat_plot_1d: aplt.MatPlot1D = None, visuals_1d: aplt.Visuals1D = None, - include_1d: aplt.Include1D = None, mat_plot_2d: aplt.MatPlot2D = None, visuals_2d: aplt.Visuals2D = None, - include_2d: aplt.Include2D = None, ): """ Plots the attributes of `FitPointDataset` objects using matplotlib methods and functions which customize the @@ -24,8 +22,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- @@ -36,26 +33,16 @@ def __init__( Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ super().__init__( mat_plot_1d=mat_plot_1d, visuals_1d=visuals_1d, - include_1d=include_1d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, visuals_2d=visuals_2d, ) self.fit = fit - def get_visuals_1d(self) -> aplt.Visuals1D: - return self.visuals_1d - - def get_visuals_2d(self) -> aplt.Visuals2D: - return self.visuals_2d - def figures_2d(self, positions: bool = False, fluxes: bool = False): """ Plots the individual attributes of the plotter's `FitPointDataset` object in 2D. @@ -71,7 +58,7 @@ def figures_2d(self, positions: bool = False, fluxes: bool = False): If `True`, the dataset's fluxes are plotted on the figure compared to the model fluxes. """ if positions: - visuals_2d = self.get_visuals_2d() + visuals_2d = self.visuals_2d visuals_2d += visuals_2d.__class__( multiple_images=self.fit.positions.model_data @@ -138,7 +125,7 @@ def figures_2d(self, positions: bool = False, fluxes: bool = False): if fluxes: if self.fit.dataset.fluxes is not None: - visuals_1d = self.get_visuals_1d() + visuals_1d = self.visuals_1d # Dataset may have flux but model may not diff --git a/autolens/point/plot/point_dataset_plotters.py b/autolens/point/plot/point_dataset_plotters.py index c89b7e2ec..f781fc729 100644 --- a/autolens/point/plot/point_dataset_plotters.py +++ b/autolens/point/plot/point_dataset_plotters.py @@ -10,10 +10,8 @@ def __init__( dataset: PointDataset, mat_plot_1d: aplt.MatPlot1D = None, visuals_1d: aplt.Visuals1D = None, - include_1d: aplt.Include1D = None, mat_plot_2d: aplt.MatPlot2D = None, visuals_2d: aplt.Visuals2D = None, - include_2d: aplt.Include2D = None, ): """ Plots the attributes of `PointDataset` objects using the matplotlib methods and functions functions which @@ -24,39 +22,26 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Imaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Imaging` and plotted via the visuals object. Parameters ---------- dataset The imaging dataset the plotter plots. - get_visuals_2d - A function which extracts from the `Imaging` the 2D visuals which are plotted on figures. mat_plot_2d Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Imaging` are extracted and plotted as visuals for 2D plots. """ super().__init__( mat_plot_1d=mat_plot_1d, visuals_1d=visuals_1d, - include_1d=include_1d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, visuals_2d=visuals_2d, ) self.dataset = dataset - def get_visuals_1d(self) -> aplt.Visuals1D: - return self.visuals_1d - - def get_visuals_2d(self) -> aplt.Visuals2D: - return self.visuals_2d - def figures_2d(self, positions: bool = False, fluxes: bool = False): """ Plots the individual attributes of the plotter's `PointDataset` object in 2D. @@ -76,7 +61,7 @@ def figures_2d(self, positions: bool = False, fluxes: bool = False): grid=self.dataset.positions, y_errors=self.dataset.positions_noise_map, x_errors=self.dataset.positions_noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=aplt.AutoLabels( title=f"{self.dataset.name} Positions", filename="point_dataset_positions", @@ -100,7 +85,7 @@ def figures_2d(self, positions: bool = False, fluxes: bool = False): self.mat_plot_1d.plot_yx( y=self.dataset.fluxes, y_errors=self.dataset.fluxes_noise_map, - visuals_1d=self.get_visuals_1d(), + visuals_1d=self.visuals_1d, auto_labels=aplt.AutoLabels( title=f" {self.dataset.name} Fluxes", filename="point_dataset_fluxes", diff --git a/docs/api/plot.rst b/docs/api/plot.rst index 5c8ed390e..bef26ff77 100644 --- a/docs/api/plot.rst +++ b/docs/api/plot.rst @@ -75,8 +75,6 @@ visuals to figures. MatPlot1D MatPlot2D - Include1D - Include2D Visuals1D Visuals2D diff --git a/docs/overview/overview_3_features.rst b/docs/overview/overview_3_features.rst index 8ae00beec..16d6ef608 100644 --- a/docs/overview/overview_3_features.rst +++ b/docs/overview/overview_3_features.rst @@ -45,7 +45,7 @@ galaxy: A complete overview of pixelized source reconstructions can be found at ``notebooks/overview/overview_5_pixelizations.ipynb``. -Chapter 4 of the **HowToLens** lectures describes pixelizations in detail and teaches users how they can be used to +Chapter 4 of lectures describes pixelizations in detail and teaches users how they can be used to perform lens modeling. diff --git a/test_autolens/analysis/test_plotter_interface.py b/test_autolens/analysis/test_plotter_interface.py index 8077e02d5..2a5d4573c 100644 --- a/test_autolens/analysis/test_plotter_interface.py +++ b/test_autolens/analysis/test_plotter_interface.py @@ -14,9 +14,7 @@ def make_plotter_interface_plotter_setup(): return path.join("{}".format(directory), "files") -def test__tracer( - masked_imaging_7x7, tracer_x2_plane_7x7, include_2d_all, plot_path, plot_patch -): +def test__tracer(masked_imaging_7x7, tracer_x2_plane_7x7, plot_path, plot_patch): if os.path.exists(plot_path): shutil.rmtree(plot_path) @@ -36,9 +34,7 @@ def test__tracer( assert image.shape == (5, 5) -def test__image_with_positions( - image_7x7, positions_x2, include_2d_all, plot_path, plot_patch -): +def test__image_with_positions(image_7x7, positions_x2, plot_path, plot_patch): if os.path.exists(plot_path): shutil.rmtree(plot_path) diff --git a/test_autolens/config/visualize.yaml b/test_autolens/config/visualize.yaml index aafc2f8c9..8e7439c03 100644 --- a/test_autolens/config/visualize.yaml +++ b/test_autolens/config/visualize.yaml @@ -3,26 +3,6 @@ general: backend: default imshow_origin: upper zoom_around_mask: true -include: - include_2d: - border: false - tangential_caustics: false - radial_caustics: false - tangential_critical_curves: false - radial_critical_curves: false - grid: true - inversion_grid: true - light_profile_centres: true - mapper_image_plane_mesh_grid: false - mapper_source_plane_mesh_grid: true - mask: true - mass_profile_centres: true - multiple_images: false - origin: true - parallel_overscan: true - positions: true - serial_overscan: true - serial_prescan: true mat_wrap_2d: CausticsLine: figure: diff --git a/test_autolens/conftest.py b/test_autolens/conftest.py index 961f3fd09..f75cefec4 100644 --- a/test_autolens/conftest.py +++ b/test_autolens/conftest.py @@ -371,11 +371,6 @@ def make_adapt_images_7x7(): return fixtures.make_adapt_images_7x7() -@pytest.fixture(name="include_2d_all") -def make_include_all(): - return fixtures.make_include_2d_all() - - @pytest.fixture(name="samples_summary_with_result") def make_samples_summary_with_result(): return fixtures.make_samples_summary_with_result() diff --git a/test_autolens/imaging/model/test_plotter_interface_imaging.py b/test_autolens/imaging/model/test_plotter_interface_imaging.py index e430d2cb9..692020ebf 100644 --- a/test_autolens/imaging/model/test_plotter_interface_imaging.py +++ b/test_autolens/imaging/model/test_plotter_interface_imaging.py @@ -15,7 +15,7 @@ def make_plotter_interface_plotter_setup(): def test__fit_imaging( - fit_imaging_x2_plane_inversion_7x7, include_2d_all, plot_path, plot_patch + fit_imaging_x2_plane_inversion_7x7, plot_path, plot_patch ): if os.path.exists(plot_path): shutil.rmtree(plot_path) diff --git a/test_autolens/imaging/plot/test_fit_imaging_plotters.py b/test_autolens/imaging/plot/test_fit_imaging_plotters.py index 05fd07087..fce14771c 100644 --- a/test_autolens/imaging/plot/test_fit_imaging_plotters.py +++ b/test_autolens/imaging/plot/test_fit_imaging_plotters.py @@ -15,12 +15,11 @@ def make_fit_imaging_plotter_setup(): def test__fit_quantities_are_output( - fit_imaging_x2_plane_7x7, include_2d_all, plot_path, plot_patch + fit_imaging_x2_plane_7x7, plot_path, plot_patch ): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_x2_plane_7x7, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -62,12 +61,11 @@ def test__fit_quantities_are_output( def test__figures_of_plane( - fit_imaging_x2_plane_7x7, include_2d_all, plot_path, plot_patch + fit_imaging_x2_plane_7x7, plot_path, plot_patch ): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_x2_plane_7x7, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -99,12 +97,11 @@ def test__figures_of_plane( def test_subplot_fit_is_output( - fit_imaging_x2_plane_7x7, include_2d_all, plot_path, plot_patch + fit_imaging_x2_plane_7x7, plot_path, plot_patch ): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_x2_plane_7x7, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), ) @@ -116,12 +113,11 @@ def test_subplot_fit_is_output( def test__subplot_of_planes( - fit_imaging_x2_plane_7x7, include_2d_all, plot_path, plot_patch + fit_imaging_x2_plane_7x7, plot_path, plot_patch ): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_x2_plane_7x7, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), ) diff --git a/test_autolens/interferometer/model/test_plotter_interface_interferometer.py b/test_autolens/interferometer/model/test_plotter_interface_interferometer.py index 02c4c1a96..412ed86dd 100644 --- a/test_autolens/interferometer/model/test_plotter_interface_interferometer.py +++ b/test_autolens/interferometer/model/test_plotter_interface_interferometer.py @@ -18,7 +18,6 @@ def make_plotter_interface_plotter_setup(): def test__fit_interferometer( fit_interferometer_x2_plane_7x7, - include_2d_all, plot_path, plot_patch, ): diff --git a/test_autolens/interferometer/plot/test_fit_interferometer_plotters.py b/test_autolens/interferometer/plot/test_fit_interferometer_plotters.py index d3ca6ac10..b405caa4b 100644 --- a/test_autolens/interferometer/plot/test_fit_interferometer_plotters.py +++ b/test_autolens/interferometer/plot/test_fit_interferometer_plotters.py @@ -107,13 +107,11 @@ def test__fit_quantities_are_output( def test__fit_sub_plot_real_space( fit_interferometer_x2_plane_7x7, fit_interferometer_x2_plane_inversion_7x7, - include_2d_all, plot_path, plot_patch, ): fit_plotter = aplt.FitInterferometerPlotter( fit=fit_interferometer_x2_plane_7x7, - include_2d=include_2d_all, mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(plot_path, format="png")), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), ) diff --git a/test_autolens/lens/plot/test_tracer_plotters.py b/test_autolens/lens/plot/test_tracer_plotters.py index d77eb2dce..67cb86e60 100644 --- a/test_autolens/lens/plot/test_tracer_plotters.py +++ b/test_autolens/lens/plot/test_tracer_plotters.py @@ -21,7 +21,6 @@ def test__all_individual_plotter( tracer_x2_plane_7x7, grid_2d_7x7, mask_2d_7x7, - include_2d_all, plot_path, plot_patch, ): @@ -54,7 +53,6 @@ def test__all_individual_plotter( tracer_plotter = aplt.TracerPlotter( tracer=tracer_x2_plane_7x7, grid=grid_2d_7x7, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), ) @@ -75,14 +73,12 @@ def test__figures_of_plane( tracer_x2_plane_7x7, grid_2d_7x7, mask_2d_7x7, - include_2d_all, plot_path, plot_patch, ): tracer_plotter = aplt.TracerPlotter( tracer=tracer_x2_plane_7x7, grid=grid_2d_7x7, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -99,13 +95,10 @@ def test__figures_of_plane( assert path.join(plot_path, "plane_image_of_plane_1.png") not in plot_patch.paths -def test__tracer_plot_output( - tracer_x2_plane_7x7, grid_2d_7x7, include_2d_all, plot_path, plot_patch -): +def test__tracer_plot_output(tracer_x2_plane_7x7, grid_2d_7x7, plot_path, plot_patch): tracer_plotter = aplt.TracerPlotter( tracer=tracer_x2_plane_7x7, grid=grid_2d_7x7, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), ) diff --git a/test_autolens/plot/__init__.py b/test_autolens/plot/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_autolens/plot/test_get_visuals.py b/test_autolens/plot/test_get_visuals.py deleted file mode 100644 index 6ebcbbc2b..000000000 --- a/test_autolens/plot/test_get_visuals.py +++ /dev/null @@ -1,169 +0,0 @@ -from os import path -import pytest - -import autolens.plot as aplt - -from autolens.plot.get_visuals.two_d import GetVisuals2D - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_profile_plotter_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "profiles" - ) - - -def test__2d__via_tracer(tracer_x2_plane_7x7, grid_2d_7x7): - visuals_2d = aplt.Visuals2D(vectors=2) - - include_2d = aplt.Include2D( - origin=True, - border=True, - light_profile_centres=True, - mass_profile_centres=True, - tangential_critical_curves=True, - radial_critical_curves=True, - ) - - get_visuals = GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_tracer_from( - tracer=tracer_x2_plane_7x7, grid=grid_2d_7x7, plane_index=0 - ) - - assert visuals_2d_via.origin.in_list == [(0.0, 0.0)] - assert (visuals_2d_via.border == grid_2d_7x7.mask.derive_grid.border).all() - assert visuals_2d_via.light_profile_centres.in_list == [ - tracer_x2_plane_7x7.galaxies[1].light_profile_0.centre - ] - assert visuals_2d_via.mass_profile_centres.in_list == [ - tracer_x2_plane_7x7.galaxies[0].mass_profile_0.centre - ] - assert ( - visuals_2d_via.tangential_critical_curves[0] - == tracer_x2_plane_7x7.tangential_critical_curve_list_from(grid=grid_2d_7x7)[0] - ).all() - assert ( - visuals_2d_via.radial_critical_curves[0] - == tracer_x2_plane_7x7.radial_critical_curve_list_from(grid=grid_2d_7x7)[0] - ).all() - assert visuals_2d_via.vectors == 2 - - include_2d = aplt.Include2D( - origin=True, - border=True, - light_profile_centres=True, - mass_profile_centres=True, - tangential_caustics=True, - radial_caustics=True, - ) - - get_visuals = GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_tracer_from( - tracer=tracer_x2_plane_7x7, grid=grid_2d_7x7, plane_index=1 - ) - - assert visuals_2d_via.origin.in_list == [(0.0, 0.0)] - traced_border = tracer_x2_plane_7x7.traced_grid_2d_list_from( - grid=grid_2d_7x7.mask.derive_grid.border - )[1] - assert (visuals_2d_via.border == traced_border).all() - assert visuals_2d_via.light_profile_centres.in_list == [ - tracer_x2_plane_7x7.galaxies[1].light_profile_0.centre - ] - assert visuals_2d_via.mass_profile_centres is None - assert ( - visuals_2d_via.tangential_caustics[0] - == tracer_x2_plane_7x7.tangential_caustic_list_from(grid=grid_2d_7x7)[0] - ).all() - assert ( - visuals_2d_via.radial_caustics[0] - == tracer_x2_plane_7x7.radial_caustic_list_from(grid=grid_2d_7x7)[0] - ).all() - - include_2d = aplt.Include2D( - origin=False, - border=False, - light_profile_centres=False, - mass_profile_centres=False, - tangential_critical_curves=False, - radial_critical_curves=False, - ) - - get_visuals = GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_tracer_from( - tracer=tracer_x2_plane_7x7, grid=grid_2d_7x7, plane_index=0 - ) - - assert visuals_2d_via.origin is None - assert visuals_2d_via.border is None - assert visuals_2d_via.light_profile_centres is None - assert visuals_2d_via.mass_profile_centres is None - assert visuals_2d_via.tangential_critical_curves is None - assert visuals_2d_via.radial_critical_curves is None - assert visuals_2d_via.vectors == 2 - - -def test__via_fit_imaging_from(fit_imaging_x2_plane_7x7, grid_2d_7x7): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0), vectors=2) - include_2d = aplt.Include2D( - origin=True, - mask=True, - border=True, - light_profile_centres=True, - mass_profile_centres=True, - tangential_critical_curves=True, - radial_critical_curves=True, - ) - - get_visuals = GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_fit_imaging_from(fit=fit_imaging_x2_plane_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert (visuals_2d_via.mask == fit_imaging_x2_plane_7x7.mask).all() - assert ( - visuals_2d_via.border == fit_imaging_x2_plane_7x7.mask.derive_grid.border - ).all() - assert visuals_2d_via.light_profile_centres.in_list == [(0.0, 0.0)] - assert visuals_2d_via.mass_profile_centres.in_list == [(0.0, 0.0)] - assert ( - visuals_2d_via.tangential_critical_curves[0] - == fit_imaging_x2_plane_7x7.tracer.tangential_critical_curve_list_from( - grid=grid_2d_7x7 - )[0] - ).all() - assert ( - visuals_2d_via.radial_critical_curves[0] - == fit_imaging_x2_plane_7x7.tracer.radial_critical_curve_list_from( - grid=grid_2d_7x7 - )[0] - ).all() - assert visuals_2d_via.vectors == 2 - - include_2d = aplt.Include2D( - origin=False, - mask=False, - border=False, - light_profile_centres=False, - mass_profile_centres=False, - tangential_critical_curves=False, - radial_critical_curves=False, - ) - - get_visuals = GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_fit_imaging_from(fit=fit_imaging_x2_plane_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert visuals_2d_via.mask is None - assert visuals_2d_via.border is None - assert visuals_2d_via.light_profile_centres is None - assert visuals_2d_via.mass_profile_centres is None - assert visuals_2d_via.tangential_critical_curves is None - assert visuals_2d_via.radial_critical_curves is None - assert visuals_2d_via.vectors == 2 diff --git a/test_autolens/plot/test_subhalo_plotters.py b/test_autolens/plot/test_subhalo_plotters.py deleted file mode 100644 index 0e160ca19..000000000 --- a/test_autolens/plot/test_subhalo_plotters.py +++ /dev/null @@ -1,69 +0,0 @@ -from os import path - -import pytest - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_fit_imaging_plotter_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "fit" - ) - - -# def test__subhalo_detection_sub_plot( -# fit_imaging_x2_plane_7x7, -# fit_imaging_x2_plane_inversion_7x7, -# include_2d_all, -# plot_path, -# plot_patch, -# ): -# arr = al.Array2D.no_mask(values=[[1.0, 2.0], [3.0, 4.0]], pixel_scales=1.0) -# -# subhalo_plotter = aplt.SubhaloPlotter( -# include_2d=include_2d_all, -# mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), -# ) -# -# subhalo_plotter.subplot_detection_imaging( -# fit_imaging_detect=fit_imaging_x2_plane_7x7, detection_array=arr, mass_array=arr -# ) -# -# assert path.join(plot_path, "subplot_detection_imaging.png") in plot_patch.paths -# -# subhalo_plotter.subplot_detection_imaging( -# fit_imaging_detect=fit_imaging_x2_plane_inversion_7x7, -# detection_array=arr, -# mass_array=arr, -# ) -# -# assert path.join(plot_path, "subplot_detection_imaging.png") in plot_patch.paths -# -# -# def test__subhalo_detection_fits( -# fit_imaging_x2_plane_7x7, -# fit_imaging_x2_plane_inversion_7x7, -# include_2d_all, -# plot_path, -# plot_patch, -# ): -# -# subhalo_plotter = aplt.SubhaloPlotter( -# include_2d=include_2d_all, -# mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), -# ) -# -# subhalo_plotter.subplot_detection_fits( -# fit_imaging_before=fit_imaging_x2_plane_7x7, -# fit_imaging_detect=fit_imaging_x2_plane_7x7, -# ) -# -# assert path.join(plot_path, "subplot_detection_fits.png") in plot_patch.paths -# -# subhalo_plotter.subplot_detection_fits( -# fit_imaging_before=fit_imaging_x2_plane_inversion_7x7, -# fit_imaging_detect=fit_imaging_x2_plane_inversion_7x7, -# ) -# -# assert path.join(plot_path, "subplot_detection_fits.png") in plot_patch.paths diff --git a/test_autolens/point/model/test_plotter_interface_point.py b/test_autolens/point/model/test_plotter_interface_point.py index ebe613bec..73b4a02b4 100644 --- a/test_autolens/point/model/test_plotter_interface_point.py +++ b/test_autolens/point/model/test_plotter_interface_point.py @@ -13,7 +13,7 @@ def make_plotter_interface_plotter_setup(): return path.join("{}".format(directory), "files") -def test__fit_point(fit_point_dataset_x2_plane, include_2d_all, plot_path, plot_patch): +def test__fit_point(fit_point_dataset_x2_plane, plot_path, plot_patch): if os.path.exists(plot_path): shutil.rmtree(plot_path) diff --git a/test_autolens/point/plot/test_fit_point_plotters.py b/test_autolens/point/plot/test_fit_point_plotters.py index d544e8de7..43eed3d4f 100644 --- a/test_autolens/point/plot/test_fit_point_plotters.py +++ b/test_autolens/point/plot/test_fit_point_plotters.py @@ -18,11 +18,10 @@ def make_fit_point_plotter_setup(): def test__fit_point_quantities_are_output( - fit_point_dataset_x2_plane, include_2d_all, plot_path, plot_patch + fit_point_dataset_x2_plane, plot_path, plot_patch ): fit_point_plotter = aplt.FitPointDatasetPlotter( fit=fit_point_dataset_x2_plane, - include_2d=include_2d_all, mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -45,7 +44,6 @@ def test__fit_point_quantities_are_output( fit_point_plotter = aplt.FitPointDatasetPlotter( fit=fit_point_dataset_x2_plane, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -55,12 +53,9 @@ def test__fit_point_quantities_are_output( assert path.join(plot_path, "fit_point_fluxes.png") not in plot_patch.paths -def test__subplot_fit( - fit_point_dataset_x2_plane, include_2d_all, plot_path, plot_patch -): +def test__subplot_fit(fit_point_dataset_x2_plane, plot_path, plot_patch): fit_point_plotter = aplt.FitPointDatasetPlotter( fit=fit_point_dataset_x2_plane, - include_2d=include_2d_all, mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) diff --git a/test_autolens/point/plot/test_point_dataset_plotters.py b/test_autolens/point/plot/test_point_dataset_plotters.py index c90bc2ce0..404015750 100644 --- a/test_autolens/point/plot/test_point_dataset_plotters.py +++ b/test_autolens/point/plot/test_point_dataset_plotters.py @@ -17,12 +17,9 @@ def make_point_dataset_plotter_setup(): ) -def test__point_dataset_quantities_are_output( - point_dataset, include_2d_all, plot_path, plot_patch -): +def test__point_dataset_quantities_are_output(point_dataset, plot_path, plot_patch): point_dataset_plotter = aplt.PointDatasetPlotter( dataset=point_dataset, - include_2d=include_2d_all, mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -45,7 +42,6 @@ def test__point_dataset_quantities_are_output( point_dataset_plotter = aplt.PointDatasetPlotter( dataset=point_dataset, - include_2d=include_2d_all, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -55,10 +51,9 @@ def test__point_dataset_quantities_are_output( assert path.join(plot_path, "point_dataset_fluxes.png") not in plot_patch.paths -def test__subplot_dataset(point_dataset, include_2d_all, plot_path, plot_patch): +def test__subplot_dataset(point_dataset, plot_path, plot_patch): point_dataset_plotter = aplt.PointDatasetPlotter( dataset=point_dataset, - include_2d=include_2d_all, mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), )