diff --git a/.gitignore b/.gitignore index a1df8f42d..40460efaf 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,8 @@ test_autoarray/dataset/plot/files/ test_autoarray/fit/plot/files/ test_autoarray/structures/arrays/one_d/files/array/ test_autoarray/structures/plot/files/ +test_autoarray/inversion/plot/files/ +test_autoarray/plot/files/ test_autoarray/instruments/files/ .envr diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 000000000..68475ffe5 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,207 @@ +# Plan: Remove Visuals Classes and Pass Overlays Directly + +## Current State + +The codebase is in a *partial* refactoring state. Standalone `plot_array`, `plot_grid`, +`plot_yx`, `plot_inversion_reconstruction` functions already exist in +`autoarray/plot/plots/` and are called by the new-style plotters. However: + +- `Visuals1D` and `Visuals2D` wrapper classes still exist +- Every plotter still accepts `visuals_2d` / `visuals_1d` constructor args and stores them +- Helper functions (`_lines_from_visuals`, `_positions_from_visuals`, `_mask_edge_from`, + `_grid_from_visuals`) bridge old Visuals → new standalone functions +- `MatPlot2D.plot_array/plot_grid/plot_mapper` and `MatPlot1D.plot_yx` still exist and + `InterferometerPlotter` still calls them directly +- `InversionPlotter.subplot_of_mapper` directly mutates `self.visuals_2d` + +## Goal + +Remove `Visuals1D`, `Visuals2D`, and `AbstractVisuals` entirely. Each plotter holds its +overlay data as plain attributes and passes them straight to the `plot_*` standalone +functions. Default overlays (e.g. mask derived from `array.mask`) are computed inline. + +--- + +## Steps + +### 1. Update `AbstractPlotter` (`abstract_plotters.py`) +- Remove `visuals_1d: Visuals1D` and `visuals_2d: Visuals2D` constructor parameters and + their default instantiation (`self.visuals_1d = visuals_1d or Visuals1D()`, etc.) +- Remove the imports of `Visuals1D` and `Visuals2D` + +### 2. Update each Plotter constructor to accept individual overlay objects + +Replace `visuals_2d: Visuals2D = None` with explicit per-overlay kwargs. Plotters store +each overlay as a plain instance attribute (defaulting to `None`). + +**`Array2DPlotter`** (`structures/plot/structure_plotters.py`): +```python +def __init__(self, array, mat_plot_2d=None, + mask=None, origin=None, border=None, grid=None, + positions=None, lines=None, vectors=None, + patches=None, fill_region=None, array_overlay=None): +``` + +**`Grid2DPlotter`**: +```python +def __init__(self, grid, mat_plot_2d=None, lines=None, positions=None): +``` + +**`YX1DPlotter`**: +```python +def __init__(self, y, x=None, mat_plot_1d=None, + shaded_region=None, vertical_line=None, points=None, ...): +``` + +**`MapperPlotter`** (`inversion/plot/mapper_plotters.py`): +```python +def __init__(self, mapper, mat_plot_2d=None, + lines=None, grid=None, positions=None): +``` + +**`InversionPlotter`** (`inversion/plot/inversion_plotters.py`): +```python +def __init__(self, inversion, mat_plot_2d=None, + lines=None, grid=None, positions=None, + residuals_symmetric_cmap=True): +``` + +**`ImagingPlotterMeta` / `ImagingPlotter`** (`dataset/plot/imaging_plotters.py`): +```python +def __init__(self, dataset, mat_plot_2d=None, + mask=None, grid=None, positions=None, lines=None): +``` + +**`FitImagingPlotterMeta` / `FitImagingPlotter`** (`fit/plot/fit_imaging_plotters.py`): +```python +def __init__(self, fit, mat_plot_2d=None, + mask=None, grid=None, positions=None, lines=None, + residuals_symmetric_cmap=True): +``` + +**`InterferometerPlotter`** (`dataset/plot/interferometer_plotters.py`): +```python +def __init__(self, dataset, mat_plot_1d=None, mat_plot_2d=None, lines=None): +``` + +### 3. Inline overlay logic inside each plotter's `_plot_*` / `figure_*` methods + +Each plotter's internal plot helpers already call the standalone functions. Replace +calls like: +```python +mask=_mask_edge_from(array, self.visuals_2d), +lines=_lines_from_visuals(self.visuals_2d), +``` +with direct access to the plotter's own attributes plus inline auto-extraction: +```python +mask=self.mask if self.mask is not None else _auto_mask_edge(array), +lines=self.lines, +``` + +Where `_auto_mask_edge(array)` is a tiny module-level helper (no Visuals dependency): +```python +def _auto_mask_edge(array): + """Return edge-pixel (y,x) coords from array.mask, or None.""" + try: + if not array.mask.is_all_false: + return np.array(array.mask.derive_grid.edge.array) + except AttributeError: + pass + return None +``` + +### 4. Fix `InversionPlotter.subplot_of_mapper` — drop the `visuals_2d` mutation + +Currently this method does: +```python +self.visuals_2d += Visuals2D(mesh_grid=mapper.image_plane_mesh_grid) +``` +Replace by passing `mesh_grid` directly to the specific `figures_2d_of_pixelization` +call that needs it, or by temporarily storing `self.mesh_grid` on the plotter and +checking it in `_plot_array`. The mutation and the `Visuals2D(...)` construction are +both removed. + +Similarly remove `self.visuals_2d.indexes = indexes` in `subplot_mappings` — store as +`self._indexes` and pass through. + +### 5. Update `InterferometerPlotter.figures_2d` — replace old MatPlot calls + +`InterferometerPlotter` still calls `self.mat_plot_2d.plot_array(...)`, +`self.mat_plot_2d.plot_grid(...)`, and `self.mat_plot_1d.plot_yx(...)`. + +Replace each with the equivalent standalone function call, deriving `ax`, `output_path`, +`filename`, `fmt` via `_output_for_mat_plot` (which already exists and has no Visuals +dependency). + +### 6. Remove `MatPlot2D.plot_array`, `plot_grid`, `plot_mapper` (and private helpers) + +Once no caller uses them, delete these methods from `mat_plot/two_d.py`: +- `plot_array` +- `plot_grid` +- `plot_mapper` +- `_plot_rectangular_mapper` +- `_plot_delaunay_mapper` + +Remove the `from autoarray.plot.visuals.two_d import Visuals2D` import. + +### 7. Remove `MatPlot1D.plot_yx` + +Delete the method from `mat_plot/one_d.py` and remove the `Visuals1D` import. + +### 8. Remove helper extraction functions from `structure_plotters.py` + +Delete (no longer needed): +- `_lines_from_visuals` +- `_positions_from_visuals` +- `_mask_edge_from` +- `_grid_from_visuals` + +Keep: `_zoom_array`, `_output_for_mat_plot` (neither depends on Visuals). + +### 9. Delete `autoarray/plot/visuals/` + +Remove: +- `autoarray/plot/visuals/__init__.py` +- `autoarray/plot/visuals/abstract.py` +- `autoarray/plot/visuals/one_d.py` +- `autoarray/plot/visuals/two_d.py` + +### 10. Update `autoarray/plot/__init__.py` + +Remove `Visuals1D` and `Visuals2D` exports (lines 45–46). + +### 11. Check and update remaining plotters + +Read and update: +- `fit/plot/fit_interferometer_plotters.py` +- `fit/plot/fit_vector_yx_plotters.py` + +Both import `Visuals1D`/`Visuals2D`; apply the same pattern as above. + +### 12. Run full test suite + +```bash +python -m pytest test_autoarray/ -q --tb=short +``` + +Fix any failures before committing. + +--- + +## Summary of files changed + +| File | Change | +|------|--------| +| `autoarray/plot/abstract_plotters.py` | Remove `visuals_1d`, `visuals_2d` | +| `autoarray/plot/mat_plot/one_d.py` | Remove `plot_yx`, remove Visuals1D import | +| `autoarray/plot/mat_plot/two_d.py` | Remove `plot_array/grid/mapper` methods, remove Visuals2D import | +| `autoarray/plot/visuals/` | **Delete entire directory** | +| `autoarray/plot/__init__.py` | Remove Visuals exports | +| `autoarray/structures/plot/structure_plotters.py` | Replace visuals args with individual kwargs; remove helper functions | +| `autoarray/inversion/plot/mapper_plotters.py` | Replace visuals args with individual kwargs | +| `autoarray/inversion/plot/inversion_plotters.py` | Replace visuals args; fix subplot_of_mapper mutation | +| `autoarray/dataset/plot/imaging_plotters.py` | Replace visuals args with individual kwargs | +| `autoarray/dataset/plot/interferometer_plotters.py` | Replace visuals args; replace old MatPlot calls | +| `autoarray/fit/plot/fit_imaging_plotters.py` | Replace visuals args with individual kwargs | +| `autoarray/fit/plot/fit_interferometer_plotters.py` | Replace visuals args | +| `autoarray/fit/plot/fit_vector_yx_plotters.py` | Replace visuals args | diff --git a/autoarray/config/visualize/general.yaml b/autoarray/config/visualize/general.yaml index b6cecf50f..0c841906f 100644 --- a/autoarray/config/visualize/general.yaml +++ b/autoarray/config/visualize/general.yaml @@ -1,28 +1,30 @@ -general: - backend: default # The matploblib backend used for visualization. `default` uses the system default, can specifiy specific backend (e.g. TKAgg, Qt5Agg, WXAgg). - imshow_origin: upper # The `origin` input of `imshow`, determining if pixel values are ascending or descending on the y-axis. - log10_min_value: 1.0e-4 # If negative values are being plotted on a log10 scale, values below this value are rounded up to it (e.g. to remove negative values). - log10_max_value: 1.0e99 # If positive values are being plotted on a log10 scale, values above this value are rounded down to it (e.g. to prevent white blobs). - zoom_around_mask: true # If True, plots of data structures with a mask automatically zoom in the masked region. -inversion: - reconstruction_vmax_factor: 0.5 - total_mappings_pixels : 8 # The number of source pixels used when plotting the subplot_mappings of a pixelization. -zoom: - plane_percent: 0.01 - inversion_percent: 0.01 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor. -subplot_shape: # The shape of a subplots for figures with an input number of subplots (e.g. for a figure with 4 subplots, the shape is (2, 2)). - 1: (1, 1) # The shape of subplots for a figure with 1 subplot. - 2: (1, 2) # The shape of subplots for a figure with 2 subplots. - 4: (2, 2) # The shape of subplots for a figure with 4 (or less than the above value) of subplots. - 6: (2, 3) # The shape of subplots for a figure with 6 (or less than the above value) of subplots. - 9: (3, 3) # The shape of subplots for a figure with 9 (or less than the above value) of subplots. - 12: (3, 4) # The shape of subplots for a figure with 12 (or less than the above value) of subplots. - 16: (4, 4) # The shape of subplots for a figure with 16 (or less than the above value) of subplots. - 20: (4, 5) # The shape of subplots for a figure with 20 (or less than the above value) of subplots. - 36: (6, 6) # The shape of subplots for a figure with 36 (or less than the above value) of subplots. -subplot_shape_to_figsize_factor: (6, 6) # The factors by which the subplot_shape is multiplied to determine the figsize of a subplot (e.g. if the subplot_shape is (2,2), the figsize will be (2*6, 2*6). -units: - use_scaled: true # Whether to plot spatial coordinates in scaled units computed via the pixel_scale (e.g. arc-seconds) or pixel units by default. - cb_unit: $\,\,\mathrm{e^{-}}\,\mathrm{s^{-1}}$ # The string or latex unit label used for the colorbar of the image, for example electrons per second. - scaled_symbol: '"' # The symbol used when plotting spatial coordinates computed via the pixel_scale (e.g. for Astronomy data this is arc-seconds). - unscaled_symbol: pix # The symbol used when plotting spatial coordinates in unscaled pixel units. \ No newline at end of file +general: + backend: default # The matplotlib backend used for visualization. `default` uses the system default, can specify specific backend (e.g. TKAgg, Qt5Agg, WXAgg). + imshow_origin: upper # The `origin` input of `imshow`, determining if pixel values are ascending or descending on the y-axis. + log10_min_value: 1.0e-4 # If negative values are being plotted on a log10 scale, values below this value are rounded up to it (e.g. to remove negative values). + log10_max_value: 1.0e99 # If positive values are being plotted on a log10 scale, values above this value are rounded down to it. + zoom_around_mask: true # If True, plots of data structures with a mask automatically zoom in the masked region. +inversion: + reconstruction_vmax_factor: 0.5 + total_mappings_pixels: 8 # The number of source pixels used when plotting the subplot_mappings of a pixelization. +zoom: + plane_percent: 0.01 + inversion_percent: 0.01 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor. +units: + use_scaled: true # Whether to plot spatial coordinates in scaled units computed via the pixel_scale (e.g. arc-seconds) or pixel units by default. + cb_unit: $\,\,\mathrm{e^{-}}\,\mathrm{s^{-1}}$ # The string or latex unit label used for the colorbar of the image, for example electrons per second. + scaled_symbol: '"' # The symbol used when plotting spatial coordinates computed via the pixel_scale (e.g. for Astronomy data this is arc-seconds). + unscaled_symbol: pix # The symbol used when plotting spatial coordinates in unscaled pixel units. +mat_plot: + figure: + figsize: (7, 7) # Default figure size. Override via aplt.Figure(figsize=(...)). + yticks: + fontsize: 22 # Default y-tick font size. Override via aplt.YTicks(fontsize=...). + xticks: + fontsize: 22 # Default x-tick font size. Override via aplt.XTicks(fontsize=...). + title: + fontsize: 24 # Default title font size. Override via aplt.Title(fontsize=...). + ylabel: + fontsize: 16 # Default y-label font size. Override via aplt.YLabel(fontsize=...). + xlabel: + fontsize: 16 # Default x-label font size. Override via aplt.XLabel(fontsize=...). diff --git a/autoarray/config/visualize/mat_wrap.yaml b/autoarray/config/visualize/mat_wrap.yaml deleted file mode 100644 index 10bc8b822..000000000 --- a/autoarray/config/visualize/mat_wrap.yaml +++ /dev/null @@ -1,124 +0,0 @@ -# These settings specify the default matplotlib settings when figures and subplots are plotted. - -# For example, the `Figure` section has the following lines: - -# Figure: -# figure: -# aspect: square -# figsize: (7,7) -# subplot: -# aspect: square -# figsize: auto - -# This means that when a figure (e.g. a single image) is plotted it will use `figsize=(7,7)` and ``aspect="square`` if -# the values of these parameters are not manually set by the user via a `MatPlot2D` object. -# -# In the above example, subplots (e.g. more than one image) will always use `figsize="auto` by default. -# -# These configuration options can be customized such that the appearance of figures and subplots for a user is -# optimal for your computer set up. - -Axis: # wrapper for `plt.axis()`: customize the figure axis. - figure: {} - subplot: {} -Cmap: # wrapper for `plt.cmap()`: customize the figure colormap. - figure: - cmap: default - linscale: 0.01 - linthresh: 0.05 - norm: linear - vmax: null - vmin: null - subplot: - cmap: default - linscale: 0.01 - linthresh: 0.05 - norm: linear - vmax: null - vmin: null -Colorbar: # wrapper for `plt.colorbar()`: customize the figure colorbar. - figure: - fraction: 0.047 - pad: 0.01 - subplot: - fraction: 0.047 - pad: 0.01 -ColorbarTickParams: # wrapper for `cb.ax.tick_params()`: customize the ticks of the figure's colorbar. - figure: - labelrotation: 90 - labelsize: 22 - subplot: - labelrotation: 90 - labelsize: 18 -Figure: # wrapper for `plt.figure()`: customize the figure size. - figure: - aspect: square - figsize: (7,7) - subplot: - aspect: square - figsize: auto -Legend: # wrapper for `plt.legend()`: customize the figure legend. - figure: - fontsize: 12 - include: true - subplot: - fontsize: 12 - include: true -Text: # wrapper for `plt.text()`: customize the appearance of text placed on the figure. - figure: - fontsize: 16 - subplot: - fontsize: 10 -Annotate: # wrapper for `plt.annotate()`: customize the appearance of annotations placed on the figure. - figure: - fontsize: 16 - subplot: - fontsize: 10 -TickParams: # wrapper for `plt.tick_params()`: customize the figure tick parameters. - figure: - labelsize: 16 - subplot: - labelsize: 10 -Title: # wrapper for `plt.title()`: customize the figure title. - figure: - fontsize: 24 - subplot: - fontsize: 16 -XLabel: # wrapper for `plt.xlabel()`: customize the figure ylabel. - figure: - fontsize: 16 - xlabel: "" - subplot: - fontsize: 10 - xlabel: "" -XTicks: # wrapper for `plt.xticks()`: customize the figure xticks. - manual: - extent_factor_1d: 1.0 # For 1D plots, the fraction of the extent that the ticks appears from the edge of the figure and the center. - extent_factor_2d: 0.75 # For 2D plots, the fraction of the extent that the ticks appears from the edge of the figure and the center. - number_of_ticks_1d: 8 # For 1D plots, the number of ticks that appear on the x-axis. - number_of_ticks_2d: 3 # For 1D plots, the number of ticks that appear on the x-axis. - figure: - fontsize: 22 - subplot: - fontsize: 18 -YLabel: # wrapper for `plt.ylabel()`: customize the figure ylabel. - figure: - fontsize: 16 - ylabel: "" - subplot: - fontsize: 10 - ylabel: "" -YTicks: # wrapper for `plt.yticks()`: customize the figure yticks. - manual: - extent_factor_1d: 1.0 # For 1D plots, the fraction of the extent that the ticks appears from the edge of the figure and the center. - extent_factor_2d: 0.75 # For 2D plots, the fraction of the extent that the ticks appears from the edge of the figure and the center. - number_of_ticks_1d: 8 # For 1D plots, the number of ticks that appear on the y-axis. - number_of_ticks_2d: 3 # For 1D plots, the number of ticks that appear on the y-axis. - figure: - fontsize: 22 - rotation: vertical - va: center - subplot: - fontsize: 18 - rotation: vertical - va: center \ No newline at end of file diff --git a/autoarray/config/visualize/mat_wrap_1d.yaml b/autoarray/config/visualize/mat_wrap_1d.yaml deleted file mode 100644 index fc9dad3f0..000000000 --- a/autoarray/config/visualize/mat_wrap_1d.yaml +++ /dev/null @@ -1,40 +0,0 @@ -# These settings specify the default matplotlib settings when 1D figures and subplots are plotted. - -# For example, the `YXPlot` section has the following lines: - -# YXPlot: -# figure: -# c: k -# subplot: -# c: k - -# This means that when a figure of y vs x data is plotted it will use `c=k`, meaning the line appears black, -# provided the values of these parameters are not manually set by the user via a `MatPlot1D` object. -# -# In the above example, subplots (e.g. more than one image) will always use `c=k` by default as well. -# -# These configuration options can be customized such that the appearance of figures and subplots for a user is -# optimal for your computer set up. - -AXVLine: # wrapper for `plt.axvline()`: customize verticals lines plotted on the figure. - figure: - c: k - subplot: - c: k -FillBetween: # wrapper for `plt.fill_between()`: customize how fill between plots appear - figure: - alpha: 0.7 - color: k - subplot: - alpha: 0.7 - color: k -YXPlot: # wrapper for `plt.plot()`: customize plots of y versus x. - figure: - c: k - subplot: - c: k -YXScatter: - figure: - c: k - subplot: - c: k \ No newline at end of file diff --git a/autoarray/config/visualize/mat_wrap_2d.yaml b/autoarray/config/visualize/mat_wrap_2d.yaml deleted file mode 100644 index 087032458..000000000 --- a/autoarray/config/visualize/mat_wrap_2d.yaml +++ /dev/null @@ -1,184 +0,0 @@ -# These settings specify the default matplotlib settings when "D figures and subplots are plotted. - -# For example, the `GridScatter` section has the following lines: - -# GridScatter: -# figure: -# c: k -# subplot: -# c: k - -# This means that when a 2D grid of data is plotted it will use `c=k`, meaning the grid points appear black, -# provided the values of these parameters are not manually set by the user via a `MatPlot2D` object. -# -# In the above example, subplots (e.g. more than one image) will always use `c=k` by default as well. -# -# These configuration options can be customized such that the appearance of figures and subplots for a user is -# optimal for your computer set up. - -ArrayOverlay: # wrapper for `plt.imshow()`: customize arrays overlaid. - figure: - alpha: 0.5 - subplot: - alpha: 0.5 -Contour: # wrapper for `plt.contour()`: customize contours plotted on the figure. - figure: - colors: "k" - total_contours: 10 # Number of contours to plot - use_log10: true # If true, contours are plotted with log10 spacing, if False, linear spacing. - include_values: true # If true, the values of the contours are plotted on the figure. - subplot: - colors: "k" - total_contours: 10 # Number of contours to plot - use_log10: true # If true, contours are plotted with log10 spacing, if False, linear spacing. - include_values: true # If true, the values of the contours are plotted on the figure. -BorderScatter: # wrapper for `plt.scatter()`: customize the apperance of 2D borders. - figure: - c: r - marker: . - s: 30 - subplot: - c: r - marker: . - s: 10 -GridErrorbar: # wrapper for `plt.errrorbar()`: customize grids with errors. - figure: - alpha: 0.5 - c: k - fmt: o - linewidth: 5 - marker: o - markersize: 8 - subplot: - alpha: 0.5 - c: k - fmt: o - linewidth: 5 - marker: o - markersize: 8 -GridPlot: # wrapper for `plt.plot()`: customize how grids plotted via this method appear. - figure: - c: w - subplot: - c: w -GridScatter: # wrapper for `plt.scatter()`: customize appearances of Grid2D. - figure: - c: k - marker: . - s: 1 - subplot: - c: k - marker: . - s: 1 -IndexScatter: # wrapper for `plt.scatter()`: customize indexes (e.g. data / source plane or frame objects of an Inversion) - figure: - c: r,g,b,m,y,k - marker: . - s: 20 - subplot: - c: r,g,b,m,y,k - marker: . - s: 20 -IndexPlot: # wrapper for `plt.plot()`: customize indexes (e.g. data / source plane or frame objects of an Inversion) - figure: - c: r,g,b,m,y,k - linewidth: 3 - subplot: - c: r,g,b,m,y,k - linewidth: 3 -MaskScatter: # wrapper for `plt.scatter()`: customize the appearance of 2D masks. - figure: - c: k - marker: x - s: 10 - subplot: - c: k - marker: x - s: 10 -MeshGridScatter: # wrapper for `plt.scatter()`: customize the appearance of mesh grids of Inversions in the source-plane / source-frame. - figure: - c: r - marker: . - s: 2 - subplot: - c: r - marker: . - s: 2 -OriginScatter: # wrapper for `plt.scatter()`: customize the appearance of the (y,x) origin on figures. - figure: - c: k - marker: x - s: 80 - subplot: - c: k - marker: x - s: 80 -PatchOverlay: # wrapper for `plt.gcf().gca().add_collection`: customize how overlaid patches appear. - figure: - edgecolor: c - facecolor: null - subplot: - edgecolor: c - facecolor: null -PositionsScatter: # wrapper for `plt.scatter()`: customize the appearance of positions input via `Visuals2d.positions`. - figure: - c: k,m,y,b,r,g - marker: . - s: 32 - subplot: - c: k,m,y,b,r,g - marker: . - s: 32 -VectorYXQuiver: # wrapper for `plt.quiver()`: customize (y,x) vectors appearances (e.g. a shear field). - figure: - alpha: 1.0 - angles: xy - headlength: 0 - headwidth: 1 - linewidth: 5 - pivot: middle - units: xy - subplot: - alpha: 1.0 - angles: xy - headlength: 0 - headwidth: 1 - linewidth: 5 - pivot: middle - units: xy -DelaunayDrawer: # wrapper for `plt.fill()`: customize the appearance of Delaunay mesh's. - figure: - alpha: 0.7 - edgecolor: k - linewidth: 0.0 - subplot: - alpha: 0.7 - edgecolor: k - linewidth: 0.0 -ParallelOverscanPlot: - figure: - c: k - linestyle: '-' - linewidth: 1 - subplot: - c: k - linestyle: '-' - linewidth: 1 -SerialOverscanPlot: - figure: - c: k - linestyle: '-' - linewidth: 1 - subplot: - c: k - linestyle: '-' - linewidth: 1 -SerialPrescanPlot: - figure: - c: k - linestyle: '-' - linewidth: 1 - subplot: - c: k - linestyle: '-' - linewidth: 1 \ No newline at end of file diff --git a/autoarray/config/visualize/plots.yaml b/autoarray/config/visualize/plots.yaml index ed4da3b0c..e6c653d7c 100644 --- a/autoarray/config/visualize/plots.yaml +++ b/autoarray/config/visualize/plots.yaml @@ -3,11 +3,11 @@ # For example, if `plots: fit: subplot_fit=True``, the ``fit_dataset.png`` subplot file will # be plotted every time visualization is performed. -dataset: # Settings for plots of all datasets (e.g. ImagingPlotter, InterferometerPlotter). +dataset: # Settings for plots of all datasets (e.g. Imaging, Interferometer). subplot_dataset: true # Plot subplot containing all dataset quantities (e.g. the data, noise-map, etc.)? -imaging: # Settings for plots of imaging datasets (e.g. ImagingPlotter) +imaging: # Settings for plots of imaging datasets (e.g. Imaging) psf: false -fit: # Settings for plots of all fits (e.g. FitImagingPlotter, FitInterferometerPlotter). +fit: # Settings for plots of all fits (e.g. FitImaging, FitInterferometer). subplot_fit: true # Plot subplot of all fit quantities for any dataset (e.g. the model data, residual-map, etc.)? subplot_fit_log10: true # Plot subplot of all fit quantities for any dataset using log10 color maps (e.g. the model data, residual-map, etc.)? data: false # Plot individual plots of the data? @@ -18,8 +18,8 @@ fit: # Settings for plots of all fits (e.g normalized_residual_map: false # Plot individual plots of the normalized-residual-map? chi_squared_map: false # Plot individual plots of the chi-squared-map? residual_flux_fraction: false # Plot individual plots of the residual_flux_fraction? -fit_imaging: {} # Settings for plots of fits to imaging datasets (e.g. FitImagingPlotter). -inversion: # Settings for plots of inversions (e.g. InversionPlotter). +fit_imaging: {} # Settings for plots of fits to imaging datasets (e.g. FitImaging). +inversion: # Settings for plots of inversions (e.g. Inversion). subplot_inversion: true # Plot subplot of all quantities in each inversion (e.g. reconstrucuted image, reconstruction)? subplot_mappings: false # Plot subplot of the image-to-source pixels mappings of each pixelization? data_subtracted: false # Plot individual plots of the data with the other inversion linear objects subtracted? @@ -30,6 +30,6 @@ inversion: # Settings for plots of inversions (e reconstructed_operated_data: false # Plot image of the reconstructed data (e.g. in the image-plane)? reconstruction: false # Plot the reconstructed inversion (e.g. the pixelization's mesh in the source-plane)? regularization_weights: false # Plot the effective regularization weight of every inversion mesh pixel? -fit_interferometer: # Settings for plots of fits to interferometer datasets (e.g. FitInterferometerPlotter). +fit_interferometer: # Settings for plots of fits to interferometer datasets (e.g. FitInterferometer). subplot_fit_dirty_images: false # Plot subplot of the dirty-images of all interferometer datasets? subplot_fit_real_space: false # Plot subplot of the real-space images of all interferometer datasets? \ No newline at end of file diff --git a/autoarray/dataset/plot/__init__.py b/autoarray/dataset/plot/__init__.py index e69de29bb..1afae2d13 100644 --- a/autoarray/dataset/plot/__init__.py +++ b/autoarray/dataset/plot/__init__.py @@ -0,0 +1,5 @@ +from autoarray.dataset.plot.imaging_plots import subplot_imaging_dataset +from autoarray.dataset.plot.interferometer_plots import ( + subplot_interferometer_dataset, + subplot_interferometer_dirty_images, +) diff --git a/autoarray/dataset/plot/imaging_plots.py b/autoarray/dataset/plot/imaging_plots.py new file mode 100644 index 000000000..94e363b3f --- /dev/null +++ b/autoarray/dataset/plot/imaging_plots.py @@ -0,0 +1,126 @@ +from typing import Optional + +import matplotlib.pyplot as plt + +from autoarray.plot.array import plot_array +from autoarray.plot.utils import subplot_save + + +def subplot_imaging_dataset( + dataset, + output_path: Optional[str] = None, + output_filename: str = "subplot_dataset", + output_format: str = "png", + colormap=None, + use_log10: bool = False, + grid=None, + positions=None, + lines=None, +): + """ + 3×3 subplot of all ``Imaging`` dataset components. + + Panels (row-major): + 0. Data + 1. Data (log10) + 2. Noise-Map + 3. PSF (if present) + 4. PSF log10 (if present) + 5. Signal-To-Noise Map + 6. Over-sample size (light profiles) + 7. Over-sample size (pixelization) + + Parameters + ---------- + dataset + An ``Imaging`` dataset instance. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename without extension. + output_format + File format, e.g. ``"png"``. + colormap + Matplotlib colormap name. ``None`` uses the package default. + use_log10 + Apply log10 normalisation to non-log panels. + grid, positions, lines + Optional overlays forwarded to every panel. + """ + fig, axes = plt.subplots(3, 3, figsize=(21, 21)) + axes = axes.flatten() + + plot_array( + dataset.data, + ax=axes[0], + title="Data", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + dataset.data, + ax=axes[1], + title="Data (log10)", + colormap=colormap, + use_log10=True, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + dataset.noise_map, + ax=axes[2], + title="Noise-Map", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + + if dataset.psf is not None: + plot_array( + dataset.psf.kernel, + ax=axes[3], + title="Point Spread Function", + colormap=colormap, + use_log10=use_log10, + ) + plot_array( + dataset.psf.kernel, + ax=axes[4], + title="PSF (log10)", + colormap=colormap, + use_log10=True, + ) + + plot_array( + dataset.signal_to_noise_map, + ax=axes[5], + title="Signal-To-Noise Map", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + dataset.grids.over_sample_size_lp, + ax=axes[6], + title="Over Sample Size (Light Profiles)", + colormap=colormap, + use_log10=use_log10, + ) + plot_array( + dataset.grids.over_sample_size_pixelization, + ax=axes[7], + title="Over Sample Size (Pixelization)", + colormap=colormap, + use_log10=use_log10, + ) + + plt.tight_layout() + subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/dataset/plot/imaging_plotters.py b/autoarray/dataset/plot/imaging_plotters.py deleted file mode 100644 index ac1f23ded..000000000 --- a/autoarray/dataset/plot/imaging_plotters.py +++ /dev/null @@ -1,260 +0,0 @@ -import copy -from typing import Callable, Optional - -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.dataset.imaging.dataset import Imaging - - -class ImagingPlotterMeta(AbstractPlotter): - def __init__( - self, - dataset: Imaging, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Imaging` and plotted via the visuals object. - - Parameters - ---------- - dataset - The imaging dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.dataset = dataset - - @property - def imaging(self): - return self.dataset - - def figures_2d( - self, - data: bool = False, - noise_map: bool = False, - psf: bool = False, - signal_to_noise_map: bool = False, - over_sample_size_lp: bool = False, - over_sample_size_pixelization: bool = False, - title_str: Optional[str] = None, - ): - """ - Plots the individual attributes of the plotter's `Imaging` object in 2D. - - The API is such that every plottable attribute of the `Imaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - data - Whether to make a 2D plot (via `imshow`) of the image data. - noise_map - Whether to make a 2D plot (via `imshow`) of the noise map. - psf - Whether to make a 2D plot (via `imshow`) of the psf. - signal_to_noise_map - Whether to make a 2D plot (via `imshow`) of the signal-to-noise map. - over_sample_size_lp - Whether to make a 2D plot (via `imshow`) of the Over Sampling for input light profiles. If - adaptive sub size is used, the sub size grid for a centre of (0.0, 0.0) is used. - over_sample_size_pixelization - Whether to make a 2D plot (via `imshow`) of the Over Sampling for pixelizations. - """ - - if data: - self.mat_plot_2d.plot_array( - array=self.dataset.data, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title=title_str or f" Data", filename="data"), - ) - - if noise_map: - self.mat_plot_2d.plot_array( - array=self.dataset.noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title_str or f"Noise-Map", filename="noise_map"), - ) - - if psf: - if self.dataset.psf is not None: - self.mat_plot_2d.plot_array( - array=self.dataset.psf.kernel, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title=title_str or f"Point Spread Function", - filename="psf", - cb_unit="", - ), - ) - - if signal_to_noise_map: - self.mat_plot_2d.plot_array( - array=self.dataset.signal_to_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title=title_str or f"Signal-To-Noise Map", - filename="signal_to_noise_map", - cb_unit="", - ), - ) - - if over_sample_size_lp: - self.mat_plot_2d.plot_array( - array=self.dataset.grids.over_sample_size_lp, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title=title_str or f"Over Sample Size (Light Profiles)", - filename="over_sample_size_lp", - cb_unit="", - ), - ) - - if over_sample_size_pixelization: - self.mat_plot_2d.plot_array( - array=self.dataset.grids.over_sample_size_pixelization, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title=title_str or f"Over Sample Size (Pixelization)", - filename="over_sample_size_pixelization", - cb_unit="", - ), - ) - - def subplot( - self, - data: bool = False, - noise_map: bool = False, - psf: bool = False, - signal_to_noise_map: bool = False, - over_sampling: bool = False, - over_sampling_pixelization: bool = False, - auto_filename: str = "subplot_dataset", - ): - """ - Plots the individual attributes of the plotter's `Imaging` object in 2D on a subplot. - - The API is such that every plottable attribute of the `Imaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - data - Whether to include a 2D plot (via `imshow`) of the image data. - noise_map - Whether to include a 2D plot (via `imshow`) of the noise map. - psf - Whether to include a 2D plot (via `imshow`) of the psf. - signal_to_noise_map - Whether to include a 2D plot (via `imshow`) of the signal-to-noise map. - over_sampling - Whether to include a 2D plot (via `imshow`) of the Over Sampling. If adaptive sub size is used, the - sub size grid for a centre of (0.0, 0.0) is used. - over_sampling_pixelization - Whether to include a 2D plot (via `imshow`) of the Over Sampling for pixelizations. - auto_filename - The default filename of the output subplot if written to hard-disk. - """ - self._subplot_custom_plot( - data=data, - noise_map=noise_map, - psf=psf, - signal_to_noise_map=signal_to_noise_map, - over_sampling=over_sampling, - over_sampling_pixelization=over_sampling_pixelization, - auto_labels=AutoLabels(filename=auto_filename), - ) - - def subplot_dataset(self): - """ - Standard subplot of the attributes of the plotter's `Imaging` object. - """ - use_log10_original = self.mat_plot_2d.use_log10 - - self.open_subplot_figure(number_subplots=9) - - self.figures_2d(data=True) - - contour_original = copy.copy(self.mat_plot_2d.contour) - - self.mat_plot_2d.use_log10 = True - self.mat_plot_2d.contour = False - self.figures_2d(data=True) - self.mat_plot_2d.use_log10 = False - self.mat_plot_2d.contour = contour_original - - self.figures_2d(noise_map=True) - - self.figures_2d(psf=True) - - self.mat_plot_2d.use_log10 = True - self.figures_2d(psf=True) - self.mat_plot_2d.use_log10 = False - - self.figures_2d(signal_to_noise_map=True) - - self.figures_2d(over_sample_size_lp=True) - self.figures_2d(over_sample_size_pixelization=True) - - self.mat_plot_2d.output.subplot_to_figure(auto_filename="subplot_dataset") - self.close_subplot_figure() - - self.mat_plot_2d.use_log10 = use_log10_original - - -class ImagingPlotter(AbstractPlotter): - def __init__( - self, - dataset: Imaging, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Imaging` and plotted via the visuals object. - - Parameters - ---------- - imaging - The imaging dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.dataset = dataset - - self._imaging_meta_plotter = ImagingPlotterMeta( - dataset=self.dataset, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - ) - - self.figures_2d = self._imaging_meta_plotter.figures_2d - self.subplot = self._imaging_meta_plotter.subplot - self.subplot_dataset = self._imaging_meta_plotter.subplot_dataset diff --git a/autoarray/dataset/plot/interferometer_plots.py b/autoarray/dataset/plot/interferometer_plots.py new file mode 100644 index 000000000..7841ea50e --- /dev/null +++ b/autoarray/dataset/plot/interferometer_plots.py @@ -0,0 +1,142 @@ +import numpy as np +from typing import Optional + +import matplotlib.pyplot as plt + +from autoarray.plot.array import plot_array +from autoarray.plot.grid import plot_grid +from autoarray.plot.yx import plot_yx +from autoarray.plot.utils import subplot_save +from autoarray.structures.grids.irregular_2d import Grid2DIrregular + + +def subplot_interferometer_dataset( + dataset, + output_path: Optional[str] = None, + output_filename: str = "subplot_dataset", + output_format: str = "png", + colormap=None, + use_log10: bool = False, +): + """ + 2×3 subplot of interferometer dataset components. + + Panels: Visibilities | UV-Wavelengths | Amplitudes vs UV-distances | + Phases vs UV-distances | Dirty Image | Dirty S/N Map + + Parameters + ---------- + dataset + An ``Interferometer`` dataset instance. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename without extension. + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation to image panels. + """ + fig, axes = plt.subplots(2, 3, figsize=(21, 14)) + axes = axes.flatten() + + plot_grid(dataset.data.in_grid, ax=axes[0], title="Visibilities") + plot_grid( + Grid2DIrregular.from_yx_1d( + y=dataset.uv_wavelengths[:, 1] / 10**3.0, + x=dataset.uv_wavelengths[:, 0] / 10**3.0, + ), + ax=axes[1], + title="UV-Wavelengths", + ) + plot_yx( + dataset.amplitudes, + dataset.uv_distances / 10**3.0, + ax=axes[2], + title="Amplitudes vs UV-distances", + ylabel="Jy", + xlabel="k$\\lambda$", + plot_axis_type="scatter", + ) + plot_yx( + dataset.phases, + dataset.uv_distances / 10**3.0, + ax=axes[3], + title="Phases vs UV-distances", + ylabel="deg", + xlabel="k$\\lambda$", + plot_axis_type="scatter", + ) + plot_array( + dataset.dirty_image, + ax=axes[4], + title="Dirty Image", + colormap=colormap, + use_log10=use_log10, + ) + plot_array( + dataset.dirty_signal_to_noise_map, + ax=axes[5], + title="Dirty Signal-To-Noise Map", + colormap=colormap, + use_log10=use_log10, + ) + + plt.tight_layout() + subplot_save(fig, output_path, output_filename, output_format) + + +def subplot_interferometer_dirty_images( + dataset, + output_path: Optional[str] = None, + output_filename: str = "subplot_dirty_images", + output_format: str = "png", + colormap=None, + use_log10: bool = False, +): + """ + 1×3 subplot of dirty image, dirty noise map, and dirty S/N map. + + Parameters + ---------- + dataset + An ``Interferometer`` dataset instance. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename without extension. + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation. + """ + fig, axes = plt.subplots(1, 3, figsize=(21, 7)) + + plot_array( + dataset.dirty_image, + ax=axes[0], + title="Dirty Image", + colormap=colormap, + use_log10=use_log10, + ) + plot_array( + dataset.dirty_noise_map, + ax=axes[1], + title="Dirty Noise Map", + colormap=colormap, + use_log10=use_log10, + ) + plot_array( + dataset.dirty_signal_to_noise_map, + ax=axes[2], + title="Dirty Signal-To-Noise Map", + colormap=colormap, + use_log10=use_log10, + ) + + plt.tight_layout() + subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/dataset/plot/interferometer_plotters.py b/autoarray/dataset/plot/interferometer_plotters.py deleted file mode 100644 index e69f53d38..000000000 --- a/autoarray/dataset/plot/interferometer_plotters.py +++ /dev/null @@ -1,284 +0,0 @@ -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.one_d import MatPlot1D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.dataset.interferometer.dataset import Interferometer -from autoarray.structures.grids.irregular_2d import Grid2DIrregular - - -class InterferometerPlotter(AbstractPlotter): - def __init__( - self, - dataset: Interferometer, - mat_plot_1d: MatPlot1D = None, - visuals_1d: Visuals1D = None, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots the attributes of `Interferometer` objects using the matplotlib methods `plot()`, `scatter()` and - `imshow()` and other matplotlib functions which customize the plot's appearance. - - The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, - the settings passed to every matplotlib function called are those specified in - the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2d` to - customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `LightProfile` and plotted via the visuals object. - - Parameters - ---------- - dataset - The interferometer dataset the plotter plots. - mat_plot_1d - Contains objects which wrap the matplotlib function calls that make 1D plots. - visuals_1d - Contains 1D visuals that can be overlaid on 1D plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - self.dataset = dataset - - super().__init__( - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - @property - def interferometer(self): - return self.dataset - - def figures_2d( - self, - data: bool = False, - noise_map: bool = False, - u_wavelengths: bool = False, - v_wavelengths: bool = False, - uv_wavelengths: bool = False, - amplitudes_vs_uv_distances: bool = False, - phases_vs_uv_distances: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - ): - """ - Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - data - Whether to make a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to make a 2D plot (via `scatter`) of the noise-map. - u_wavelengths - Whether to make a 1D plot (via `plot`) of the u-wavelengths. - v_wavelengths - Whether to make a 1D plot (via `plot`) of the v-wavelengths. - amplitudes_vs_uv_distances - Whether to make a 1D plot (via `plot`) of the amplitudes versis the uv distances. - phases_vs_uv_distances - Whether to make a 1D plot (via `plot`) of the phases versis the uv distances. - dirty_image - Whether to make a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty noise map. - dirty_signal_to_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty signal-to-noise map. - """ - - if data: - self.mat_plot_2d.plot_grid( - grid=self.dataset.data.in_grid, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Visibilities", filename="data"), - ) - - if noise_map: - self.mat_plot_2d.plot_grid( - grid=self.dataset.data.in_grid, - visuals_2d=self.visuals_2d, - color_array=self.dataset.noise_map.real, - auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), - ) - - if u_wavelengths: - self.mat_plot_1d.plot_yx( - y=self.dataset.uv_wavelengths[:, 0], - x=None, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="U-Wavelengths", - filename="u_wavelengths", - ylabel="Wavelengths", - ), - plot_axis_type_override="linear", - ) - - if v_wavelengths: - self.mat_plot_1d.plot_yx( - y=self.dataset.uv_wavelengths[:, 1], - x=None, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="V-Wavelengths", - filename="v_wavelengths", - ylabel="Wavelengths", - ), - plot_axis_type_override="linear", - ) - - if uv_wavelengths: - self.mat_plot_2d.plot_grid( - grid=Grid2DIrregular.from_yx_1d( - y=self.dataset.uv_wavelengths[:, 1] / 10**3.0, - x=self.dataset.uv_wavelengths[:, 0] / 10**3.0, - ), - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="UV-Wavelengths", filename="uv_wavelengths" - ), - ) - - if amplitudes_vs_uv_distances: - self.mat_plot_1d.plot_yx( - y=self.dataset.amplitudes, - x=self.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Amplitudes vs UV-distances", - filename="amplitudes_vs_uv_distances", - yunit="Jy", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if phases_vs_uv_distances: - self.mat_plot_1d.plot_yx( - y=self.dataset.phases, - x=self.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Phases vs UV-distances", - filename="phases_vs_uv_distances", - yunit="deg", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if dirty_image: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_image, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"), - ) - - if dirty_noise_map: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Dirty Noise Map", filename="dirty_noise_map" - ), - ) - - if dirty_signal_to_noise_map: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_signal_to_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Dirty Signal-To-Noise Map", - filename="dirty_signal_to_noise_map", - ), - ) - - def subplot( - self, - data: bool = False, - noise_map: bool = False, - u_wavelengths: bool = False, - v_wavelengths: bool = False, - uv_wavelengths: bool = False, - amplitudes_vs_uv_distances: bool = False, - phases_vs_uv_distances: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - auto_filename: str = "subplot_dataset", - ): - """ - Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D on a subplot. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - data - Whether to include a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to include a 2D plot (via `scatter`) of the noise-map. - u_wavelengths - Whether to include a 1D plot (via `plot`) of the u-wavelengths. - v_wavelengths - Whether to include a 1D plot (via `plot`) of the v-wavelengths. - amplitudes_vs_uv_distances - Whether to include a 1D plot (via `plot`) of the amplitudes versis the uv distances. - phases_vs_uv_distances - Whether to include a 1D plot (via `plot`) of the phases versis the uv distances. - dirty_image - Whether to include a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to include a 2D plot (via `imshow`) of the dirty noise map. - dirty_signal_to_noise_map - Whether to include a 2D plot (via `imshow`) of the dirty signal-to-noise map. - """ - self._subplot_custom_plot( - data=data, - noise_map=noise_map, - u_wavelengths=u_wavelengths, - v_wavelengths=v_wavelengths, - uv_wavelengths=uv_wavelengths, - amplitudes_vs_uv_distances=amplitudes_vs_uv_distances, - phases_vs_uv_distances=phases_vs_uv_distances, - dirty_image=dirty_image, - dirty_noise_map=dirty_noise_map, - dirty_signal_to_noise_map=dirty_signal_to_noise_map, - auto_labels=AutoLabels(filename=auto_filename), - ) - - def subplot_dataset(self): - """ - Standard subplot of the attributes of the plotter's `Interferometer` object. - """ - return self.subplot( - data=True, - uv_wavelengths=True, - amplitudes_vs_uv_distances=True, - phases_vs_uv_distances=True, - dirty_image=True, - dirty_signal_to_noise_map=True, - auto_filename="subplot_dataset", - ) - - def subplot_dirty_images(self): - """ - Standard subplot of the dirty attributes of the plotter's `Interferometer` object. - """ - return self.subplot( - dirty_image=True, - dirty_noise_map=True, - dirty_signal_to_noise_map=True, - auto_filename="subplot_dirty_images", - ) diff --git a/autoarray/fit/plot/__init__.py b/autoarray/fit/plot/__init__.py index e69de29bb..1279adc32 100644 --- a/autoarray/fit/plot/__init__.py +++ b/autoarray/fit/plot/__init__.py @@ -0,0 +1,5 @@ +from autoarray.fit.plot.fit_imaging_plots import subplot_fit_imaging +from autoarray.fit.plot.fit_interferometer_plots import ( + subplot_fit_interferometer, + subplot_fit_interferometer_dirty_images, +) diff --git a/autoarray/fit/plot/fit_imaging_plots.py b/autoarray/fit/plot/fit_imaging_plots.py new file mode 100644 index 000000000..4a314b4c8 --- /dev/null +++ b/autoarray/fit/plot/fit_imaging_plots.py @@ -0,0 +1,122 @@ +from typing import Optional + +import matplotlib.pyplot as plt + +from autoarray.plot.array import plot_array +from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax + + +def subplot_fit_imaging( + fit, + output_path: Optional[str] = None, + output_filename: str = "subplot_fit", + output_format: str = "png", + colormap=None, + use_log10: bool = False, + residuals_symmetric_cmap: bool = True, + grid=None, + positions=None, + lines=None, +): + """ + 2×3 subplot of ``FitImaging`` components. + + Panels: Data | S/N Map | Model Image | Residual Map | Norm Residual Map | Chi-Squared Map + + Parameters + ---------- + fit + A ``FitImaging`` instance. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename without extension. + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation to non-residual panels. + residuals_symmetric_cmap + Centre residual / normalised-residual colour scale symmetrically + around zero. + grid, positions, lines + Optional overlays forwarded to every panel. + """ + fig, axes = plt.subplots(2, 3, figsize=(21, 14)) + axes = axes.flatten() + + plot_array( + fit.data, + ax=axes[0], + title="Data", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + fit.signal_to_noise_map, + ax=axes[1], + title="Signal-To-Noise Map", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + fit.model_data, + ax=axes[2], + title="Model Image", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + + if residuals_symmetric_cmap: + vmin_r, vmax_r = symmetric_vmin_vmax(fit.residual_map) + vmin_n, vmax_n = symmetric_vmin_vmax(fit.normalized_residual_map) + else: + vmin_r = vmax_r = vmin_n = vmax_n = None + + plot_array( + fit.residual_map, + ax=axes[3], + title="Residual Map", + colormap=colormap, + use_log10=False, + vmin=vmin_r, + vmax=vmax_r, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + fit.normalized_residual_map, + ax=axes[4], + title="Normalized Residual Map", + colormap=colormap, + use_log10=False, + vmin=vmin_n, + vmax=vmax_n, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + fit.chi_squared_map, + ax=axes[5], + title="Chi-Squared Map", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + + plt.tight_layout() + subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/fit/plot/fit_imaging_plotters.py b/autoarray/fit/plot/fit_imaging_plotters.py deleted file mode 100644 index 86aa0d34d..000000000 --- a/autoarray/fit/plot/fit_imaging_plotters.py +++ /dev/null @@ -1,269 +0,0 @@ -from typing import Callable - -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.fit.fit_imaging import FitImaging - - -class FitImagingPlotterMeta(AbstractPlotter): - def __init__( - self, - fit, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - residuals_symmetric_cmap: bool = True, - ): - """ - Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to an imaging dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make the plot. - visuals_2d - Contains visuals that can be overlaid on the plot. - residuals_symmetric_cmap - If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such - that `abs(vmin) = abs(vmax)`. - """ - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.fit = fit - self.residuals_symmetric_cmap = residuals_symmetric_cmap - - def figures_2d( - self, - data: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - model_image: bool = False, - residual_map: bool = False, - normalized_residual_map: bool = False, - chi_squared_map: bool = False, - residual_flux_fraction_map: bool = False, - suffix: str = "", - ): - """ - Plots the individual attributes of the plotter's `FitImaging` object in 2D. - - The API is such that every plottable attribute of the `FitImaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - data - Whether to make a 2D plot (via `imshow`) of the image data. - noise_map - Whether to make a 2D plot (via `imshow`) of the noise map. - signal_to_noise_map - Whether to make a 2D plot (via `imshow`) of the signal-to-noise map. - model_image - Whether to make a 2D plot (via `imshow`) of the model image. - residual_map - Whether to make a 2D plot (via `imshow`) of the residual map. - normalized_residual_map - Whether to make a 2D plot (via `imshow`) of the normalized residual map. - chi_squared_map - Whether to make a 2D plot (via `imshow`) of the chi-squared map. - residual_flux_fraction_map - Whether to make a 2D plot (via `imshow`) of the residual flux fraction map. - """ - - if data: - self.mat_plot_2d.plot_array( - array=self.fit.data, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Data", filename=f"data{suffix}"), - ) - - if noise_map: - self.mat_plot_2d.plot_array( - array=self.fit.noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Noise-Map", filename=f"noise_map{suffix}" - ), - ) - - if signal_to_noise_map: - self.mat_plot_2d.plot_array( - array=self.fit.signal_to_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Signal-To-Noise Map", filename=f"signal_to_noise_map{suffix}" - ), - ) - - if model_image: - self.mat_plot_2d.plot_array( - array=self.fit.model_data, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Model Image", filename=f"model_image{suffix}" - ), - ) - - cmap_original = self.mat_plot_2d.cmap - - if self.residuals_symmetric_cmap: - self.mat_plot_2d.cmap = self.mat_plot_2d.cmap.symmetric_cmap_from() - - if residual_map: - self.mat_plot_2d.plot_array( - array=self.fit.residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Residual Map", filename=f"residual_map{suffix}" - ), - ) - - if normalized_residual_map: - self.mat_plot_2d.plot_array( - array=self.fit.normalized_residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Normalized Residual Map", - filename=f"normalized_residual_map{suffix}", - ), - ) - - self.mat_plot_2d.cmap = cmap_original - - if chi_squared_map: - self.mat_plot_2d.plot_array( - array=self.fit.chi_squared_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Chi-Squared Map", filename=f"chi_squared_map{suffix}" - ), - ) - - if residual_flux_fraction_map: - self.mat_plot_2d.plot_array( - array=self.fit.residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Residual Flux Fraction Map", - filename=f"residual_flux_fraction_map{suffix}", - ), - ) - - def subplot( - self, - data: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - model_image: bool = False, - residual_map: bool = False, - normalized_residual_map: bool = False, - chi_squared_map: bool = False, - residual_flux_fraction_map: bool = False, - auto_filename: str = "subplot_fit", - ): - """ - Plots the individual attributes of the plotter's `FitImaging` object in 2D on a subplot. - - The API is such that every plottable attribute of the `FitImaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - data - Whether to include a 2D plot (via `imshow`) of the image data. - noise_map - Whether to include a 2D plot (via `imshow`) of the noise map. - psf - Whether to include a 2D plot (via `imshow`) of the psf. - signal_to_noise_map - Whether to include a 2D plot (via `imshow`) of the signal-to-noise map. - model_image - Whether to include a 2D plot (via `imshow`) of the model image. - residual_map - Whether to include a 2D plot (via `imshow`) of the residual map. - normalized_residual_map - Whether to include a 2D plot (via `imshow`) of the normalized residual map. - chi_squared_map - Whether to include a 2D plot (via `imshow`) of the chi-squared map. - residual_flux_fraction_map - Whether to include a 2D plot (via `imshow`) of the residual flux fraction map. - auto_filename - The default filename of the output subplot if written to hard-disk. - """ - self._subplot_custom_plot( - data=data, - noise_map=noise_map, - signal_to_noise_map=signal_to_noise_map, - model_image=model_image, - residual_map=residual_map, - normalized_residual_map=normalized_residual_map, - chi_squared_map=chi_squared_map, - residual_flux_fraction_map=residual_flux_fraction_map, - auto_labels=AutoLabels(filename=auto_filename), - ) - - def subplot_fit(self): - """ - Standard subplot of the attributes of the plotter's `FitImaging` object. - """ - return self.subplot( - data=True, - signal_to_noise_map=True, - model_image=True, - residual_map=True, - normalized_residual_map=True, - chi_squared_map=True, - ) - - -class FitImagingPlotter(AbstractPlotter): - def __init__( - self, - fit: FitImaging, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to an imaging dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make the plot. - visuals_2d - Contains visuals that can be overlaid on the plot. - """ - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.fit = fit - - self._fit_imaging_meta_plotter = FitImagingPlotterMeta( - fit=self.fit, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - ) - - self.figures_2d = self._fit_imaging_meta_plotter.figures_2d - self.subplot = self._fit_imaging_meta_plotter.subplot - self.subplot_fit = self._fit_imaging_meta_plotter.subplot_fit diff --git a/autoarray/fit/plot/fit_interferometer_plots.py b/autoarray/fit/plot/fit_interferometer_plots.py new file mode 100644 index 000000000..aaceec086 --- /dev/null +++ b/autoarray/fit/plot/fit_interferometer_plots.py @@ -0,0 +1,195 @@ +import numpy as np +from typing import Optional + +import matplotlib.pyplot as plt + +from autoarray.plot.array import plot_array +from autoarray.plot.yx import plot_yx +from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax + + +def subplot_fit_interferometer( + fit, + output_path: Optional[str] = None, + output_filename: str = "subplot_fit", + output_format: str = "png", + colormap=None, + use_log10: bool = False, + residuals_symmetric_cmap: bool = True, +): + """ + 2×3 subplot of ``FitInterferometer`` residuals in UV-plane. + + Panels (real then imaginary): Residual Map | Norm Residual Map | Chi-Squared Map + + Parameters + ---------- + fit + A ``FitInterferometer`` instance. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename without extension. + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation. + residuals_symmetric_cmap + Not used here (UV-plane residuals are scatter plots); kept for API + consistency. + """ + fig, axes = plt.subplots(2, 3, figsize=(21, 14)) + axes = axes.flatten() + + uv = fit.dataset.uv_distances / 10**3.0 + + plot_yx( + np.real(fit.residual_map), + uv, + ax=axes[0], + title="Residual vs UV-Distance (real)", + xlabel="k$\\lambda$", + plot_axis_type="scatter", + ) + plot_yx( + np.real(fit.normalized_residual_map), + uv, + ax=axes[1], + title="Norm Residual vs UV-Distance (real)", + ylabel="$\\sigma$", + xlabel="k$\\lambda$", + plot_axis_type="scatter", + ) + plot_yx( + np.real(fit.chi_squared_map), + uv, + ax=axes[2], + title="Chi-Squared vs UV-Distance (real)", + ylabel="$\\chi^2$", + xlabel="k$\\lambda$", + plot_axis_type="scatter", + ) + plot_yx( + np.imag(fit.residual_map), + uv, + ax=axes[3], + title="Residual vs UV-Distance (imag)", + xlabel="k$\\lambda$", + plot_axis_type="scatter", + ) + plot_yx( + np.imag(fit.normalized_residual_map), + uv, + ax=axes[4], + title="Norm Residual vs UV-Distance (imag)", + ylabel="$\\sigma$", + xlabel="k$\\lambda$", + plot_axis_type="scatter", + ) + plot_yx( + np.imag(fit.chi_squared_map), + uv, + ax=axes[5], + title="Chi-Squared vs UV-Distance (imag)", + ylabel="$\\chi^2$", + xlabel="k$\\lambda$", + plot_axis_type="scatter", + ) + + plt.tight_layout() + subplot_save(fig, output_path, output_filename, output_format) + + +def subplot_fit_interferometer_dirty_images( + fit, + output_path: Optional[str] = None, + output_filename: str = "subplot_fit_dirty_images", + output_format: str = "png", + colormap=None, + use_log10: bool = False, + residuals_symmetric_cmap: bool = True, +): + """ + 2×3 subplot of ``FitInterferometer`` dirty-image components. + + Panels: Dirty Image | Dirty S/N Map | Dirty Model Image | + Dirty Residual Map | Dirty Norm Residual Map | Dirty Chi-Squared Map + + Parameters + ---------- + fit + A ``FitInterferometer`` instance. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename without extension. + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation to non-residual panels. + residuals_symmetric_cmap + Centre residual colour scale symmetrically around zero. + """ + fig, axes = plt.subplots(2, 3, figsize=(21, 14)) + axes = axes.flatten() + + plot_array( + fit.dirty_image, + ax=axes[0], + title="Dirty Image", + colormap=colormap, + use_log10=use_log10, + ) + plot_array( + fit.dirty_signal_to_noise_map, + ax=axes[1], + title="Dirty Signal-To-Noise Map", + colormap=colormap, + use_log10=use_log10, + ) + plot_array( + fit.dirty_model_image, + ax=axes[2], + title="Dirty Model Image", + colormap=colormap, + use_log10=use_log10, + ) + + if residuals_symmetric_cmap: + vmin_r, vmax_r = symmetric_vmin_vmax(fit.dirty_residual_map) + vmin_n, vmax_n = symmetric_vmin_vmax(fit.dirty_normalized_residual_map) + else: + vmin_r = vmax_r = vmin_n = vmax_n = None + + plot_array( + fit.dirty_residual_map, + ax=axes[3], + title="Dirty Residual Map", + colormap=colormap, + use_log10=False, + vmin=vmin_r, + vmax=vmax_r, + ) + plot_array( + fit.dirty_normalized_residual_map, + ax=axes[4], + title="Dirty Normalized Residual Map", + colormap=colormap, + use_log10=False, + vmin=vmin_n, + vmax=vmax_n, + ) + plot_array( + fit.dirty_chi_squared_map, + ax=axes[5], + title="Dirty Chi-Squared Map", + colormap=colormap, + use_log10=use_log10, + ) + + plt.tight_layout() + subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/fit/plot/fit_interferometer_plotters.py b/autoarray/fit/plot/fit_interferometer_plotters.py deleted file mode 100644 index 3ab2bd1e6..000000000 --- a/autoarray/fit/plot/fit_interferometer_plotters.py +++ /dev/null @@ -1,489 +0,0 @@ -import numpy as np - -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.one_d import MatPlot1D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.fit.fit_interferometer import FitInterferometer - - -class FitInterferometerPlotterMeta(AbstractPlotter): - def __init__( - self, - fit, - mat_plot_1d: MatPlot1D, - visuals_1d: Visuals1D, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - residuals_symmetric_cmap: bool = True, - ): - """ - Plots the attributes of `FitInterferometer` objects using the matplotlib method `imshow()` and many - other matplotlib functions which customize the plot's appearance. - - The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, - the settings passed to every matplotlib function called are those specified in - the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2d` to - customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `FitInterferometer` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to an interferometer dataset the plotter plots. - mat_plot_1d - Contains objects which wrap the matplotlib function calls that make 1D plots. - visuals_1d - Contains 1D visuals that can be overlaid on 1D plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - residuals_symmetric_cmap - If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such - that `abs(vmin) = abs(vmax)`. - """ - super().__init__( - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - self.fit = fit - self.residuals_symmetric_cmap = residuals_symmetric_cmap - - def figures_2d( - self, - data: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - amplitudes_vs_uv_distances: bool = False, - model_data: bool = False, - residual_map_real: bool = False, - residual_map_imag: bool = False, - normalized_residual_map_real: bool = False, - normalized_residual_map_imag: bool = False, - chi_squared_map_real: bool = False, - chi_squared_map_imag: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - dirty_model_image: bool = False, - dirty_residual_map: bool = False, - dirty_normalized_residual_map: bool = False, - dirty_chi_squared_map: bool = False, - ): - """ - Plots the individual attributes of the plotter's `FitInterferometer` object in 1D and 2D. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - data - Whether to make a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to make a 2D plot (via `scatter`) of the noise-map. - signal_to_noise_map - Whether to make a 2D plot (via `scatter`) of the signal-to-noise-map. - model_data - Whether to make a 2D plot (via `scatter`) of the model visibility data. - residual_map_real - Whether to make a 1D plot (via `plot`) of the real component of the residual map. - residual_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the residual map. - normalized_residual_map_real - Whether to make a 1D plot (via `plot`) of the real component of the normalized residual map. - normalized_residual_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the normalized residual map. - chi_squared_map_real - Whether to make a 1D plot (via `plot`) of the real component of the chi-squared map. - chi_squared_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the chi-squared map. - dirty_image - Whether to make a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty noise map. - dirty_model_image - Whether to make a 2D plot (via `imshow`) of the dirty model image. - dirty_residual_map - Whether to make a 2D plot (via `imshow`) of the dirty residual map. - dirty_normalized_residual_map - Whether to make a 2D plot (via `imshow`) of the dirty normalized residual map. - dirty_chi_squared_map - Whether to make a 2D plot (via `imshow`) of the dirty chi-squared map. - """ - - if data: - self.mat_plot_2d.plot_grid( - grid=self.fit.data.in_grid, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Visibilities", filename="data"), - color_array=np.real(self.fit.noise_map), - ) - - if noise_map: - self.mat_plot_2d.plot_grid( - grid=self.fit.data.in_grid, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), - color_array=np.real(self.fit.noise_map), - ) - - if signal_to_noise_map: - self.mat_plot_2d.plot_grid( - grid=self.fit.data.in_grid, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Signal-To-Noise Map", filename="signal_to_noise_map" - ), - color_array=np.real(self.fit.signal_to_noise_map), - ) - - if amplitudes_vs_uv_distances: - self.mat_plot_1d.plot_yx( - y=self.fit.dataset.amplitudes, - x=self.fit.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Amplitudes vs UV-distances", - filename="amplitudes_vs_uv_distances", - yunit="Jy", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if model_data: - self.mat_plot_2d.plot_grid( - grid=self.fit.data.in_grid, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Model Visibilities", filename="model_data" - ), - color_array=np.real(self.fit.model_data.array), - ) - - if residual_map_real: - self.mat_plot_1d.plot_yx( - y=np.real(self.fit.residual_map), - x=self.fit.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Residual vs UV-Distance (real)", - filename="real_residual_map_vs_uv_distances", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - if residual_map_imag: - self.mat_plot_1d.plot_yx( - y=np.imag(self.fit.residual_map), - x=self.fit.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Residual vs UV-Distance (imag)", - filename="imag_residual_map_vs_uv_distances", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if normalized_residual_map_real: - self.mat_plot_1d.plot_yx( - y=np.real(self.fit.normalized_residual_map), - x=self.fit.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Norm Residual vs UV-Distance (real)", - filename="real_normalized_residual_map_vs_uv_distances", - yunit="$\sigma$", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - if normalized_residual_map_imag: - self.mat_plot_1d.plot_yx( - y=np.imag(self.fit.normalized_residual_map), - x=self.fit.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Norm Residual vs UV-Distance (imag)", - filename="imag_normalized_residual_map_vs_uv_distances", - yunit="$\sigma$", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if chi_squared_map_real: - self.mat_plot_1d.plot_yx( - y=np.real(self.fit.chi_squared_map), - x=self.fit.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Chi-Squared vs UV-Distance (real)", - filename="real_chi_squared_map_vs_uv_distances", - yunit="$\chi^2$", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - if chi_squared_map_imag: - self.mat_plot_1d.plot_yx( - y=np.imag(self.fit.chi_squared_map), - x=self.fit.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Chi-Squared vs UV-Distance (imag)", - filename="imag_chi_squared_map_vs_uv_distances", - yunit="$\chi^2$", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if dirty_image: - self.mat_plot_2d.plot_array( - array=self.fit.dirty_image, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"), - ) - - if dirty_noise_map: - self.mat_plot_2d.plot_array( - array=self.fit.dirty_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Dirty Noise Map", filename="dirty_noise_map" - ), - ) - - if dirty_signal_to_noise_map: - self.mat_plot_2d.plot_array( - array=self.fit.dirty_signal_to_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Dirty Signal-To-Noise Map", - filename="dirty_signal_to_noise_map", - ), - ) - - if dirty_model_image: - self.mat_plot_2d.plot_array( - array=self.fit.dirty_model_image, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Dirty Model Image", filename="dirty_model_image_2d" - ), - ) - - cmap_original = self.mat_plot_2d.cmap - - if self.residuals_symmetric_cmap: - self.mat_plot_2d.cmap = self.mat_plot_2d.cmap.symmetric_cmap_from() - - if dirty_residual_map: - self.mat_plot_2d.plot_array( - array=self.fit.dirty_residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Dirty Residual Map", filename="dirty_residual_map_2d" - ), - ) - - if dirty_normalized_residual_map: - self.mat_plot_2d.plot_array( - array=self.fit.dirty_normalized_residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Dirty Normalized Residual Map", - filename="dirty_normalized_residual_map_2d", - ), - ) - - if self.residuals_symmetric_cmap: - self.mat_plot_2d.cmap = cmap_original - - if dirty_chi_squared_map: - self.mat_plot_2d.plot_array( - array=self.fit.dirty_chi_squared_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Dirty Chi-Squared Map", filename="dirty_chi_squared_map_2d" - ), - ) - - def subplot( - self, - data: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - model_data: bool = False, - residual_map_real: bool = False, - residual_map_imag: bool = False, - normalized_residual_map_real: bool = False, - normalized_residual_map_imag: bool = False, - chi_squared_map_real: bool = False, - chi_squared_map_imag: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - dirty_model_image: bool = False, - dirty_residual_map: bool = False, - dirty_normalized_residual_map: bool = False, - dirty_chi_squared_map: bool = False, - auto_filename: str = "subplot_fit", - ): - """ - Plots the individual attributes of the plotter's `FitInterferometer` object in 1D and 2D on a subplot. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - data - Whether to make a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to make a 2D plot (via `scatter`) of the noise-map. - signal_to_noise_map - Whether to make a 2D plot (via `scatter`) of the signal-to-noise-map. - model_data - Whether to make a 2D plot (via `scatter`) of the model visibility data. - residual_map_real - Whether to make a 1D plot (via `plot`) of the real component of the residual map. - residual_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the residual map. - normalized_residual_map_real - Whether to make a 1D plot (via `plot`) of the real component of the normalized residual map. - normalized_residual_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the normalized residual map. - chi_squared_map_real - Whether to make a 1D plot (via `plot`) of the real component of the chi-squared map. - chi_squared_map_imag - Whether to make a 1D plot (via `plot`) of the imaginary component of the chi-squared map. - dirty_image - Whether to make a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty noise map. - dirty_model_image - Whether to make a 2D plot (via `imshow`) of the dirty model image. - dirty_residual_map - Whether to make a 2D plot (via `imshow`) of the dirty residual map. - dirty_normalized_residual_map - Whether to make a 2D plot (via `imshow`) of the dirty normalized residual map. - dirty_chi_squared_map - Whether to make a 2D plot (via `imshow`) of the dirty chi-squared map. - auto_filename - The default filename of the output subplot if written to hard-disk. - """ - - self._subplot_custom_plot( - visibilities=data, - noise_map=noise_map, - signal_to_noise_map=signal_to_noise_map, - model_data=model_data, - residual_map_real=residual_map_real, - residual_map_imag=residual_map_imag, - normalized_residual_map_real=normalized_residual_map_real, - normalized_residual_map_imag=normalized_residual_map_imag, - chi_squared_map_real=chi_squared_map_real, - chi_squared_map_imag=chi_squared_map_imag, - dirty_image=dirty_image, - dirty_noise_map=dirty_noise_map, - dirty_signal_to_noise_map=dirty_signal_to_noise_map, - dirty_model_image=dirty_model_image, - dirty_residual_map=dirty_residual_map, - dirty_normalized_residual_map=dirty_normalized_residual_map, - dirty_chi_squared_map=dirty_chi_squared_map, - auto_labels=AutoLabels(filename=auto_filename), - ) - - def subplot_fit(self): - """ - Standard subplot of the attributes of the plotter's `FitInterferometer` object. - """ - return self.subplot( - residual_map_real=True, - normalized_residual_map_real=True, - chi_squared_map_real=True, - residual_map_imag=True, - normalized_residual_map_imag=True, - chi_squared_map_imag=True, - auto_filename="subplot_fit", - ) - - def subplot_fit_dirty_images(self): - """ - Standard subplot of the dirty attributes of the plotter's `FitInterferometer` object. - """ - return self.subplot( - dirty_image=True, - dirty_signal_to_noise_map=True, - dirty_model_image=True, - dirty_residual_map=True, - dirty_normalized_residual_map=True, - dirty_chi_squared_map=True, - auto_filename="subplot_fit_dirty_images", - ) - - -class FitInterferometerPlotter(AbstractPlotter): - def __init__( - self, - fit: FitInterferometer, - mat_plot_1d: MatPlot1D = None, - visuals_1d: Visuals1D = None, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots the attributes of `FitInterferometer` objects using the matplotlib method `imshow()` and many other - matplotlib functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitInterferometer` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to an interferometer dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make the plot. - visuals_2d - Contains visuals that can be overlaid on the plot. - """ - super().__init__( - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - visuals_2d=visuals_2d, - ) - - self.fit = fit - - self._fit_interferometer_meta_plotter = FitInterferometerPlotterMeta( - fit=self.fit, - mat_plot_1d=self.mat_plot_1d, - visuals_1d=self.visuals_1d, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - ) - - self.figures_2d = self._fit_interferometer_meta_plotter.figures_2d - self.subplot = self._fit_interferometer_meta_plotter.subplot - self.subplot_fit = self._fit_interferometer_meta_plotter.subplot_fit - self.subplot_fit_dirty_images = ( - self._fit_interferometer_meta_plotter.subplot_fit_dirty_images - ) diff --git a/autoarray/fit/plot/fit_vector_yx_plotters.py b/autoarray/fit/plot/fit_vector_yx_plotters.py deleted file mode 100644 index 9691e5680..000000000 --- a/autoarray/fit/plot/fit_vector_yx_plotters.py +++ /dev/null @@ -1,233 +0,0 @@ -from typing import Callable - -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.fit.fit_imaging import FitImaging -from autoarray.fit.plot.fit_imaging_plotters import FitImagingPlotterMeta - - -class FitVectorYXPlotterMeta(AbstractPlotter): - def __init__( - self, - fit, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to an imaging dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make the plot. - visuals_2d - Contains visuals that can be overlaid on the plot. - """ - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.fit = fit - - def figures_2d( - self, - image: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - model_image: bool = False, - residual_map: bool = False, - normalized_residual_map: bool = False, - chi_squared_map: bool = False, - ): - """ - Plots the individual attributes of the plotter's `FitImaging` object in 2D. - - The API is such that every plottable attribute of the `FitImaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - image - Whether to make a 2D plot (via `imshow`) of the image data. - noise_map - Whether to make a 2D plot (via `imshow`) of the noise map. - psf - Whether to make a 2D plot (via `imshow`) of the psf. - signal_to_noise_map - Whether to make a 2D plot (via `imshow`) of the signal-to-noise map. - residual_map - Whether to make a 2D plot (via `imshow`) of the residual map. - normalized_residual_map - Whether to make a 2D plot (via `imshow`) of the normalized residual map. - chi_squared_map - Whether to make a 2D plot (via `imshow`) of the chi-squared map. - """ - - if image: - self.mat_plot_2d.plot_array( - array=self.fit.data, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Data", filename="image_2d"), - ) - - if noise_map: - self.mat_plot_2d.plot_array( - array=self.fit.noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), - ) - - if signal_to_noise_map: - self.mat_plot_2d.plot_array( - array=self.fit.signal_to_noise_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Signal-To-Noise Map", filename="signal_to_noise_map" - ), - ) - - if model_image: - self.mat_plot_2d.plot_array( - array=self.fit.model_data, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Model Image", filename="model_image"), - ) - - if residual_map: - self.mat_plot_2d.plot_array( - array=self.fit.residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Residual Map", filename="residual_map"), - ) - - if normalized_residual_map: - self.mat_plot_2d.plot_array( - array=self.fit.normalized_residual_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Normalized Residual Map", filename="normalized_residual_map" - ), - ) - - if chi_squared_map: - self.mat_plot_2d.plot_array( - array=self.fit.chi_squared_map, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Chi-Squared Map", filename="chi_squared_map" - ), - ) - - def subplot( - self, - image: bool = False, - noise_map: bool = False, - signal_to_noise_map: bool = False, - model_image: bool = False, - residual_map: bool = False, - normalized_residual_map: bool = False, - chi_squared_map: bool = False, - auto_filename: str = "subplot_fit", - ): - """ - Plots the individual attributes of the plotter's `FitImaging` object in 2D on a subplot. - - The API is such that every plottable attribute of the `FitImaging` object is an input parameter of type bool of - the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - image - Whether to include a 2D plot (via `imshow`) of the image data. - noise_map - Whether to include a 2D plot (via `imshow`) of the noise map. - psf - Whether to include a 2D plot (via `imshow`) of the psf. - signal_to_noise_map - Whether to include a 2D plot (via `imshow`) of the signal-to-noise map. - model_image - Whether to include a 2D plot (via `imshow`) of the model image. - residual_map - Whether to include a 2D plot (via `imshow`) of the residual map. - normalized_residual_map - Whether to include a 2D plot (via `imshow`) of the normalized residual map. - chi_squared_map - Whether to include a 2D plot (via `imshow`) of the chi-squared map. - auto_filename - The default filename of the output subplot if written to hard-disk. - """ - self._subplot_custom_plot( - image=image, - noise_map=noise_map, - signal_to_noise_map=signal_to_noise_map, - model_image=model_image, - residual_map=residual_map, - normalized_residual_map=normalized_residual_map, - chi_squared_map=chi_squared_map, - auto_labels=AutoLabels(filename=auto_filename), - ) - - def subplot_fit(self): - """ - Standard subplot of the attributes of the plotter's `FitImaging` object. - """ - return self.subplot( - image=True, - signal_to_noise_map=True, - model_image=True, - residual_map=True, - normalized_residual_map=True, - chi_squared_map=True, - ) - - -class FitImagingPlotter(AbstractPlotter): - def __init__( - self, - fit: FitImaging, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object. - - Parameters - ---------- - fit - The fit to an imaging dataset the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make the plot. - visuals_2d - Contains visuals that can be overlaid on the plot. - """ - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.fit = fit - - self._fit_imaging_meta_plotter = FitImagingPlotterMeta( - fit=self.fit, - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - ) - - self.figures_2d = self._fit_imaging_meta_plotter.figures_2d - self.subplot = self._fit_imaging_meta_plotter.subplot - self.subplot_fit = self._fit_imaging_meta_plotter.subplot_fit diff --git a/autoarray/inversion/mesh/interpolator/delaunay.py b/autoarray/inversion/mesh/interpolator/delaunay.py index b954ae236..5d2410370 100644 --- a/autoarray/inversion/mesh/interpolator/delaunay.py +++ b/autoarray/inversion/mesh/interpolator/delaunay.py @@ -91,6 +91,7 @@ def jax_delaunay(points, query_points, areas_factor=0.5): ), points, query_points, + vmap_method="sequential", ) @@ -248,6 +249,7 @@ def jax_delaunay_matern(points, query_points): (points_shape, simplices_padded_shape, mappings_shape), points, query_points, + vmap_method="sequential", ) diff --git a/autoarray/inversion/plot/__init__.py b/autoarray/inversion/plot/__init__.py index e69de29bb..9c115dc67 100644 --- a/autoarray/inversion/plot/__init__.py +++ b/autoarray/inversion/plot/__init__.py @@ -0,0 +1,8 @@ +from autoarray.inversion.plot.mapper_plots import ( + plot_mapper, + subplot_image_and_mapper, +) +from autoarray.inversion.plot.inversion_plots import ( + subplot_of_mapper, + subplot_mappings, +) diff --git a/autoarray/inversion/plot/inversion_plots.py b/autoarray/inversion/plot/inversion_plots.py new file mode 100644 index 000000000..e3f521a69 --- /dev/null +++ b/autoarray/inversion/plot/inversion_plots.py @@ -0,0 +1,338 @@ +import logging +import numpy as np +from typing import Optional + +import matplotlib.pyplot as plt +from autoconf import conf + +from autoarray.inversion.mappers.abstract import Mapper +from autoarray.plot.array import plot_array +from autoarray.plot.utils import numpy_grid, numpy_lines, numpy_positions, subplot_save +from autoarray.inversion.plot.mapper_plots import plot_mapper +from autoarray.structures.arrays.uniform_2d import Array2D + +logger = logging.getLogger(__name__) + + +def subplot_of_mapper( + inversion, + mapper_index: int = 0, + output_path: Optional[str] = None, + output_filename: str = "subplot_inversion", + output_format: str = "png", + colormap=None, + use_log10: bool = False, + mesh_grid=None, + lines=None, + grid=None, + positions=None, +): + """ + 3×4 subplot showing all pixelization diagnostics for one mapper. + + Parameters + ---------- + inversion + An ``AbstractInversion`` instance. + mapper_index + Which mapper in the inversion to visualise. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename prefix (``_{mapper_index}`` is appended). + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation. + mesh_grid, lines, grid, positions + Optional overlays. + """ + mapper = inversion.cls_list_from(cls=Mapper)[mapper_index] + + fig, axes = plt.subplots(3, 4, figsize=(28, 21)) + axes = axes.flatten() + + # panel 0: data subtracted + try: + plot_array( + inversion.data_subtracted_dict[mapper], + ax=axes[0], + title="Data Subtracted", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + except (AttributeError, KeyError): + pass + + # panels 1-3: reconstructed operated data (plain, log10, + mesh grid overlay) + def _recon_array(): + array = inversion.mapped_reconstructed_operated_data_dict[mapper] + from autoarray.structures.visibilities import Visibilities + + if isinstance(array, Visibilities): + array = inversion.mapped_reconstructed_data_dict[mapper] + return array + + try: + plot_array( + _recon_array(), + ax=axes[1], + title="Reconstructed Image", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + _recon_array(), + ax=axes[2], + title="Reconstructed Image (log10)", + colormap=colormap, + use_log10=True, + grid=grid, + positions=positions, + lines=lines, + ) + plot_array( + _recon_array(), + ax=axes[3], + title="Mesh Pixel Grid Overlaid", + colormap=colormap, + use_log10=use_log10, + grid=numpy_grid(mapper.image_plane_mesh_grid), + positions=positions, + lines=lines, + ) + except (AttributeError, KeyError): + pass + + # panels 4-5: source reconstruction zoomed / unzoomed + pixel_values = inversion.reconstruction_dict[mapper] + plot_mapper( + mapper, + solution_vector=pixel_values, + ax=axes[4], + title="Source Reconstruction", + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=True, + mesh_grid=mesh_grid, + lines=lines, + ) + plot_mapper( + mapper, + solution_vector=pixel_values, + ax=axes[5], + title="Source Reconstruction (Unzoomed)", + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=False, + mesh_grid=mesh_grid, + lines=lines, + ) + + # panel 6: noise map + try: + nm = inversion.reconstruction_noise_map_dict[mapper] + plot_mapper( + mapper, + solution_vector=nm, + ax=axes[6], + title="Noise-Map (Unzoomed)", + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=False, + mesh_grid=mesh_grid, + lines=lines, + ) + except (KeyError, TypeError): + pass + + # panel 7: regularization weights + try: + rw = inversion.regularization_weights_mapper_dict[mapper] + plot_mapper( + mapper, + solution_vector=rw, + ax=axes[7], + title="Regularization Weights (Unzoomed)", + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=False, + mesh_grid=mesh_grid, + lines=lines, + ) + except (IndexError, ValueError, KeyError, TypeError): + pass + + # panel 8: sub pixels per image pixels + try: + sub_size = Array2D( + values=mapper.over_sampler.sub_size, mask=inversion.dataset.mask + ) + plot_array( + sub_size, + ax=axes[8], + title="Sub Pixels Per Image Pixels", + colormap=colormap, + use_log10=use_log10, + ) + except Exception: + pass + + # panel 9: mesh pixels per image pixels + try: + plot_array( + mapper.mesh_pixels_per_image_pixels, + ax=axes[9], + title="Mesh Pixels Per Image Pixels", + colormap=colormap, + use_log10=use_log10, + ) + except Exception: + pass + + # panel 10: image pixels per mesh pixel + try: + pw = mapper.data_weight_total_for_pix_from() + plot_mapper( + mapper, + solution_vector=pw, + ax=axes[10], + title="Image Pixels Per Source Pixel", + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=True, + mesh_grid=mesh_grid, + lines=lines, + ) + except (TypeError, Exception): + pass + + plt.tight_layout() + subplot_save(fig, output_path, f"{output_filename}_{mapper_index}", output_format) + + +def subplot_mappings( + inversion, + pixelization_index: int = 0, + output_path: Optional[str] = None, + output_filename: str = "subplot_mappings", + output_format: str = "png", + colormap=None, + use_log10: bool = False, + mesh_grid=None, + lines=None, + grid=None, + positions=None, +): + """ + 2×2 subplot showing data, model image, reconstruction and unzoomed reconstruction. + + Parameters + ---------- + inversion + An ``AbstractInversion`` instance. + pixelization_index + Which mapper in the inversion to visualise. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename prefix (``_{pixelization_index}`` is appended). + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation. + mesh_grid, lines, grid, positions + Optional overlays. + """ + mapper = inversion.cls_list_from(cls=Mapper)[pixelization_index] + + try: + total_pixels = conf.instance["visualize"]["general"]["inversion"][ + "total_mappings_pixels" + ] + except Exception: + total_pixels = 10 + + pix_indexes = inversion.max_pixel_list_from( + total_pixels=total_pixels, + filter_neighbors=True, + mapper_index=pixelization_index, + ) + mapper.slim_indexes_for_pix_indexes(pix_indexes=pix_indexes) + + fig, axes = plt.subplots(2, 2, figsize=(14, 14)) + axes = axes.flatten() + + # panel 0: data subtracted + try: + plot_array( + inversion.data_subtracted_dict[mapper], + ax=axes[0], + title="Data Subtracted", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + except (AttributeError, KeyError): + pass + + # panel 1: reconstructed operated data + try: + array = inversion.mapped_reconstructed_operated_data_dict[mapper] + from autoarray.structures.visibilities import Visibilities + + if isinstance(array, Visibilities): + array = inversion.mapped_reconstructed_data_dict[mapper] + plot_array( + array, + ax=axes[1], + title="Reconstructed Image", + colormap=colormap, + use_log10=use_log10, + grid=grid, + positions=positions, + lines=lines, + ) + except (AttributeError, KeyError): + pass + + pixel_values = inversion.reconstruction_dict[mapper] + plot_mapper( + mapper, + solution_vector=pixel_values, + ax=axes[2], + title="Source Reconstruction", + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=True, + mesh_grid=mesh_grid, + lines=lines, + ) + plot_mapper( + mapper, + solution_vector=pixel_values, + ax=axes[3], + title="Source Reconstruction (Unzoomed)", + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=False, + mesh_grid=mesh_grid, + lines=lines, + ) + + plt.tight_layout() + subplot_save( + fig, output_path, f"{output_filename}_{pixelization_index}", output_format + ) diff --git a/autoarray/inversion/plot/inversion_plotters.py b/autoarray/inversion/plot/inversion_plotters.py deleted file mode 100644 index ef2ae6ea4..000000000 --- a/autoarray/inversion/plot/inversion_plotters.py +++ /dev/null @@ -1,461 +0,0 @@ -import numpy as np - -from autoconf import conf - -from autoarray.inversion.mappers.abstract import Mapper -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.inversion.inversion.abstract import AbstractInversion -from autoarray.inversion.plot.mapper_plotters import MapperPlotter - - -class InversionPlotter(AbstractPlotter): - def __init__( - self, - inversion: AbstractInversion, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - residuals_symmetric_cmap: bool = True, - ): - """ - Plots the attributes of `Inversion` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Inversion` and plotted via the visuals object. - - Parameters - ---------- - inversion - The inversion the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) - - self.inversion = inversion - self.residuals_symmetric_cmap = residuals_symmetric_cmap - - def mapper_plotter_from(self, mapper_index: int) -> MapperPlotter: - """ - Returns a `MapperPlotter` corresponding to the `Mapper` in the `Inversion`'s `linear_obj_list` given an input - `mapper_index`. - - Parameters - ---------- - mapper_index - The index of the mapper in the inversion which is used to create the `MapperPlotter`. - - Returns - ------- - MapperPlotter - An object that plots mappers which is used for plotting attributes of the inversion. - """ - return MapperPlotter( - mapper=self.inversion.cls_list_from(cls=Mapper)[mapper_index], - mat_plot_2d=self.mat_plot_2d, - visuals_2d=self.visuals_2d, - ) - - def figures_2d(self, reconstructed_operated_data: bool = False): - """ - Plots the individual attributes of the plotter's `Inversion` object in 2D. - - The API is such that every plottable attribute of the `Inversion` object is an input parameter of type bool of - the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - reconstructed_operated_data - Whether to make a 2D plot (via `imshow`) of the reconstructed image data. - """ - if reconstructed_operated_data: - try: - self.mat_plot_2d.plot_array( - array=self.inversion.mapped_reconstructed_operated_data, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Reconstructed Image", - filename="reconstructed_operated_data", - ), - ) - except AttributeError: - self.mat_plot_2d.plot_array( - array=self.inversion.mapped_reconstructed_data, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Reconstructed Image", filename="reconstructed_data" - ), - ) - - def figures_2d_of_pixelization( - self, - pixelization_index: int = 0, - data_subtracted: bool = False, - reconstructed_operated_data: bool = False, - reconstruction: bool = False, - reconstruction_noise_map: bool = False, - signal_to_noise_map: bool = False, - regularization_weights: bool = False, - sub_pixels_per_image_pixels: bool = False, - mesh_pixels_per_image_pixels: bool = False, - image_pixels_per_mesh_pixel: bool = False, - magnification_per_mesh_pixel: bool = False, - zoom_to_brightest: bool = True, - ): - """ - Plots the individual attributes of a specific `Mapper` of the plotter's `Inversion` object in 2D. - - The API is such that every plottable attribute of the `Mapper` and `Inversion` object is an input parameter of - type bool of the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - pixelization_index - The index of the `Mapper` in the `Inversion`'s `linear_obj_list` that is plotted. - reconstructed_operated_data - Whether to make a 2D plot (via `imshow`) of the mapper's reconstructed image data. - reconstruction - Whether to make a 2D plot (via `imshow` or `fill`) of the mapper's source-plane reconstruction. - reconstruction_noise_map - Whether to make a 2D plot (via `imshow` or `fill`) of the mapper's source-plane noise-map. - signal_to_noise_map - Whether to make a 2D plot (via `imshow` or `fill`) of the mapper's source-plane signal-to-noise-map. - sub_pixels_per_image_pixels - Whether to make a 2D plot (via `imshow`) of the number of sub pixels per image pixels in the 2D - data's mask. - mesh_pixels_per_image_pixels - Whether to make a 2D plot (via `imshow`) of the number of image-mesh pixels per image pixels in the 2D - data's mask (only valid for pixelizations which use an `image_mesh`, e.g. Hilbert, KMeans). - image_pixels_per_mesh_pixel - Whether to make a 2D plot (via `imshow`) of the number of image pixels per source plane pixel, therefore - indicating how many image pixels map to each source pixel. - magnification_per_mesh_pixel - Whether to make a 2D plot (via `imshow`) of the magnification of each mesh pixel, which is the area - ratio of the image pixel to source pixel. - zoom_to_brightest - For images not in the image-plane (e.g. the `plane_image`), whether to automatically zoom the plot to - the brightest regions of the galaxies being plotted as opposed to the full extent of the grid. - """ - - if not self.inversion.has(cls=Mapper): - return - - mapper_plotter = self.mapper_plotter_from(mapper_index=pixelization_index) - - if data_subtracted: - # Attribute error is cause this raises an error for interferometer inversion, because the data is - # visibilities not an image. Update this to be handled better in future. - - try: - array = self.inversion.data_subtracted_dict[mapper_plotter.mapper] - - self.mat_plot_2d.plot_array( - array=array, - visuals_2d=self.visuals_2d, - grid_indexes=mapper_plotter.mapper.over_sampler.uniform_over_sampled, - auto_labels=AutoLabels( - title="Data Subtracted", filename="data_subtracted" - ), - ) - except AttributeError: - pass - - if reconstructed_operated_data: - array = self.inversion.mapped_reconstructed_operated_data_dict[ - mapper_plotter.mapper - ] - - from autoarray.structures.visibilities import Visibilities - - if isinstance(array, Visibilities): - array = self.inversion.mapped_reconstructed_data_dict[ - mapper_plotter.mapper - ] - - self.mat_plot_2d.plot_array( - array=array, - visuals_2d=self.visuals_2d, - grid_indexes=mapper_plotter.mapper.over_sampler.uniform_over_sampled, - auto_labels=AutoLabels( - title="Reconstructed Image", filename="reconstructed_operated_data" - ), - ) - - if reconstruction: - vmax_custom = False - - if "vmax" in self.mat_plot_2d.cmap.kwargs: - if self.mat_plot_2d.cmap.kwargs["vmax"] is None: - reconstruction_vmax_factor = conf.instance["visualize"]["general"][ - "inversion" - ]["reconstruction_vmax_factor"] - - self.mat_plot_2d.cmap.kwargs["vmax"] = ( - reconstruction_vmax_factor - * np.max(self.inversion.reconstruction) - ) - vmax_custom = True - - pixel_values = self.inversion.reconstruction_dict[mapper_plotter.mapper] - - mapper_plotter.plot_source_from( - pixel_values=pixel_values, - zoom_to_brightest=zoom_to_brightest, - auto_labels=AutoLabels( - title="Source Reconstruction", filename="reconstruction" - ), - ) - - if vmax_custom: - self.mat_plot_2d.cmap.kwargs["vmax"] = None - - if reconstruction_noise_map: - try: - mapper_plotter.plot_source_from( - pixel_values=self.inversion.reconstruction_noise_map_dict[ - mapper_plotter.mapper - ], - auto_labels=AutoLabels( - title="Noise Map", filename="reconstruction_noise_map" - ), - ) - - except TypeError: - pass - - if signal_to_noise_map: - try: - signal_to_noise_values = ( - self.inversion.reconstruction_dict[mapper_plotter.mapper] - / self.inversion.reconstruction_noise_map_dict[ - mapper_plotter.mapper - ] - ) - - mapper_plotter.plot_source_from( - pixel_values=signal_to_noise_values, - auto_labels=AutoLabels( - title="Signal To Noise Map", filename="signal_to_noise_map" - ), - ) - - except TypeError: - pass - - if regularization_weights: - try: - mapper_plotter.plot_source_from( - pixel_values=self.inversion.regularization_weights_mapper_dict[ - mapper_plotter.mapper - ], - auto_labels=AutoLabels( - title="Regularization weight_list", - filename="regularization_weights", - ), - ) - except (IndexError, ValueError): - pass - - if sub_pixels_per_image_pixels: - sub_size = Array2D( - values=mapper_plotter.mapper.over_sampler.sub_size, - mask=self.inversion.dataset.mask, - ) - - self.mat_plot_2d.plot_array( - array=sub_size, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Sub Pixels Per Image Pixels", - filename="sub_pixels_per_image_pixels", - ), - ) - - if mesh_pixels_per_image_pixels: - try: - mesh_pixels_per_image_pixels = ( - mapper_plotter.mapper.mesh_pixels_per_image_pixels - ) - - self.mat_plot_2d.plot_array( - array=mesh_pixels_per_image_pixels, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="Mesh Pixels Per Image Pixels", - filename="mesh_pixels_per_image_pixels", - ), - ) - except Exception: - pass - - if image_pixels_per_mesh_pixel: - try: - mapper_plotter.plot_source_from( - pixel_values=mapper_plotter.mapper.data_weight_total_for_pix_from(), - auto_labels=AutoLabels( - title="Image Pixels Per Source Pixel", - filename="image_pixels_per_mesh_pixel", - ), - ) - - except TypeError: - pass - - def subplot_of_mapper( - self, mapper_index: int = 0, auto_filename: str = "subplot_inversion" - ): - """ - Plots the individual attributes of a specific `Mapper` of the plotter's `Inversion` object in 2D on a subplot. - - Parameters - ---------- - mapper_index - The index of the `Mapper` in the `Inversion`'s `linear_obj_list` that is plotted. - auto_filename - The default filename of the output subplot if written to hard-disk. - """ - - self.open_subplot_figure(number_subplots=12) - - contour_original = self.mat_plot_2d.contour - - if self.mat_plot_2d.use_log10: - self.mat_plot_2d.contour = False - - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, data_subtracted=True - ) - - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, reconstructed_operated_data=True - ) - - self.mat_plot_2d.use_log10 = True - self.mat_plot_2d.contour = False - - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, reconstructed_operated_data=True - ) - - self.mat_plot_2d.use_log10 = False - - mapper = self.inversion.cls_list_from(cls=Mapper)[mapper_index] - - self.visuals_2d += Visuals2D(mesh_grid=mapper.image_plane_mesh_grid) - - self.set_title(label="Mesh Pixel Grid Overlaid") - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, reconstructed_operated_data=True - ) - self.set_title(label=None) - - self.visuals_2d.mesh_grid = None - - # self.include_2d._mapper_image_plane_mesh_grid = False - - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, reconstruction=True - ) - - self.set_title(label="Source Reconstruction (Unzoomed)") - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, - reconstruction=True, - zoom_to_brightest=False, - ) - self.set_title(label=None) - - self.set_title(label="Noise-Map (Unzoomed)") - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, - reconstruction_noise_map=True, - zoom_to_brightest=False, - ) - - self.set_title(label="Regularization Weights (Unzoomed)") - try: - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, - regularization_weights=True, - zoom_to_brightest=False, - ) - except IndexError: - pass - self.set_title(label=None) - - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, sub_pixels_per_image_pixels=True - ) - - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, mesh_pixels_per_image_pixels=True - ) - - self.figures_2d_of_pixelization( - pixelization_index=mapper_index, image_pixels_per_mesh_pixel=True - ) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"{auto_filename}_{mapper_index}" - ) - - self.mat_plot_2d.contour = contour_original - - self.close_subplot_figure() - - def subplot_mappings( - self, pixelization_index: int = 0, auto_filename: str = "subplot_mappings" - ): - self.open_subplot_figure(number_subplots=4) - - self.figures_2d_of_pixelization( - pixelization_index=pixelization_index, data_subtracted=True - ) - - total_pixels = conf.instance["visualize"]["general"]["inversion"][ - "total_mappings_pixels" - ] - - mapper = self.inversion.cls_list_from(cls=Mapper)[pixelization_index] - - pix_indexes = self.inversion.max_pixel_list_from( - total_pixels=total_pixels, - filter_neighbors=True, - mapper_index=pixelization_index, - ) - - indexes = mapper.slim_indexes_for_pix_indexes(pix_indexes=pix_indexes) - - self.visuals_2d.indexes = indexes - - self.figures_2d_of_pixelization( - pixelization_index=pixelization_index, reconstructed_operated_data=True - ) - - self.figures_2d_of_pixelization( - pixelization_index=pixelization_index, reconstruction=True - ) - - self.set_title(label="Source Reconstruction (Unzoomed)") - self.figures_2d_of_pixelization( - pixelization_index=pixelization_index, - reconstruction=True, - zoom_to_brightest=False, - ) - self.set_title(label=None) - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"{auto_filename}_{pixelization_index}" - ) - - self.close_subplot_figure() diff --git a/autoarray/inversion/plot/mapper_plots.py b/autoarray/inversion/plot/mapper_plots.py new file mode 100644 index 000000000..6bddd8626 --- /dev/null +++ b/autoarray/inversion/plot/mapper_plots.py @@ -0,0 +1,131 @@ +import logging +from typing import Optional + +import matplotlib.pyplot as plt + +from autoarray.plot.array import plot_array +from autoarray.plot.inversion import plot_inversion_reconstruction +from autoarray.plot.utils import numpy_grid, numpy_lines, subplot_save + +logger = logging.getLogger(__name__) + + +def plot_mapper( + mapper, + solution_vector=None, + output_path: Optional[str] = None, + output_filename: str = "mapper", + output_format: str = "png", + colormap=None, + use_log10: bool = False, + mesh_grid=None, + lines=None, + title: str = "Pixelization Mesh (Source-Plane)", + zoom_to_brightest: bool = True, + ax=None, +): + """ + Plot a pixelization mapper reconstruction. + + Parameters + ---------- + mapper + A ``Mapper`` instance. + solution_vector + Per-pixel flux values. ``None`` uses uniform colours. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename without extension. + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation. + mesh_grid + Mesh grid to overlay as scatter points. + lines + Lines to overlay. + title + Figure title. + zoom_to_brightest + Zoom the source plane to the brightest region. + ax + Existing ``Axes`` to draw onto. + """ + try: + plot_inversion_reconstruction( + pixel_values=solution_vector, + mapper=mapper, + ax=ax, + title=title, + colormap=colormap, + use_log10=use_log10, + zoom_to_brightest=zoom_to_brightest, + lines=numpy_lines(lines), + grid=numpy_grid(mesh_grid), + output_path=output_path, + output_filename=output_filename, + output_format=output_format, + ) + except Exception as exc: + logger.info(f"Could not plot the source-plane via the Mapper: {exc}") + + +def subplot_image_and_mapper( + mapper, + image, + output_path: Optional[str] = None, + output_filename: str = "subplot_image_and_mapper", + output_format: str = "png", + colormap=None, + use_log10: bool = False, + mesh_grid=None, + lines=None, +): + """ + 1×2 subplot: image-plane image (left) and pixelization mesh (right). + + Parameters + ---------- + mapper + A ``Mapper`` instance. + image + An ``Array2D`` instance to show in the image plane. + output_path + Directory to save the figure. ``None`` calls ``plt.show()``. + output_filename + Base filename without extension. + output_format + File format. + colormap + Matplotlib colormap name. + use_log10 + Apply log10 normalisation. + mesh_grid + Mesh grid to overlay on the reconstruction panel. + lines + Lines to overlay on both panels. + """ + fig, axes = plt.subplots(1, 2, figsize=(14, 7)) + + plot_array( + image, + ax=axes[0], + title="Image (Image-Plane)", + colormap=colormap, + use_log10=use_log10, + lines=lines, + ) + plot_mapper( + mapper, + colormap=colormap, + use_log10=use_log10, + mesh_grid=mesh_grid, + lines=lines, + ax=axes[1], + ) + + plt.tight_layout() + subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/inversion/plot/mapper_plotters.py b/autoarray/inversion/plot/mapper_plotters.py deleted file mode 100644 index b9a446792..000000000 --- a/autoarray/inversion/plot/mapper_plotters.py +++ /dev/null @@ -1,130 +0,0 @@ -import numpy as np - -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.structures.arrays.uniform_2d import Array2D - -import logging - -logger = logging.getLogger(__name__) - - -class MapperPlotter(AbstractPlotter): - def __init__( - self, - mapper, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots the attributes of `Mapper` objects using the matplotlib method `imshow()` and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Mapper` and plotted via the visuals object. - - Parameters - ---------- - mapper - The mapper the plotter plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) - - self.mapper = mapper - - def figure_2d(self, solution_vector: bool = None): - """ - Plots the plotter's `Mapper` object in 2D. - - Parameters - ---------- - solution_vector - A vector of values which can culor the pixels of the mapper's source pixels. - """ - self.mat_plot_2d.plot_mapper( - mapper=self.mapper, - visuals_2d=self.visuals_2d, - pixel_values=solution_vector, - auto_labels=AutoLabels( - title="Pixelization Mesh (Source-Plane)", filename="mapper" - ), - ) - - def figure_2d_image(self, image): - - self.mat_plot_2d.plot_array( - array=image, - visuals_2d=self.visuals_2d, - grid_indexes=self.mapper.image_plane_data_grid.over_sampled, - auto_labels=AutoLabels( - title="Image (Image-Plane)", filename="mapper_image" - ), - ) - - def subplot_image_and_mapper( - self, - image: Array2D, - ): - """ - Make a subplot of an input image and the `Mapper`'s source-plane reconstruction. - - This function can include colored points that mark the mappings between the image pixels and their - corresponding locations in the `Mapper` source-plane and reconstruction. This therefore visually illustrates - the mapping process. - - Parameters - ---------- - image - The image which is plotted on the subplot. - """ - self.open_subplot_figure(number_subplots=2) - - self.figure_2d_image(image=image) - self.figure_2d() - - self.mat_plot_2d.output.subplot_to_figure( - auto_filename="subplot_image_and_mapper" - ) - self.close_subplot_figure() - - def plot_source_from( - self, - pixel_values: np.ndarray, - zoom_to_brightest: bool = True, - auto_labels: AutoLabels = AutoLabels(), - ): - """ - Plot the source of the `Mapper` where the coloring is specified by an input set of values. - - Parameters - ---------- - pixel_values - The values of the mapper's source pixels used for coloring the figure. - zoom_to_brightest - For images not in the image-plane (e.g. the `plane_image`), whether to automatically zoom the plot to - the brightest regions of the galaxies being plotted as opposed to the full extent of the grid. - auto_labels - The labels given to the figure. - """ - try: - self.mat_plot_2d.plot_mapper( - mapper=self.mapper, - visuals_2d=self.visuals_2d, - auto_labels=auto_labels, - pixel_values=pixel_values, - zoom_to_brightest=zoom_to_brightest, - ) - except ValueError: - logger.info( - "Could not plot the source-plane via the Mapper because of a ValueError." - ) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index c63bbb736..582f6c810 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -356,6 +356,7 @@ def visibilities_from_jax(self, image: np.ndarray) -> np.ndarray: lambda img: self._pynufft_forward_numpy(img), result_shape, image, + vmap_method="sequential", ) def visibilities_from(self, image, xp=np): diff --git a/autoarray/plot/__init__.py b/autoarray/plot/__init__.py index 1ec4ff8ad..7456c7273 100644 --- a/autoarray/plot/__init__.py +++ b/autoarray/plot/__init__.py @@ -1,61 +1,70 @@ -from autoarray.plot.wrap.base.units import Units -from autoarray.plot.wrap.base.figure import Figure -from autoarray.plot.wrap.base.axis import Axis -from autoarray.plot.wrap.base.cmap import Cmap -from autoarray.plot.wrap.base.colorbar import Colorbar -from autoarray.plot.wrap.base.colorbar_tickparams import ColorbarTickParams -from autoarray.plot.wrap.base.tickparams import TickParams -from autoarray.plot.wrap.base.ticks import YTicks -from autoarray.plot.wrap.base.ticks import XTicks -from autoarray.plot.wrap.base.title import Title -from autoarray.plot.wrap.base.label import YLabel -from autoarray.plot.wrap.base.label import XLabel -from autoarray.plot.wrap.base.text import Text -from autoarray.plot.wrap.base.annotate import Annotate -from autoarray.plot.wrap.base.legend import Legend -from autoarray.plot.wrap.base.output import Output - -from autoarray.plot.wrap.one_d.yx_plot import YXPlot -from autoarray.plot.wrap.one_d.yx_scatter import YXScatter -from autoarray.plot.wrap.one_d.avxline import AXVLine -from autoarray.plot.wrap.one_d.fill_between import FillBetween - -from autoarray.plot.wrap.two_d.array_overlay import ArrayOverlay -from autoarray.plot.wrap.two_d.contour import Contour -from autoarray.plot.wrap.two_d.fill import Fill -from autoarray.plot.wrap.two_d.grid_scatter import GridScatter -from autoarray.plot.wrap.two_d.grid_plot import GridPlot -from autoarray.plot.wrap.two_d.grid_errorbar import GridErrorbar -from autoarray.plot.wrap.two_d.vector_yx_quiver import VectorYXQuiver -from autoarray.plot.wrap.two_d.patch_overlay import PatchOverlay -from autoarray.plot.wrap.two_d.delaunay_drawer import DelaunayDrawer -from autoarray.plot.wrap.two_d.origin_scatter import OriginScatter -from autoarray.plot.wrap.two_d.mask_scatter import MaskScatter -from autoarray.plot.wrap.two_d.border_scatter import BorderScatter -from autoarray.plot.wrap.two_d.positions_scatter import PositionsScatter -from autoarray.plot.wrap.two_d.index_scatter import IndexScatter -from autoarray.plot.wrap.two_d.index_plot import IndexPlot -from autoarray.plot.wrap.two_d.mesh_grid_scatter import MeshGridScatter -from autoarray.plot.wrap.two_d.parallel_overscan_plot import ParallelOverscanPlot -from autoarray.plot.wrap.two_d.serial_prescan_plot import SerialPrescanPlot -from autoarray.plot.wrap.two_d.serial_overscan_plot import SerialOverscanPlot - -from autoarray.plot.mat_plot.one_d import MatPlot1D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.auto_labels import AutoLabels - -from autoarray.structures.plot.structure_plotters import Array2DPlotter -from autoarray.structures.plot.structure_plotters import Grid2DPlotter -from autoarray.structures.plot.structure_plotters import YX1DPlotter -from autoarray.structures.plot.structure_plotters import YX1DPlotter as Array1DPlotter -from autoarray.inversion.plot.mapper_plotters import MapperPlotter -from autoarray.inversion.plot.inversion_plotters import InversionPlotter -from autoarray.dataset.plot.imaging_plotters import ImagingPlotter -from autoarray.dataset.plot.interferometer_plotters import InterferometerPlotter -from autoarray.fit.plot.fit_imaging_plotters import FitImagingPlotter -from autoarray.fit.plot.fit_interferometer_plotters import FitInterferometerPlotter - -from autoarray.plot.multi_plotters import MultiFigurePlotter -from autoarray.plot.multi_plotters import MultiYX1DPlotter +def _set_backend(): + try: + import matplotlib + from autoconf import conf + + backend = conf.get_matplotlib_backend() + if backend != "default": + matplotlib.use(backend) + try: + hpc_mode = conf.instance["general"]["hpc"]["hpc_mode"] + except KeyError: + hpc_mode = False + if hpc_mode: + matplotlib.use("Agg") + except Exception: + pass + + +_set_backend() + +from autoarray.plot.output import Output + +from autoarray.plot.array import plot_array +from autoarray.plot.grid import plot_grid +from autoarray.plot.yx import plot_yx +from autoarray.plot.inversion import plot_inversion_reconstruction +from autoarray.plot.utils import ( + apply_extent, + apply_labels, + conf_figsize, + conf_mat_plot_fontsize, + save_figure, + subplot_save, + auto_mask_edge, + zoom_array, + numpy_grid, + numpy_lines, + numpy_positions, + symmetric_vmin_vmax, + symmetric_cmap_from, + set_with_color_values, + plot_visibilities_1d, +) + +from autoarray.structures.plot.structure_plots import ( + plot_array_2d, + plot_grid_2d, + plot_yx_1d, +) + +from autoarray.dataset.plot.imaging_plots import subplot_imaging_dataset +from autoarray.dataset.plot.interferometer_plots import ( + subplot_interferometer_dataset, + subplot_interferometer_dirty_images, +) + +from autoarray.fit.plot.fit_imaging_plots import subplot_fit_imaging +from autoarray.fit.plot.fit_interferometer_plots import ( + subplot_fit_interferometer, + subplot_fit_interferometer_dirty_images, +) + +from autoarray.inversion.plot.mapper_plots import ( + plot_mapper, + subplot_image_and_mapper, +) +from autoarray.inversion.plot.inversion_plots import ( + subplot_of_mapper, + subplot_mappings, +) diff --git a/autoarray/plot/abstract_plotters.py b/autoarray/plot/abstract_plotters.py deleted file mode 100644 index 07ec41354..000000000 --- a/autoarray/plot/abstract_plotters.py +++ /dev/null @@ -1,213 +0,0 @@ -from autoconf import conf - -from autoarray.plot.wrap.base.abstract import set_backend - -set_backend() - -from typing import Optional, Tuple - -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.one_d import MatPlot1D -from autoarray.plot.mat_plot.two_d import MatPlot2D - - -class AbstractPlotter: - def __init__( - self, - mat_plot_1d: MatPlot1D = None, - visuals_1d: Visuals1D = None, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - self.visuals_1d = visuals_1d or Visuals1D() - self.mat_plot_1d = mat_plot_1d or MatPlot1D() - - self.visuals_2d = visuals_2d or Visuals2D() - self.mat_plot_2d = mat_plot_2d or MatPlot2D() - - self.subplot_figsize = None - - def set_title(self, label): - if self.mat_plot_1d is not None: - self.mat_plot_1d.title.manual_label = label - - if self.mat_plot_2d is not None: - self.mat_plot_2d.title.manual_label = label - - def set_filename(self, filename): - if self.mat_plot_1d is not None: - self.mat_plot_1d.output.filename = filename - - if self.mat_plot_2d is not None: - self.mat_plot_2d.output.filename = filename - - def set_format(self, format): - if self.mat_plot_1d is not None: - self.mat_plot_1d.output._format = format - - if self.mat_plot_2d is not None: - self.mat_plot_2d.output._format = format - - def set_mat_plot_1d_for_multi_plot( - self, is_for_multi_plot, color: str, xticks=None, yticks=None - ): - self.mat_plot_1d.set_for_multi_plot( - is_for_multi_plot=is_for_multi_plot, - color=color, - xticks=xticks, - yticks=yticks, - ) - - def set_mat_plots_for_subplot( - self, is_for_subplot, number_subplots=None, subplot_shape=None - ): - if self.mat_plot_1d is not None: - self.mat_plot_1d.set_for_subplot(is_for_subplot=is_for_subplot) - self.mat_plot_1d.number_subplots = number_subplots - self.mat_plot_1d.subplot_shape = subplot_shape - self.mat_plot_1d.subplot_index = 1 - if self.mat_plot_2d is not None: - self.mat_plot_2d.set_for_subplot(is_for_subplot=is_for_subplot) - self.mat_plot_2d.number_subplots = number_subplots - self.mat_plot_2d.subplot_shape = subplot_shape - self.mat_plot_2d.subplot_index = 1 - - @property - def is_for_subplot(self): - if self.mat_plot_1d is not None: - if self.mat_plot_1d.is_for_subplot: - return True - - if self.mat_plot_2d is not None: - if self.mat_plot_2d.is_for_subplot: - return True - - return False - - def open_subplot_figure( - self, - number_subplots: int, - subplot_shape: Optional[Tuple[int, int]] = None, - subplot_figsize: Optional[Tuple[int, int]] = None, - subplot_title: Optional[str] = None, - ): - """ - Setup a figure for plotting an image. - - Parameters - ---------- - figsize - The size of the figure in (total_y_pixels, total_x_pixels). - as_subplot - If the figure is a subplot, the setup_figure function is omitted to ensure that each subplot does not create a \ - new figure and so that it can be output using the *output.output_figure(structure=None)* function. - """ - import matplotlib.pyplot as plt - - self.set_mat_plots_for_subplot( - is_for_subplot=True, - number_subplots=number_subplots, - subplot_shape=subplot_shape, - ) - - self.subplot_figsize = subplot_figsize - - figsize = self.get_subplot_figsize(number_subplots=number_subplots) - plt.figure(figsize=figsize) - plt.suptitle(subplot_title, fontsize=40, y=0.93) - - def close_subplot_figure(self): - try: - self.mat_plot_2d.figure.close() - except AttributeError: - self.mat_plot_1d.figure.close() - self.set_mat_plots_for_subplot(is_for_subplot=False) - self.subplot_figsize = None - - def get_subplot_figsize(self, number_subplots): - """ - Get the size of a sub plotter in (total_y_pixels, total_x_pixels), based on the number of subplots that are going to be plotted. - - Parameters - ---------- - number_subplots - The number of subplots that are to be plotted in the figure. - """ - - if self.subplot_figsize is not None: - return self.subplot_figsize - - if self.mat_plot_1d is not None: - if self.mat_plot_1d.figure.config_dict["figsize"] is not None: - return self.mat_plot_1d.figure.config_dict["figsize"] - - if self.mat_plot_2d is not None: - if self.mat_plot_2d.figure.config_dict["figsize"] is not None: - return self.mat_plot_2d.figure.config_dict["figsize"] - - try: - subplot_shape = self.mat_plot_1d.get_subplot_shape( - number_subplots=number_subplots - ) - except AttributeError: - subplot_shape = self.mat_plot_2d.get_subplot_shape( - number_subplots=number_subplots - ) - - subplot_shape_to_figsize_factor = conf.instance["visualize"]["general"][ - "subplot_shape_to_figsize_factor" - ] - subplot_shape_to_figsize_factor = tuple( - map(int, subplot_shape_to_figsize_factor[1:-1].split(",")) - ) - - return ( - subplot_shape[1] * subplot_shape_to_figsize_factor[1], - subplot_shape[0] * subplot_shape_to_figsize_factor[0], - ) - - def _subplot_custom_plot(self, **kwargs): - figures_dict = dict( - (key, value) for key, value in kwargs.items() if value is True - ) - - self.open_subplot_figure(number_subplots=len(figures_dict)) - - for index, (key, value) in enumerate(figures_dict.items()): - if value: - try: - self.figures_2d(**{key: True}) - except AttributeError: - self.figures_1d(**{key: True}) - - try: - self.mat_plot_2d.subplot_index = max( - self.mat_plot_1d.subplot_index, self.mat_plot_2d.subplot_index - ) - self.mat_plot_1d.subplot_index = max( - self.mat_plot_1d.subplot_index, self.mat_plot_2d.subplot_index - ) - except AttributeError: - pass - - try: - self.mat_plot_2d.output.subplot_to_figure( - auto_filename=kwargs["auto_labels"].filename - ) - except AttributeError: - self.mat_plot_1d.output.subplot_to_figure( - auto_filename=kwargs["auto_labels"].filename - ) - - self.close_subplot_figure() - - def subplot_of_plotters_figure(self, plotter_list, name): - self.open_subplot_figure(number_subplots=len(plotter_list)) - - for i, plotter in enumerate(plotter_list): - plotter.figures_2d(**{name: True}) - - self.mat_plot_2d.output.subplot_to_figure(auto_filename=f"subplot_{name}") - - self.close_subplot_figure() diff --git a/autoarray/plot/array.py b/autoarray/plot/array.py new file mode 100644 index 000000000..3c7d8f9af --- /dev/null +++ b/autoarray/plot/array.py @@ -0,0 +1,279 @@ +""" +Standalone function for plotting a 2D array (image) directly with matplotlib. +""" + +import os +from typing import List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import LogNorm, Normalize + +from autoarray.plot.utils import ( + apply_extent, + apply_labels, + conf_figsize, + save_figure, + zoom_array, + auto_mask_edge, + numpy_grid, + numpy_lines, + numpy_positions, +) + +_zoom_array_2d = zoom_array +_mask_edge_coords = auto_mask_edge + + +def plot_array( + array, + ax: Optional[plt.Axes] = None, + # --- spatial metadata ------------------------------------------------------- + extent: Optional[Tuple[float, float, float, float]] = None, + # --- overlays --------------------------------------------------------------- + mask: Optional[np.ndarray] = None, + border=None, + origin=None, + grid=None, + mesh_grid=None, + positions=None, + lines=None, + vector_yx: Optional[np.ndarray] = None, + array_overlay=None, + patches: Optional[List] = None, + fill_region: Optional[List] = None, + contours: Optional[int] = None, + # --- cosmetics -------------------------------------------------------------- + title: str = "", + xlabel: str = 'x (")', + ylabel: str = 'y (")', + colormap: str = "jet", + vmin: Optional[float] = None, + vmax: Optional[float] = None, + use_log10: bool = False, + aspect: str = "auto", + origin_imshow: str = "upper", + # --- figure control (used only when ax is None) ----------------------------- + figsize: Optional[Tuple[int, int]] = None, + output_path: Optional[str] = None, + output_filename: str = "array", + output_format: str = "png", + structure=None, +) -> None: + """ + Plot a 2D array (image) using ``plt.imshow``. + + Parameters + ---------- + array + 2D numpy array of pixel values. + ax + Existing matplotlib ``Axes`` to draw onto. If ``None`` a new figure + is created and saved / shown according to *output_path*. + extent + ``[xmin, xmax, ymin, ymax]`` spatial extent in data coordinates. + mask + Array of shape ``(N, 2)`` with ``(y, x)`` coordinates of masked + pixels to overlay as black dots (auto-derived from array.mask by caller). + border + Array of shape ``(N, 2)`` with ``(y, x)`` border pixel coordinates. + origin + ``(y, x)`` origin coordinate(s) to scatter as a marker. + grid + Array of shape ``(N, 2)`` with ``(y, x)`` coordinates to scatter. + mesh_grid + Array of shape ``(N, 2)`` mesh grid coordinates to scatter. + positions + List of ``(N, 2)`` arrays; each is scattered as a distinct group. + lines + List of ``(N, 2)`` arrays with ``(y, x)`` columns to plot as lines. + vector_yx + Array of shape ``(N, 4)`` — ``(y, x, vy, vx)`` — plotted as quiver. + array_overlay + A second 2D array rendered on top of *array* with partial alpha. + patches + List of matplotlib ``Patch`` objects to draw over the image. + fill_region + List of two arrays ``[y1_arr, y2_arr]`` passed to ``ax.fill_between``. + title + Figure title string. + xlabel, ylabel + Axis label strings. + colormap + Matplotlib colormap name. + vmin, vmax + Explicit color scale limits. + use_log10 + When ``True`` a ``LogNorm`` is applied. + aspect + Passed directly to ``imshow``. + origin_imshow + Passed directly to ``imshow`` (``"upper"`` or ``"lower"``). + figsize + Figure size in inches. + output_path + Directory to save the figure. When empty / ``None`` ``plt.show()`` + is called instead. + output_filename + Base file name (without extension). + output_format + File format, e.g. ``"png"``. + """ + # --- autoarray extraction -------------------------------------------------- + array = zoom_array(array) + try: + if structure is None: + structure = array + if extent is None: + extent = array.geometry.extent + if mask is None: + mask = auto_mask_edge(array) + array = array.native.array + except AttributeError: + array = np.asarray(array) + + if array is None or array.size == 0: + return + + # convert overlay params (safe for None and already-numpy inputs) + border = numpy_grid(border) + origin = numpy_grid(origin) + grid = numpy_grid(grid) + mesh_grid = numpy_grid(mesh_grid) + positions = numpy_positions(positions) + lines = numpy_lines(lines) + if array_overlay is not None: + try: + array_overlay = array_overlay.native.array + except AttributeError: + array_overlay = np.asarray(array_overlay) + + owns_figure = ax is None + if owns_figure: + figsize = figsize or conf_figsize("figures") + fig, ax = plt.subplots(1, 1, figsize=figsize) + else: + fig = ax.get_figure() + + # --- colour normalisation -------------------------------------------------- + if use_log10: + try: + from autoconf import conf as _conf + + log10_min = _conf.instance["visualize"]["general"]["general"][ + "log10_min_value" + ] + except Exception: + log10_min = 1.0e-4 + clipped = np.clip(array, log10_min, None) + vmin_log = vmin if (vmin is not None and np.isfinite(vmin)) else log10_min + if vmax is not None and np.isfinite(vmax): + vmax_log = vmax + else: + with np.errstate(all="ignore"): + vmax_log = np.nanmax(clipped) + if not np.isfinite(vmax_log) or vmax_log <= vmin_log: + vmax_log = vmin_log * 10.0 + norm = LogNorm(vmin=vmin_log, vmax=vmax_log) + elif vmin is not None or vmax is not None: + norm = Normalize(vmin=vmin, vmax=vmax) + else: + norm = None + + im = ax.imshow( + array, + cmap=colormap, + norm=norm, + extent=extent, + aspect=aspect, + origin=origin_imshow, + ) + + plt.colorbar(im, ax=ax) + + # --- overlays -------------------------------------------------------------- + if array_overlay is not None: + ax.imshow( + array_overlay, + cmap="Greys", + alpha=0.5, + extent=extent, + aspect=aspect, + origin=origin_imshow, + ) + + if mask is not None: + ax.scatter(mask[:, 1], mask[:, 0], s=1, c="k") + + if border is not None: + ax.scatter(border[:, 1], border[:, 0], s=1, c="b") + + if origin is not None: + origin_arr = np.asarray(origin) + if origin_arr.ndim == 1: + origin_arr = origin_arr[np.newaxis, :] + ax.scatter( + origin_arr[:, 1], origin_arr[:, 0], s=20, c="r", marker="x", zorder=6 + ) + + if grid is not None: + ax.scatter(grid[:, 1], grid[:, 0], s=1, c="k") + + if mesh_grid is not None: + ax.scatter(mesh_grid[:, 1], mesh_grid[:, 0], s=1, c="w", alpha=0.5) + + if positions is not None: + colors = ["r", "g", "b", "m", "c", "y"] + for i, pos in enumerate(positions): + ax.scatter(pos[:, 1], pos[:, 0], s=20, c=colors[i % len(colors)], zorder=5) + + if lines is not None: + for line in lines: + if line is not None and len(line) > 0: + ax.plot(line[:, 1], line[:, 0], linewidth=2) + + if vector_yx is not None: + ax.quiver( + vector_yx[:, 1], + vector_yx[:, 0], + vector_yx[:, 3], + vector_yx[:, 2], + ) + + if patches is not None: + for patch in patches: + import copy + + ax.add_patch(copy.copy(patch)) + + if fill_region is not None: + y1, y2 = fill_region[0], fill_region[1] + x_fill = np.arange(len(y1)) + ax.fill_between(x_fill, y1, y2, alpha=0.3) + + if contours is not None and contours > 0: + try: + levels = np.linspace(np.nanmin(array), np.nanmax(array), contours) + cs = ax.contour(array[::-1], levels=levels, extent=extent, colors="k") + try: + ax.clabel(cs, levels=levels, inline=True, fontsize=8) + except (ValueError, IndexError): + pass + except Exception: + pass + + # --- labels / ticks -------------------------------------------------------- + apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel) + + if extent is not None: + apply_extent(ax, extent) + + # --- output ---------------------------------------------------------------- + if owns_figure: + save_figure( + fig, + path=output_path or "", + filename=output_filename, + format=output_format, + structure=structure, + ) diff --git a/autoarray/plot/auto_labels.py b/autoarray/plot/auto_labels.py deleted file mode 100644 index 34d4a495e..000000000 --- a/autoarray/plot/auto_labels.py +++ /dev/null @@ -1,20 +0,0 @@ -class AutoLabels: - def __init__( - self, - title=None, - ylabel=None, - xlabel=None, - yunit=None, - xunit=None, - cb_unit=None, - legend=None, - filename=None, - ): - self.title = title - self.ylabel = ylabel - self.xlabel = xlabel - self.yunit = yunit - self.xunit = xunit - self.cb_unit = cb_unit - self.legend = legend - self.filename = filename diff --git a/autoarray/plot/grid.py b/autoarray/plot/grid.py new file mode 100644 index 000000000..88cbac5f9 --- /dev/null +++ b/autoarray/plot/grid.py @@ -0,0 +1,184 @@ +""" +Standalone function for plotting a 2D grid of (y, x) coordinates. + +This replaces the ``MatPlot2D.plot_grid`` / ``MatWrap`` system. +""" + +from typing import Iterable, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np + +from autoarray.plot.utils import ( + apply_extent, + apply_labels, + conf_figsize, + save_figure, + numpy_lines, +) + + +def plot_grid( + grid, + ax: Optional[plt.Axes] = None, + # --- errors ----------------------------------------------------------------- + y_errors: Optional[np.ndarray] = None, + x_errors: Optional[np.ndarray] = None, + # --- overlays --------------------------------------------------------------- + lines=None, + color_array: Optional[np.ndarray] = None, + indexes: Optional[List] = None, + # --- cosmetics -------------------------------------------------------------- + title: str = "", + xlabel: str = 'x (")', + ylabel: str = 'y (")', + colormap: str = "jet", + buffer: float = 0.1, + extent: Optional[Tuple[float, float, float, float]] = None, + force_symmetric_extent: bool = True, + # --- figure control (used only when ax is None) ----------------------------- + figsize: Optional[Tuple[int, int]] = None, + output_path: Optional[str] = None, + output_filename: str = "grid", + output_format: str = "png", +) -> None: + """ + Plot a 2D grid of ``(y, x)`` coordinates as a scatter plot. + + This is the direct-matplotlib replacement for ``MatPlot2D.plot_grid``. + + Parameters + ---------- + grid + Array of shape ``(N, 2)``; column 0 is *y*, column 1 is *x*. + ax + Existing ``Axes`` to draw onto. ``None`` creates a new figure. + y_errors, x_errors + Per-point error values; when provided ``plt.errorbar`` is used. + lines + Iterable of ``(N, 2)`` arrays (y, x columns) drawn as lines. + color_array + 1D array of scalar values for colouring each point; triggers a + colorbar. + title + Figure title. + xlabel, ylabel + Axis labels. + colormap + Matplotlib colormap name. + buffer + Fractional padding for the auto-computed extent. The grid's + ``extent_with_buffer_from`` method is called when *extent* is + ``None`` and the grid object exposes that method. + extent + Manual axis limits ``[xmin, xmax, ymin, ymax]``. Auto-computed + when ``None``. + force_symmetric_extent + When ``True`` (and *extent* is auto-computed) the limits are made + symmetric about the origin so the plot is centred. + figsize + Figure size in inches ``(width, height)``. + output_path + Directory for saving. Empty string / ``None`` triggers + ``plt.show()``. + output_filename + Base file name without extension. + output_format + File format, e.g. ``"png"``. + """ + # --- autoarray extraction -------------------------------------------------- + # Compute extent before converting to numpy so grid methods are available. + if extent is None: + try: + extent = grid.extent_with_buffer_from(buffer=buffer) + except AttributeError: + pass # computed from numpy values below + + if hasattr(grid, "array"): + grid = np.array(grid.array) + else: + grid = np.asarray(grid) + + lines = numpy_lines(lines) + + owns_figure = ax is None + if owns_figure: + figsize = figsize or conf_figsize("figures") + fig, ax = plt.subplots(1, 1, figsize=figsize) + else: + fig = ax.get_figure() + + # --- scatter / errorbar ---------------------------------------------------- + if color_array is not None: + cmap = plt.get_cmap(colormap) + colors = cmap((color_array - color_array.min()) / (np.ptp(color_array) or 1)) + + if y_errors is None and x_errors is None: + sc = ax.scatter(grid[:, 1], grid[:, 0], s=1, c=color_array, cmap=colormap) + else: + sc = ax.scatter(grid[:, 1], grid[:, 0], s=1, c=color_array, cmap=colormap) + ax.errorbar( + grid[:, 1], + grid[:, 0], + yerr=y_errors, + xerr=x_errors, + fmt="none", + ecolor=colors, + ) + + plt.colorbar(sc, ax=ax) + else: + if y_errors is None and x_errors is None: + ax.scatter(grid[:, 1], grid[:, 0], s=1, c="k") + else: + ax.errorbar( + grid[:, 1], + grid[:, 0], + yerr=y_errors, + xerr=x_errors, + fmt="o", + markersize=2, + color="k", + ) + + # --- line overlays --------------------------------------------------------- + if lines is not None: + for line in lines: + if line is not None and len(line) > 0: + ax.plot(line[:, 1], line[:, 0], linewidth=2) + + # --- labels ---------------------------------------------------------------- + apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel) + + # --- extent ---------------------------------------------------------------- + if extent is None: + y_vals = grid[:, 0] + x_vals = grid[:, 1] + extent = [x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()] + + if indexes is not None: + colors = ["r", "g", "b", "m", "c", "y"] + for i, idx_list in enumerate(indexes): + ax.scatter( + grid[idx_list, 1], + grid[idx_list, 0], + s=10, + c=colors[i % len(colors)], + zorder=5, + ) + + if force_symmetric_extent and extent is not None: + x_abs = max(abs(extent[0]), abs(extent[1])) + y_abs = max(abs(extent[2]), abs(extent[3])) + extent = [-x_abs, x_abs, -y_abs, y_abs] + + apply_extent(ax, extent) + + # --- output ---------------------------------------------------------------- + if owns_figure: + save_figure( + fig, + path=output_path or "", + filename=output_filename, + format=output_format, + ) diff --git a/autoarray/plot/inversion.py b/autoarray/plot/inversion.py new file mode 100644 index 000000000..baa01436d --- /dev/null +++ b/autoarray/plot/inversion.py @@ -0,0 +1,243 @@ +""" +Standalone functions for plotting inversion / pixelization reconstructions. + +Replaces the inversion-specific paths in ``MatPlot2D.plot_mapper``. +""" + +from typing import List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import LogNorm, Normalize + +from autoarray.plot.utils import apply_extent, apply_labels, conf_figsize, save_figure + + +def plot_inversion_reconstruction( + pixel_values: np.ndarray, + mapper, + ax: Optional[plt.Axes] = None, + # --- cosmetics -------------------------------------------------------------- + title: str = "Reconstruction", + xlabel: str = 'x (")', + ylabel: str = 'y (")', + colormap: str = "jet", + vmin: Optional[float] = None, + vmax: Optional[float] = None, + use_log10: bool = False, + zoom_to_brightest: bool = True, + # --- overlays --------------------------------------------------------------- + lines: Optional[List[np.ndarray]] = None, + grid: Optional[np.ndarray] = None, + # --- figure control (used only when ax is None) ----------------------------- + figsize: Optional[Tuple[int, int]] = None, + output_path: Optional[str] = None, + output_filename: str = "reconstruction", + output_format: str = "png", +) -> None: + """ + Plot an inversion reconstruction using the appropriate mapper type. + + Chooses between rectangular (``imshow``/``pcolormesh``) and Delaunay + (``tripcolor``) rendering based on the mapper's interpolator type. + + Parameters + ---------- + pixel_values + 1D array of reconstructed flux values, one per source pixel. + mapper + Autoarray mapper object exposing ``interpolator``, ``mesh_geometry``, + ``source_plane_mesh_grid``, etc. + ax + Existing ``Axes``. ``None`` creates a new figure. + title, xlabel, ylabel + Text labels. + colormap + Matplotlib colormap name. + vmin, vmax + Explicit colour scale limits. + use_log10 + Apply ``LogNorm``. + zoom_to_brightest + Pass through to ``mapper.extent_from``. + lines + Line overlays (e.g. critical curves). + grid + Scatter overlay (e.g. data-plane grid). + figsize, output_path, output_filename, output_format + Figure output controls. + """ + from autoarray.inversion.mesh.interpolator.rectangular import ( + InterpolatorRectangular, + ) + from autoarray.inversion.mesh.interpolator.rectangular_uniform import ( + InterpolatorRectangularUniform, + ) + from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay + from autoarray.inversion.mesh.interpolator.knn import InterpolatorKNearestNeighbor + + owns_figure = ax is None + if owns_figure: + figsize = figsize or conf_figsize("figures") + fig, ax = plt.subplots(1, 1, figsize=figsize) + else: + fig = ax.get_figure() + + # --- colour normalisation -------------------------------------------------- + if use_log10: + vmin_log = vmin if (vmin is not None and np.isfinite(vmin)) else 1e-4 + if vmax is not None and np.isfinite(vmax): + vmax_log = vmax + elif pixel_values is not None: + with np.errstate(all="ignore"): + vmax_log = float(np.nanmax(np.asarray(pixel_values))) + if not np.isfinite(vmax_log) or vmax_log <= vmin_log: + vmax_log = vmin_log * 10.0 + else: + vmax_log = vmin_log * 10.0 + norm = LogNorm(vmin=vmin_log, vmax=vmax_log) + elif vmin is not None or vmax is not None: + norm = Normalize(vmin=vmin, vmax=vmax) + else: + norm = None + + extent = mapper.extent_from( + values=pixel_values, zoom_to_brightest=zoom_to_brightest + ) + + if isinstance( + mapper.interpolator, (InterpolatorRectangular, InterpolatorRectangularUniform) + ): + _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent) + elif isinstance( + mapper.interpolator, (InterpolatorDelaunay, InterpolatorKNearestNeighbor) + ): + _plot_delaunay(ax, pixel_values, mapper, norm, colormap) + + # --- overlays -------------------------------------------------------------- + if lines is not None: + for line in lines: + if line is not None and len(line) > 0: + ax.plot(line[:, 1], line[:, 0], linewidth=2) + + if grid is not None: + ax.scatter(grid[:, 1], grid[:, 0], s=1, c="w", alpha=0.5) + + apply_extent(ax, extent) + + apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel) + + if owns_figure: + save_figure( + fig, + path=output_path or "", + filename=output_filename, + format=output_format, + ) + + +def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent): + """Render a rectangular pixelization reconstruction onto *ax*. + + Uses ``imshow`` for uniform rectangular grids + (``InterpolatorRectangularUniform``) and ``pcolormesh`` for non-uniform + rectangular grids. Both paths add a colorbar. + + Parameters + ---------- + ax + Matplotlib ``Axes`` to draw onto. + pixel_values + 1-D array of reconstructed flux values, one per source pixel. + ``None`` renders a zero-filled image. + mapper + Mapper object exposing ``interpolator``, ``mesh_geometry``, and + (for uniform grids) ``pixel_scales`` / ``origin``. + norm + ``matplotlib.colors.Normalize`` (or ``LogNorm``) instance, or + ``None`` for automatic scaling. + colormap + Matplotlib colormap name. + extent + ``[xmin, xmax, ymin, ymax]`` spatial extent; passed to ``imshow``. + """ + from autoarray.inversion.mesh.interpolator.rectangular_uniform import ( + InterpolatorRectangularUniform, + ) + import numpy as np + + shape_native = mapper.mesh_geometry.shape + + if pixel_values is None: + pixel_values = np.zeros(shape_native[0] * shape_native[1]) + + if isinstance(mapper.interpolator, InterpolatorRectangularUniform): + from autoarray.structures.arrays.uniform_2d import Array2D + from autoarray.structures.arrays import array_2d_util + + solution_array_2d = array_2d_util.array_2d_native_from( + array_2d_slim=pixel_values, + mask_2d=np.full(fill_value=False, shape=shape_native), + ) + pix_array = Array2D.no_mask( + values=solution_array_2d, + pixel_scales=mapper.mesh_geometry.pixel_scales, + origin=mapper.mesh_geometry.origin, + ) + im = ax.imshow( + pix_array.native.array, + cmap=colormap, + norm=norm, + extent=pix_array.geometry.extent, + aspect="auto", + origin="upper", + ) + plt.colorbar(im, ax=ax) + else: + y_edges, x_edges = mapper.mesh_geometry.edges_transformed.T + Y, X = np.meshgrid(y_edges, x_edges, indexing="ij") + im = ax.pcolormesh( + X, + Y, + pixel_values.reshape(shape_native), + shading="flat", + norm=norm, + cmap=colormap, + ) + plt.colorbar(im, ax=ax) + + +def _plot_delaunay(ax, pixel_values, mapper, norm, colormap): + """Render a Delaunay or KNN pixelization reconstruction onto *ax*. + + Uses ``ax.tripcolor`` with Gouraud shading so that the reconstructed + flux is interpolated smoothly across the triangulated source-plane mesh. + A colorbar is attached after rendering. + + Parameters + ---------- + ax + Matplotlib ``Axes`` to draw onto. + pixel_values + 1-D array of reconstructed flux values (one per source-plane pixel), + or an autoarray object exposing a ``.array`` attribute. + mapper + Mapper object exposing ``source_plane_mesh_grid`` — an ``(N, 2)`` + array of ``(y, x)`` mesh-point coordinates. + norm + ``matplotlib.colors.Normalize`` (or ``LogNorm``) instance, or + ``None`` for automatic scaling. + colormap + Matplotlib colormap name. + """ + mesh_grid = mapper.source_plane_mesh_grid + x = mesh_grid[:, 1] + y = mesh_grid[:, 0] + + if hasattr(pixel_values, "array"): + vals = pixel_values.array + else: + vals = pixel_values + + tc = ax.tripcolor(x, y, vals, cmap=colormap, norm=norm, shading="gouraud") + plt.colorbar(tc, ax=ax) diff --git a/autoarray/plot/mat_plot/__init__.py b/autoarray/plot/mat_plot/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/autoarray/plot/mat_plot/abstract.py b/autoarray/plot/mat_plot/abstract.py deleted file mode 100644 index 36f293fc3..000000000 --- a/autoarray/plot/mat_plot/abstract.py +++ /dev/null @@ -1,264 +0,0 @@ -from autoconf import conf - -from autoarray.plot.wrap.base.abstract import set_backend - -set_backend() - -import copy -import matplotlib.pyplot as plt -from typing import Optional, List, Tuple, Union - -from autoarray.plot.wrap import base as wb -from autoarray import exc - - -class AbstractMatPlot: - def __init__( - self, - units: Optional[wb.Units] = None, - figure: Optional[wb.Figure] = None, - axis: Optional[wb.Axis] = None, - cmap: Optional[wb.Cmap] = None, - colorbar: Optional[wb.Colorbar] = None, - colorbar_tickparams: Optional[wb.ColorbarTickParams] = None, - tickparams: Optional[wb.TickParams] = None, - yticks: Optional[wb.YTicks] = None, - xticks: Optional[wb.XTicks] = None, - title: Optional[wb.Title] = None, - ylabel: Optional[wb.YLabel] = None, - xlabel: Optional[wb.XLabel] = None, - text: Optional[Union[wb.Text, List[wb.Text]]] = None, - annotate: Optional[Union[wb.Annotate, List[wb.Annotate]]] = None, - legend: Optional[wb.Legend] = None, - output: Optional[wb.Output] = None, - ): - """ - Visualizes data structures (e.g an `Array2D`, `Grid2D`, `VectorField`, etc.) using Matplotlib. - - The `Plotter` is passed objects from the `wrap_base` package which wrap matplotlib plot functions and customize - the appearance of the plots of the data structure. If the values of these matplotlib wrapper objects are not - manually specified, they assume the default values provided in the `config.visualize.mat_*` `.ini` config files. - - The following data structures can be plotted using the following matplotlib functions: - - - `Array2D`:, using `plt.imshow`. - - `Grid2D`: using `plt.scatter`. - - `Line`: using `plt.plot`, `plt.semilogy`, `plt.loglog` or `plt.scatter`. - - `VectorField`: using `plt.quiver`. - - `RectangularMapper`: using `plt.imshow`. - - `MapperVorono`: using `plt.fill`. - - Parameters - ---------- - units - The units of the figure used to plot the data structure which sets the y and x ticks and labels. - figure - Opens the matplotlib figure before plotting via `plt.figure` and closes it once plotting is complete - via `plt.close` - axis - Sets the extent of the figure axis via `plt.axis` and allows for a manual axis range. - cmap - Customizes the colormap of the plot and its normalization via matplotlib `colors` objects such - as `colors.Normalize` and `colors.LogNorm`. - colorbar - Plots the colorbar of the plot via `plt.colorbar` and customizes its tick labels and values using method - like `cb.set_yticklabels`. - colorbar_tickparams - Customizes the yticks of the colorbar plotted via `plt.colorbar`. - tickparams - Customizes the appearances of the y and x ticks on the plot (e.g. the fontsize) using `plt.tick_params`. - yticks - Sets the yticks of the plot, including scaling them to new units depending on the `Units` object, via - `plt.yticks`. - xticks - Sets the xticks of the plot, including scaling them to new units depending on the `Units` object, via - `plt.xticks`. - title - Sets the figure title and customizes its appearance using `plt.title`. - ylabel - Sets the figure ylabel and customizes its appearance using `plt.ylabel`. - xlabel - Sets the figure xlabel and customizes its appearance using `plt.xlabel`. - text - Sets any text on the figure and customizes its appearance using `plt.text`. - annotate - Sets any annotations on the figure and customizes its appearance using `plt.annotate`. - legend - Sets whether the plot inclues a legend and customizes its appearance and labels using `plt.legend`. - output - Sets if the figure is displayed on the user's screen or output to `.png` using `plt.show` and `plt.savefig` - """ - - self.units = units or wb.Units(is_default=True) - self.figure = figure or wb.Figure(is_default=True) - self.axis = axis or wb.Axis(is_default=True) - - self.cmap = cmap or wb.Cmap(is_default=True) - - if colorbar is not False: - self.colorbar = colorbar or wb.Colorbar(is_default=True) - else: - self.colorbar = False - - self.colorbar_tickparams = colorbar_tickparams or wb.ColorbarTickParams( - is_default=True - ) - - self.tickparams = tickparams or wb.TickParams(is_default=True) - self.yticks = yticks or wb.YTicks(is_default=True) - self.xticks = xticks or wb.XTicks(is_default=True) - - self.title = title or wb.Title(is_default=True) - self.ylabel = ylabel or wb.YLabel(is_default=True) - self.xlabel = xlabel or wb.XLabel(is_default=True) - - self.text = text or wb.Text(is_default=True) - self.annotate = annotate or wb.Annotate(is_default=True) - self.legend = legend or wb.Legend(is_default=True) - self.output = output or wb.Output(is_default=True) - - self.number_subplots = None - self.subplot_shape = None - self.subplot_index = None - - def __add__(self, other): - """ - Adds two `MatPlot` classes together. - - A `MatPlot` class contains many of the `MatWrap` objects which customize matplotlib figures. One - may have a standard `MatPlot` object, which customizes many figures in the same way, for example: - - mat_plot_2d_base = aplt.MatPlot2D( - yticks=aplt.YTicks(fontsize=18), - xticks=aplt.XTicks(fontsize=18), - ylabel=aplt.YLabel(ylabel=""), - xlabel=aplt.XLabel(xlabel=""), - ) - - However, one may require many unique `MatPlot` objects for a number of different figures, which all use - these settings. These can be created by creating the unique `MatPlot` objects and adding the above object - to each: - - mat_plot_2d = aplt.MatPlot2D( - title=aplt.Title(label="Example Figure 1"), - ) - - mat_plot_2d = mat_plot_2d + mat_plot_2d_base - - mat_plot_2d = aplt.MatPlot2D( - title=aplt.Title(label="Example Figure 2"), - ) - - mat_plot_2d = mat_plot_2d + mat_plot_2d_base - """ - - other = copy.deepcopy(other) - - for attr, value in self.__dict__.items(): - try: - if value.kwargs.get("is_default") is not True: - other.__dict__[attr] = value - except AttributeError: - pass - - return other - - def set_for_subplot(self, is_for_subplot: bool): - """ - Sets the `is_for_subplot` attribute for every `MatWrap` object in this `MatPlot` object by updating - the `is_for_subplot`. By changing this tag: - - - The subplot: section of the config file of every `MatWrap` object is used instead of figure:. - - Calls which output or close the matplotlib figure are over-ridden so that the subplot is not removed. - - Parameters - ---------- - is_for_subplot - The entry the `is_for_subplot` attribute of every `MatWrap` object is set too. - """ - self.is_for_subplot = is_for_subplot - self.output.bypass = is_for_subplot - - for attr, value in self.__dict__.items(): - if hasattr(value, "is_for_subplot"): - value.is_for_subplot = is_for_subplot - - def get_subplot_shape(self, number_subplots): - """ - Get the size of a sub plotter in (total_y_pixels, total_x_pixels), based on the number of subplots that are - going to be plotted. - - Parameters - ---------- - number_subplots - The number of subplots that are to be plotted in the figure. - """ - - if self.subplot_shape is not None: - return self.subplot_shape - - subplot_shape_dict = conf.instance["visualize"]["general"]["subplot_shape"] - - try: - subplot_shape = subplot_shape_dict[number_subplots] - except KeyError: - try: - key = min( - filter(lambda x: x > number_subplots, subplot_shape_dict.keys()) - ) - subplot_shape = subplot_shape_dict[key] - except ValueError: - raise exc.PlottingException( - f""" - The number of subplots is greater than the maximum number of subplots specified - in the visualization/general.yaml config file, in the section "subplot_shape". - - The total number of subplots in the figure is {number_subplots}, whereas this config file only - specifies the subplot shape for up to a number of subplots less than this. - - To fix this, add a new entry to the "subplot_shape" section of the visualization/general.yaml. - """ - ) - - return tuple(map(int, subplot_shape[1:-1].split(","))) - - def setup_subplot( - self, - aspect: Optional[Tuple[float, float]] = None, - subplot_shape: Tuple[int, int] = None, - ): - """ - Setup a new figure to be plotted on a subplot, which is used by a `Plotter` when plotting multiple images - on a subplot. - - Every time a new figure is plotted on the subplot, the counter `subplot_index` increases by 1. - - The shape of the subplot is determined by the number of figures on the subplot. - - The aspect ratio of the subplot can be customized based on the size of the figures. - - Every time - - Parameters - ---------- - aspect - The aspect ratio of the overall subplot. - subplot_shape - The number of rows and columns in the subplot. - """ - if subplot_shape is None: - subplot_shape = self.get_subplot_shape(number_subplots=self.number_subplots) - - if aspect is None: - ax = plt.subplot(subplot_shape[0], subplot_shape[1], self.subplot_index) - else: - ax = plt.subplot( - subplot_shape[0], - subplot_shape[1], - self.subplot_index, - aspect=float(aspect), - ) - - self.subplot_index += 1 - - return ax diff --git a/autoarray/plot/mat_plot/one_d.py b/autoarray/plot/mat_plot/one_d.py deleted file mode 100644 index 79b721bba..000000000 --- a/autoarray/plot/mat_plot/one_d.py +++ /dev/null @@ -1,325 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -from typing import Iterable, Optional, List, Union - -from autoarray.plot.mat_plot.abstract import AbstractMatPlot -from autoarray.plot.auto_labels import AutoLabels -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.wrap import base as wb -from autoarray.plot.wrap import one_d as w1d -from autoarray.structures.arrays.uniform_1d import Array1D - - -class MatPlot1D(AbstractMatPlot): - def __init__( - self, - units: Optional[wb.Units] = None, - figure: Optional[wb.Figure] = None, - axis: Optional[wb.Axis] = None, - cmap: Optional[wb.Cmap] = None, - colorbar: Optional[wb.Colorbar] = None, - colorbar_tickparams: Optional[wb.ColorbarTickParams] = None, - tickparams: Optional[wb.TickParams] = None, - yticks: Optional[wb.YTicks] = None, - xticks: Optional[wb.XTicks] = None, - title: Optional[wb.Title] = None, - ylabel: Optional[wb.YLabel] = None, - xlabel: Optional[wb.XLabel] = None, - text: Optional[Union[wb.Text, List[wb.Text]]] = None, - annotate: Optional[Union[wb.Annotate, List[wb.Annotate]]] = None, - legend: Optional[wb.Legend] = None, - output: Optional[wb.Output] = None, - yx_plot: Optional[w1d.YXPlot] = None, - vertical_line_axvline: Optional[w1d.AXVLine] = None, - yx_scatter: Optional[w1d.YXPlot] = None, - fill_between: Optional[w1d.FillBetween] = None, - ): - """ - Visualizes 1D data structures (e.g a `Line`, etc.) using Matplotlib. - - The `Plotter` is passed objects from the `wrap_base` package which wrap matplotlib plot functions and customize - the appearance of the plots of the data structure. If the values of these matplotlib wrapper objects are not - manually specified, they assume the default values provided in the `config.visualize.mat_*` `.ini` config files. - - The following 1D data structures can be plotted using the following matplotlib functions: - - - `Line` using `plt.plot`. - - Parameters - ---------- - units - The units of the figure used to plot the data structure which sets the y and x ticks and labels. - figure - Opens the matplotlib figure before plotting via `plt.figure` and closes it once plotting is complete - via `plt.close`. - axis - Sets the extent of the figure axis via `plt.axis` and allows for a manual axis range. - cmap - Customizes the colormap of the plot and its normalization via matplotlib `colors` objects such - as `colors.Normalize` and `colors.LogNorm`. - colorbar - Plots the colorbar of the plot via `plt.colorbar` and customizes its tick labels and values using method - like `cb.set_yticklabels`. - colorbar_tickparams - Customizes the yticks of the colorbar plotted via `plt.colorbar`. - tickparams - Customizes the appearances of the y and x ticks on the plot, (e.g. the fontsize), using `plt.tick_params`. - yticks - Sets the yticks of the plot, including scaling them to new units depending on the `Units` object, via - `plt.yticks`. - xticks - Sets the xticks of the plot, including scaling them to new units depending on the `Units` object, via - `plt.xticks`. - title - Sets the figure title and customizes its appearance using `plt.title`. - ylabel - Sets the figure ylabel and customizes its appearance using `plt.ylabel`. - xlabel - Sets the figure xlabel and customizes its appearance using `plt.xlabel`. - text - Sets any text on the figure and customizes its appearance using `plt.text`. - annotate - Sets any annotations on the figure and customizes its appearance using `plt.annotate`. - legend - Sets whether the plot inclues a legend and customizes its appearance and labels using `plt.legend`. - output - Sets if the figure is displayed on the user's screen or output to `.png` using `plt.show` and `plt.savefig` - yx_plot - Sets how the y versus x plot appears, for example if it each axis is linear or log, using `plt.plot`. - vertical_line_axvline - Sets how a vertical line plotted on the figure using the `plt.axvline` method. - """ - - super().__init__( - units=units, - figure=figure, - axis=axis, - cmap=cmap, - colorbar=colorbar, - colorbar_tickparams=colorbar_tickparams, - tickparams=tickparams, - yticks=yticks, - xticks=xticks, - title=title, - ylabel=ylabel, - xlabel=xlabel, - text=text, - annotate=annotate, - legend=legend, - output=output, - ) - - self.yx_plot = yx_plot or w1d.YXPlot(is_default=True) - self.vertical_line_axvline = vertical_line_axvline or w1d.AXVLine( - is_default=True - ) - self.yx_scatter = yx_scatter or w1d.YXScatter(is_default=True) - self.fill_between = fill_between or w1d.FillBetween(is_default=True) - - self.is_for_multi_plot = False - self.is_for_subplot = False - - def set_for_multi_plot( - self, is_for_multi_plot: bool, color: str, xticks=None, yticks=None - ): - """ - Sets the `is_for_subplot` attribute for every `MatWrap` object in this `MatPlot` object by updating - the `is_for_subplot`. By changing this tag: - - - The subplot: section of the config file of every `MatWrap` object is used instead of figure:. - - Calls which output or close the matplotlib figure are over-ridden so that the subplot is not removed. - - Parameters - ---------- - is_for_subplot - The entry the `is_for_subplot` attribute of every `MatWrap` object is set too. - """ - self.is_for_multi_plot = is_for_multi_plot - self.output.bypass = is_for_multi_plot - - self.yx_plot.kwargs["c"] = color - self.vertical_line_axvline.kwargs["c"] = color - - self.vertical_line_axvline.no_label = True - - if yticks is not None: - self.yticks = yticks - - if xticks is not None: - self.xticks = xticks - - def plot_yx( - self, - y: Union[Array1D], - visuals_1d: Visuals1D, - auto_labels: AutoLabels, - x: Optional[Union[np.ndarray, Iterable, List, Array1D]] = None, - plot_axis_type_override: Optional[str] = None, - y_errors=None, - x_errors=None, - y_extra=None, - y_extra_2=None, - ls_errorbar="", - should_plot_grid=False, - should_plot_zero=False, - text_manual_dict=None, - text_manual_dict_y=None, - bypass: bool = False, - ): - - try: - y = y.array - except AttributeError: - pass - - try: - x = x.array - except AttributeError: - pass - - if (y is None) or np.count_nonzero(y) == 0 or np.isnan(y).all(): - return - - ax = None - - if not self.is_for_subplot: - fig, ax = self.figure.open() - else: - if not bypass: - ax = self.setup_subplot() - - self.title.set(auto_title=auto_labels.title) - - use_integers = False - - if x is None: - x = np.arange(len(y)) - use_integers = True - pixel_scales = (x[1] - x[0],) - x = Array1D.no_mask(values=x, pixel_scales=pixel_scales).array - - if self.yx_plot.plot_axis_type is None: - plot_axis_type = "linear" - else: - plot_axis_type = self.yx_plot.plot_axis_type - - if plot_axis_type_override is not None: - plot_axis_type = plot_axis_type_override - - label = self.legend.label or auto_labels.legend - - self.yx_plot.plot_y_vs_x( - y=y, - x=x, - label=label, - plot_axis_type=plot_axis_type, - y_errors=y_errors, - x_errors=x_errors, - y_extra=y_extra, - y_extra_2=y_extra_2, - ls_errorbar=ls_errorbar, - ) - - if should_plot_zero: - plt.plot(x, 1.0e-6 * np.ones(shape=y.shape), c="b", ls="--") - - if should_plot_grid: - plt.grid(True) - - if visuals_1d.shaded_region is not None: - self.fill_between.fill_between_shaded_regions( - x=x, y1=visuals_1d.shaded_region[0], y2=visuals_1d.shaded_region[1] - ) - - if "extent" in self.axis.config_dict: - self.axis.set() - - self.tickparams.set() - - if plot_axis_type == "symlog": - plt.yscale("symlog") - - if x_errors is not None: - min_value_x = np.nanmin(x - x_errors) - max_value_x = np.nanmax(x + x_errors) - else: - min_value_x = np.nanmin(x) - max_value_x = np.nanmax(x) - - if y_errors is not None: - min_value_y = np.nanmin(y - y_errors) - max_value_y = np.nanmax(y + y_errors) - else: - min_value_y = np.nanmin(y) - max_value_y = np.nanmax(y) - - if should_plot_zero: - if min_value_y > 0: - min_value_y = 0 - - self.xticks.set( - min_value=min_value_x, - max_value=max_value_x, - pixels=len(x), - units=self.units, - use_integers=use_integers, - is_for_1d_plot=True, - is_log10="loglog" in plot_axis_type, - ) - - self.yticks.set( - min_value=min_value_y, - max_value=max_value_y, - pixels=len(y), - units=self.units, - yunit=auto_labels.yunit, - is_for_1d_plot=True, - is_log10="log" in plot_axis_type, - ) - - self.title.set(auto_title=auto_labels.title) - self.ylabel.set(auto_label=auto_labels.ylabel) - self.xlabel.set(auto_label=auto_labels.xlabel) - - if not isinstance(self.text, list): - self.text.set() - else: - [text.set() for text in self.text] - - # This is a horrific hack to get CTI plots to work, refactor one day. - - from autoarray.plot.wrap.base.text import Text - - if text_manual_dict is not None and ax is not None: - y = text_manual_dict_y - text_manual_list = [] - - for key, value in text_manual_dict.items(): - text_manual_list.append( - Text( - x=0.95, - y=y, - s=f"{key} : {value}", - c="b", - transform=ax.transAxes, - horizontalalignment="right", - fontsize=12, - ) - ) - y = y - 0.05 - - [text.set() for text in text_manual_list] - - if not isinstance(self.annotate, list): - self.annotate.set() - else: - [annotate.set() for annotate in self.annotate] - - visuals_1d.plot_via_plotter(plotter=self) - - if label is not None: - self.legend.set() - - if (not self.is_for_subplot) and (not self.is_for_multi_plot): - self.output.to_figure(structure=None, auto_filename=auto_labels.filename) - self.figure.close() diff --git a/autoarray/plot/mat_plot/two_d.py b/autoarray/plot/mat_plot/two_d.py deleted file mode 100644 index 55e929869..000000000 --- a/autoarray/plot/mat_plot/two_d.py +++ /dev/null @@ -1,723 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -from typing import Optional, List, Union - -from autoconf import conf - -from autoarray.inversion.mesh.interpolator.rectangular import ( - InterpolatorRectangular, -) -from autoarray.inversion.mesh.interpolator.rectangular_uniform import ( - InterpolatorRectangularUniform, -) -from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay -from autoarray.inversion.mesh.interpolator.knn import ( - InterpolatorKNearestNeighbor, -) -from autoarray.mask.derive.zoom_2d import Zoom2D -from autoarray.plot.mat_plot.abstract import AbstractMatPlot -from autoarray.plot.auto_labels import AutoLabels -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.structures.arrays.rgb import Array2DRGB - -from autoarray.structures.arrays import array_2d_util - -from autoarray import exc -from autoarray.plot.wrap import base as wb -from autoarray.plot.wrap import two_d as w2d - - -class MatPlot2D(AbstractMatPlot): - def __init__( - self, - units: Optional[wb.Units] = None, - figure: Optional[wb.Figure] = None, - axis: Optional[wb.Axis] = None, - cmap: Optional[wb.Cmap] = None, - colorbar: Optional[wb.Colorbar] = None, - colorbar_tickparams: Optional[wb.ColorbarTickParams] = None, - tickparams: Optional[wb.TickParams] = None, - yticks: Optional[wb.YTicks] = None, - xticks: Optional[wb.XTicks] = None, - title: Optional[wb.Title] = None, - ylabel: Optional[wb.YLabel] = None, - xlabel: Optional[wb.XLabel] = None, - text: Optional[Union[wb.Text, List[wb.Text]]] = None, - annotate: Optional[Union[wb.Annotate, List[wb.Annotate]]] = None, - legend: Optional[wb.Legend] = None, - output: Optional[wb.Output] = None, - array_overlay: Optional[w2d.ArrayOverlay] = None, - fill: Optional[w2d.Fill] = None, - contour: Optional[w2d.Contour] = None, - grid_scatter: Optional[w2d.GridScatter] = None, - grid_plot: Optional[w2d.GridPlot] = None, - grid_errorbar: Optional[w2d.GridErrorbar] = None, - vector_yx_quiver: Optional[w2d.VectorYXQuiver] = None, - patch_overlay: Optional[w2d.PatchOverlay] = None, - delaunay_drawer: Optional[w2d.DelaunayDrawer] = None, - origin_scatter: Optional[w2d.OriginScatter] = None, - mask_scatter: Optional[w2d.MaskScatter] = None, - border_scatter: Optional[w2d.BorderScatter] = None, - positions_scatter: Optional[w2d.PositionsScatter] = None, - index_scatter: Optional[w2d.IndexScatter] = None, - index_plot: Optional[w2d.IndexPlot] = None, - mesh_grid_scatter: Optional[w2d.MeshGridScatter] = None, - parallel_overscan_plot: Optional[w2d.ParallelOverscanPlot] = None, - serial_prescan_plot: Optional[w2d.SerialPrescanPlot] = None, - serial_overscan_plot: Optional[w2d.SerialOverscanPlot] = None, - use_log10: bool = False, - plot_mask: bool = True, - quick_update: bool = False, - ): - """ - Visualizes 2D data structures (e.g an `Array2D`, `Grid2D`, `VectorField`, etc.) using Matplotlib. - - The `Plotter` is passed objects from the `wrap` package which wrap matplotlib plot functions and customize - the appearance of the plots of the data structure. If the values of these matplotlib wrapper objects are not - manually specified, they assume the default values provided in the `config.visualize.mat_*` `.ini` config files. - - The following 2D data structures can be plotted using the following matplotlib functions: - - - `Array2D`:, using `plt.imshow`. - - `Grid2D`: using `plt.scatter`. - - `Line`: using `plt.plot`, `plt.semilogy`, `plt.loglog` or `plt.scatter`. - - `VectorField`: using `plt.quiver`. - - `RectangularMapper`: using `plt.imshow`. - - Parameters - ---------- - units - The units of the figure used to plot the data structure which sets the y and x ticks and labels. - figure - Opens the matplotlib figure before plotting via `plt.figure` and closes it once plotting is complete - via `plt.close`. - axis - Sets the extent of the figure axis via `plt.axis` and allows for a manual axis range. - cmap - Customizes the colormap of the plot and its normalization via matplotlib `colors` objects such - as `colors.Normalize` and `colors.LogNorm`. - colorbar - Plots the colorbar of the plot via `plt.colorbar` and customizes its tick labels and values using method - like `cb.set_yticklabels`. - colorbar_tickparams - Customizes the yticks of the colorbar plotted via `plt.colorbar`. - tickparams - Customizes the appearances of the y and x ticks on the plot, (e.g. the fontsize), using `plt.tick_params`. - yticks - Sets the yticks of the plot, including scaling them to new units depending on the `Units` object, via - `plt.yticks`. - xticks - Sets the xticks of the plot, including scaling them to new units depending on the `Units` object, via - `plt.xticks`. - title - Sets the figure title and customizes its appearance using `plt.title`. - ylabel - Sets the figure ylabel and customizes its appearance using `plt.ylabel`. - xlabel - Sets the figure xlabel and customizes its appearance using `plt.xlabel`. - text - Sets any text on the figure and customizes its appearance using `plt.text`. - annotate - Sets any annotations on the figure and customizes its appearance using `plt.annotate`. - legend - Sets whether the plot inclues a legend and customizes its appearance and labels using `plt.legend`. - output - Sets if the figure is displayed on the user's screen or output to `.png` using `plt.show` and `plt.savefig` - array_overlay - Overlays an input `Array2D` over the figure using `plt.imshow`. - fill - Sets the fill of the figure using `plt.fill` and customizes its appearance, such as the color and alpha. - contour - Overlays contours of an input `Array2D` over the figure using `plt.contour`. - grid_scatter - Scatters a `Grid2D` of (y,x) coordinates over the figure using `plt.scatter`. - grid_plot - Plots lines of data (e.g. a y versus x plot via `plt.plot`, vertical lines via `plt.avxline`, etc.) - vector_yx_quiver - Plots a `VectorField` object using the matplotlib function `plt.quiver`. - patch_overlay - Overlays matplotlib `patches.Patch` objects over the figure, such as an `Ellipse`. - delaunay_drawer - Draws a colored Delaunay mesh of pixels using `plt.tripcolor`. - origin_scatter - Scatters the (y,x) origin of the data structure on the figure. - mask_scatter - Scatters an input `Mask2d` over the plotted data structure's figure. - border_scatter - Scatters the border of an input `Mask2d` over the plotted data structure's figure. - positions_scatter - Scatters specific (y,x) coordinates input as a `Grid2DIrregular` object over the figure. - index_scatter - Scatters specific coordinates of an input `Grid2D` based on input values of the `Grid2D`'s 1D or 2D indexes. - mesh_grid_scatter - Scatters the `PixelizationGrid` of a `Mesh` object. - parallel_overscan_plot - Plots the parallel overscan on an `Array2D` data structure representing a CCD imaging via `plt.plot`. - serial_prescan_plot - Plots the serial prescan on an `Array2D` data structure representing a CCD imaging via `plt.plot`. - serial_overscan_plot - Plots the serial overscan on an `Array2D` data structure representing a CCD imaging via `plt.plot`. - use_log10 - If True, the plot has a log10 colormap, colorbar and contours showing the values. - """ - - super().__init__( - units=units, - figure=figure, - axis=axis, - cmap=cmap, - colorbar=colorbar, - colorbar_tickparams=colorbar_tickparams, - tickparams=tickparams, - yticks=yticks, - xticks=xticks, - title=title, - ylabel=ylabel, - xlabel=xlabel, - text=text, - annotate=annotate, - legend=legend, - output=output, - ) - - self.array_overlay = array_overlay or w2d.ArrayOverlay(is_default=True) - self.fill = fill or w2d.Fill(is_default=True) - - self.contour = contour or w2d.Contour(is_default=True) - - self.grid_scatter = grid_scatter or w2d.GridScatter(is_default=True) - self.grid_plot = grid_plot or w2d.GridPlot(is_default=True) - self.grid_errorbar = grid_errorbar or w2d.GridErrorbar(is_default=True) - - self.vector_yx_quiver = vector_yx_quiver or w2d.VectorYXQuiver(is_default=True) - self.patch_overlay = patch_overlay or w2d.PatchOverlay(is_default=True) - - self.delaunay_drawer = delaunay_drawer or w2d.DelaunayDrawer(is_default=True) - - self.origin_scatter = origin_scatter or w2d.OriginScatter(is_default=True) - self.mask_scatter = mask_scatter or w2d.MaskScatter(is_default=True) - self.border_scatter = border_scatter or w2d.BorderScatter(is_default=True) - self.positions_scatter = positions_scatter or w2d.PositionsScatter( - is_default=True - ) - self.index_scatter = index_scatter or w2d.IndexScatter(is_default=True) - self.index_plot = index_plot or w2d.IndexPlot(is_default=True) - self.mesh_grid_scatter = mesh_grid_scatter or w2d.MeshGridScatter( - is_default=True - ) - - self.parallel_overscan_plot = ( - parallel_overscan_plot or w2d.ParallelOverscanPlot(is_default=True) - ) - self.serial_prescan_plot = serial_prescan_plot or w2d.SerialPrescanPlot( - is_default=True - ) - self.serial_overscan_plot = serial_overscan_plot or w2d.SerialOverscanPlot( - is_default=True - ) - - self.use_log10 = use_log10 - self.plot_mask = plot_mask - - self.is_for_subplot = False - self.quick_update = quick_update - - def plot_array( - self, - array: Array2D, - visuals_2d: Visuals2D, - auto_labels: AutoLabels, - grid_indexes=None, - bypass: bool = False, - ): - """ - Plot an `Array2D` data structure as a figure using the matplotlib wrapper objects and tools. - - This `Array2D` is plotted using `plt.imshow`. - - Parameters - ---------- - array - The 2D array of data_type which is plotted. - visuals_2d - Contains all the visuals that are plotted over the `Array2D` (e.g. the origin, mask, grids, etc.). - bypass - If `True`, `plt.close` is omitted and the matplotlib figure remains open. This is used when making subplots. - """ - - if array is None or np.all(array == 0): - return - - if self.use_log10 and (np.all(array == array[0]) or np.all(array < 0)): - return - - if array.pixel_scales is None and self.units.use_scaled: - raise exc.ArrayException( - "You cannot plot an array using its scaled unit_label if the input array does not have " - "a pixel scales attribute." - ) - - if conf.instance["visualize"]["general"]["general"]["zoom_around_mask"]: - - zoom = Zoom2D(mask=array.mask) - - buffer = 0 if array.mask.is_all_false else 1 - - array = zoom.array_2d_from(array=array, buffer=buffer) - - extent = array.geometry.extent - - ax = None - - if not self.is_for_subplot: - fig, ax = self.figure.open() - else: - if not bypass: - ax = self.setup_subplot() - - aspect = self.figure.aspect_from(shape_native=array.shape_native) - - norm = self.cmap.norm_from(array=array.array, use_log10=self.use_log10) - - origin = conf.instance["visualize"]["general"]["general"]["imshow_origin"] - - if isinstance(array, Array2DRGB): - - plt.imshow( - X=array.native.array, - aspect=aspect, - extent=extent, - origin=origin, - ) - - else: - - plt.imshow( - X=array.native.array, - aspect=aspect, - cmap=self.cmap.cmap, - norm=norm, - extent=extent, - origin=origin, - ) - - if visuals_2d.array_overlay is not None: - self.array_overlay.overlay_array( - array=visuals_2d.array_overlay, figure=self.figure - ) - - extent_axis = self.axis.config_dict.get("extent") - - if extent_axis is None: - extent_axis = extent - - self.axis.set(extent=extent_axis) - - self.tickparams.set() - - self.yticks.set( - min_value=extent_axis[2], - max_value=extent_axis[3], - units=self.units, - pixels=array.shape_native[0], - ) - - self.xticks.set( - min_value=extent_axis[0], - max_value=extent_axis[1], - units=self.units, - pixels=array.shape_native[1], - ) - - if isinstance(array, Array2DRGB): - title = "RGB" - else: - title = auto_labels.title - - self.title.set(auto_title=title, use_log10=self.use_log10) - self.ylabel.set() - self.xlabel.set() - - if not isinstance(self.text, list): - self.text.set() - else: - [text.set() for text in self.text] - - if not isinstance(self.annotate, list): - self.annotate.set() - else: - [annotate.set() for annotate in self.annotate] - - if self.colorbar is not False: - cb = self.colorbar.set( - units=self.units, - ax=ax, - norm=norm, - cb_unit=auto_labels.cb_unit, - use_log10=self.use_log10, - ) - self.colorbar_tickparams.set(cb=cb) - - if self.contour is not False: - try: - self.contour.set(array=array, extent=extent, use_log10=self.use_log10) - except ValueError: - pass - - if self.plot_mask and visuals_2d.mask is None: - - if not array.mask.is_all_false: - - self.mask_scatter.scatter_grid(grid=array.mask.derive_grid.edge.array) - - visuals_2d.plot_via_plotter(plotter=self, grid_indexes=grid_indexes) - - if not self.is_for_subplot and not bypass: - self.output.to_figure(structure=array, auto_filename=auto_labels.filename) - self.figure.close() - - def plot_grid( - self, - grid, - visuals_2d: Visuals2D, - auto_labels: AutoLabels, - color_array=None, - y_errors=None, - x_errors=None, - plot_grid_lines=False, - plot_over_sampled_grid=False, - buffer=0.1, - ): - """Plot a grid of (y,x) Cartesian coordinates as a scatter plotter of points. - - Parameters - ---------- - grid - The (y,x) coordinates of the grid, in an array of shape (total_coordinates, 2). - indexes - A set of points that are plotted in a different colour for emphasis (e.g. to show the mappings between \ - different planes). - """ - - if not self.is_for_subplot: - fig, ax = self.figure.open() - else: - ax = self.setup_subplot() - - if plot_over_sampled_grid: - grid_plot = grid.over_sampled - else: - grid_plot = grid - - if color_array is None: - if y_errors is None and x_errors is None: - self.grid_scatter.scatter_grid(grid=grid_plot.array) - else: - self.grid_errorbar.errorbar_grid( - grid=grid_plot.array, y_errors=y_errors, x_errors=x_errors - ) - - elif color_array is not None: - cmap = plt.get_cmap(self.cmap.cmap) - - if y_errors is None and x_errors is None: - self.grid_scatter.scatter_grid_colored( - grid=grid.array, color_array=color_array, cmap=cmap - ) - else: - self.grid_errorbar.errorbar_grid_colored( - grid=grid.array, - cmap=cmap, - color_array=color_array, - y_errors=y_errors, - x_errors=x_errors, - ) - - if self.colorbar is not None: - - colorbar = self.colorbar.set_with_color_values( - units=self.units, - cmap=self.cmap.cmap, - color_values=color_array, - ax=ax, - ) - if colorbar is not None and self.colorbar_tickparams is not None: - self.colorbar_tickparams.set(cb=colorbar) - - self.title.set(auto_title=auto_labels.title) - self.ylabel.set() - self.xlabel.set() - - if not isinstance(self.text, list): - self.text.set() - else: - [text.set() for text in self.text] - - if not isinstance(self.annotate, list): - self.annotate.set() - else: - [annotate.set() for annotate in self.annotate] - - extent = self.axis.config_dict.get("extent") - - if extent is None: - extent = grid.extent_with_buffer_from(buffer=buffer) - - if plot_grid_lines: - self.grid_plot.plot_rectangular_grid_lines( - extent=grid.geometry.extent, - shape_native=grid.shape_native, - ) - - self.axis.set(extent=extent, grid=grid) - - self.tickparams.set() - - if not self.axis.symmetric_around_centre: - self.yticks.set(min_value=extent[2], max_value=extent[3], units=self.units) - self.xticks.set(min_value=extent[0], max_value=extent[1], units=self.units) - - if self.contour is not False: - self.contour.set(array=color_array, extent=extent, use_log10=self.use_log10) - - visuals_2d.plot_via_plotter(plotter=self, grid_indexes=grid.array) - - if not self.is_for_subplot: - self.output.to_figure(structure=grid, auto_filename=auto_labels.filename) - self.figure.close() - - def plot_mapper( - self, - mapper, - visuals_2d: Visuals2D, - auto_labels: AutoLabels, - pixel_values: np.ndarray = Optional[None], - zoom_to_brightest: bool = True, - ): - if isinstance(mapper.interpolator, InterpolatorRectangular) or isinstance( - mapper.interpolator, InterpolatorRectangularUniform - ): - self._plot_rectangular_mapper( - mapper=mapper, - visuals_2d=visuals_2d, - auto_labels=auto_labels, - pixel_values=pixel_values, - zoom_to_brightest=zoom_to_brightest, - ) - - elif isinstance(mapper.interpolator, InterpolatorDelaunay) or isinstance( - mapper.interpolator, InterpolatorKNearestNeighbor - ): - self._plot_delaunay_mapper( - mapper=mapper, - visuals_2d=visuals_2d, - auto_labels=auto_labels, - pixel_values=pixel_values, - zoom_to_brightest=zoom_to_brightest, - ) - - def _plot_rectangular_mapper( - self, - mapper, - visuals_2d: Visuals2D, - auto_labels: AutoLabels, - pixel_values: np.ndarray = Optional[None], - zoom_to_brightest: bool = True, - ): - if pixel_values is not None: - solution_array_2d = array_2d_util.array_2d_native_from( - array_2d_slim=pixel_values, - mask_2d=np.full(fill_value=False, shape=mapper.mesh_geometry.shape), - ) - - pixel_values = Array2D.no_mask( - values=solution_array_2d, - pixel_scales=mapper.mesh_geometry.pixel_scales, - origin=mapper.mesh_geometry.origin, - ) - - extent = self.axis.config_dict.get("extent") - if extent is None: - extent = mapper.extent_from( - values=pixel_values, zoom_to_brightest=zoom_to_brightest - ) - - aspect_inv = self.figure.aspect_for_subplot_from(extent=extent) - - if not self.is_for_subplot: - fig, ax = self.figure.open() - else: - ax = self.setup_subplot(aspect=aspect_inv) - - shape_native = mapper.mesh_geometry.shape - - if pixel_values is not None: - - from autoarray.inversion.mesh.interpolator.rectangular_uniform import ( - InterpolatorRectangularUniform, - ) - from autoarray.inversion.mesh.interpolator.rectangular import ( - InterpolatorRectangular, - ) - - if isinstance(mapper.interpolator, InterpolatorRectangularUniform): - - self.plot_array( - array=pixel_values, - visuals_2d=visuals_2d, - auto_labels=auto_labels, - bypass=True, - ) - - else: - - norm = self.cmap.norm_from( - array=pixel_values.array, use_log10=self.use_log10 - ) - - # Unpack edges (assuming shape is (N_edges, 2) → (y_edges, x_edges)) - y_edges, x_edges = ( - mapper.mesh_geometry.edges_transformed.T - ) # explicit, safe - - # Build meshes with ij-indexing: (row = y, col = x) - Y, X = np.meshgrid(y_edges, x_edges, indexing="ij") - - plt.pcolormesh( - X, # x-grid - Y, # y-grid - pixel_values.array.reshape(shape_native), # (ny, nx) - shading="flat", - norm=norm, - cmap=self.cmap.cmap, - ) - - if self.colorbar is not False: - - cb = self.colorbar.set( - units=self.units, - ax=ax, - norm=norm, - cb_unit=auto_labels.cb_unit, - use_log10=self.use_log10, - ) - self.colorbar_tickparams.set(cb=cb) - - extent_axis = self.axis.config_dict.get("extent") - - if extent_axis is None: - extent_axis = extent - - self.axis.set(extent=extent_axis) - - self.tickparams.set() - self.yticks.set( - min_value=extent_axis[2], - max_value=extent_axis[3], - units=self.units, - pixels=shape_native[0], - ) - - self.xticks.set( - min_value=extent_axis[0], - max_value=extent_axis[1], - units=self.units, - pixels=shape_native[1], - ) - - if not isinstance(self.text, list): - self.text.set() - else: - [text.set() for text in self.text] - - if not isinstance(self.annotate, list): - self.annotate.set() - else: - [annotate.set() for annotate in self.annotate] - - # self.grid_plot.plot_rectangular_grid_lines( - # extent=mapper.source_plane_mesh_grid.geometry.extent, - # shape_native=mapper.shape_native, - # ) - - self.title.set(auto_title=auto_labels.title) - self.ylabel.set() - self.xlabel.set() - - visuals_2d.plot_via_plotter( - plotter=self, grid_indexes=mapper.source_plane_data_grid.over_sampled - ) - - if not self.is_for_subplot: - self.output.to_figure(structure=None, auto_filename=auto_labels.filename) - self.figure.close() - - def _plot_delaunay_mapper( - self, - mapper, - visuals_2d: Visuals2D, - auto_labels: AutoLabels, - pixel_values: np.ndarray = Optional[None], - zoom_to_brightest: bool = True, - ): - extent = self.axis.config_dict.get("extent") - if extent is None: - extent = mapper.extent_from( - values=pixel_values, zoom_to_brightest=zoom_to_brightest - ) - - aspect_inv = self.figure.aspect_for_subplot_from(extent=extent) - - if not self.is_for_subplot: - fig, ax = self.figure.open() - else: - ax = self.setup_subplot(aspect=aspect_inv) - - self.axis.set(extent=extent, grid=mapper.source_plane_mesh_grid) - - plt.gca().set_aspect(aspect_inv) - - self.tickparams.set() - self.yticks.set(min_value=extent[2], max_value=extent[3], units=self.units) - self.xticks.set(min_value=extent[0], max_value=extent[1], units=self.units) - - if not isinstance(self.text, list): - self.text.set() - else: - [text.set() for text in self.text] - - if not isinstance(self.annotate, list): - self.annotate.set() - else: - [annotate.set() for annotate in self.annotate] - - interpolation_array = None - - if hasattr(pixel_values, "array"): - pixel_values = pixel_values.array - - self.delaunay_drawer.draw_delaunay_pixels( - mapper=mapper, - pixel_values=pixel_values, - units=self.units, - cmap=self.cmap, - colorbar=self.colorbar, - colorbar_tickparams=self.colorbar_tickparams, - ax=ax, - use_log10=self.use_log10, - ) - - self.title.set(auto_title=auto_labels.title) - self.ylabel.set() - self.xlabel.set() - - visuals_2d.plot_via_plotter( - plotter=self, grid_indexes=mapper.source_plane_data_grid.over_sampled - ) - - if not self.is_for_subplot: - self.output.to_figure( - structure=interpolation_array, auto_filename=auto_labels.filename - ) - self.figure.close() diff --git a/autoarray/plot/multi_plotters.py b/autoarray/plot/multi_plotters.py deleted file mode 100644 index 84926a416..000000000 --- a/autoarray/plot/multi_plotters.py +++ /dev/null @@ -1,420 +0,0 @@ -import os -from pathlib import Path -from typing import List, Optional, Tuple - -from autoarray.plot.wrap.base.ticks import YTicks -from autoarray.plot.wrap.base.ticks import XTicks - - -class MultiFigurePlotter: - def __init__( - self, - plotter_list, - subplot_shape: Tuple[int, int] = None, - subplot_title: Optional[str] = None, - ): - """ - Plots multiple figures of plotter objects on the same subplot. - - For example, suppose you have multiple `ImagingPlotter` objects corresponding to different `Imaging` objects. - You may want to plot the `data`, `noise_map` and `psf` of each imaging dataset on the same subplot, so that - their data, noise-map and psf can be easily compared. - - The `MultiFigurePlotter` object allows you to do this, receiving a list of `Plotter` objects and calling - their `figure` methods to plot each on the same subplot. - - This requires careful inputs to the plotting functions in order to ensure that the correct plotting - functions are called for each plotter object. - - Parameters - ---------- - plotter_list - The list of plotter objects that are plotted on the same subplot. - subplot_shape - Optionally input the shape of the subplot (e.g. 2, 2) which is used to determine the shape of the figures - on the subplot. If not input, the subplot shape is determined automatically via config files. - subplot_title - Optionally input a title for the subplot. - """ - - self.plotter_list = plotter_list - self.subplot_shape = subplot_shape - self.subplot_title = subplot_title - - def setup_subplot_via_mat_plot( - self, plotter, number_subplots: int, subplot_index: int - ): - """ - Sets the `MatPlot` internal attributes which track if the plot is being made on a subplot and what index - the subplot is in the figure. - - Outside of the `MultiFigurePlotter` class when subplots are made from a single `Plotter` object, these - attributes are updated after each subplot figure is made. - - This class plots multiple plotter objects on the same subplot, so these attributes tracked and updated - separately for each plotter object by this class. - - Parameters - ---------- - plotter - The plotter which is used to plot the next figure on the subplot and therefore requires its mat-plot - subplot attributes to be updated. - number_subplots - The number of subplots that are being made on the figure. - subplot_index - The index of the subplot that is next being made on the figure uing the input plotter object. - """ - - try: - plotter.mat_plot_2d.set_for_subplot(is_for_subplot=True) - plotter.mat_plot_2d.number_subplots = number_subplots - plotter.mat_plot_2d.subplot_shape = self.subplot_shape - plotter.mat_plot_2d.subplot_index = subplot_index - except AttributeError: - plotter.mat_plot_1d.set_for_subplot(is_for_subplot=True) - plotter.mat_plot_1d.number_subplots = number_subplots - plotter.mat_plot_1d.subplot_shape = self.subplot_shape - plotter.mat_plot_1d.subplot_index = subplot_index - - def plot_via_func(self, plotter, figure_name: str, func_name: str, kwargs): - """ - Plots a figure on the subplot using an input plotter object, figure name and function name. - - For example, if you have an `ImagingPlotter` object and you want to plot the `data` on the subplot, you would - input `plotter=imaging_plotter`, `figure_name='data'` and `func_name='figures_2d'`. - - The code then knows to call the `figures_2d` function of the `ImagingPlotter` object and plot the `data`. - - This function is called repeatedly for each plotter object in the `plotter_list` to plot each figure - on the subplot. - - Parameters - ---------- - plotter - The plotter object that is used to plot the figure on the subplot. - figure_name - The name of the figure that is plotted on the subplot. - func_name - The name of the function that is called to plot the figure on the subplot. - kwargs - Any additional keyword arguments that are passed to the function that plots the figure on the subplot. - """ - func = getattr(plotter, func_name) - - if figure_name is None: - func(**{**{}, **kwargs}) - else: - func(**{**{figure_name: True}, **kwargs}) - - def subplot_of_figure( - self, func_name: str, figure_name: str, filename_suffix: str = "", **kwargs - ): - """ - Outputs a subplot of figures of the plotter objects in the `plotter_list`, where only a single function name - and figure name is input. - - For example, if you have multiple `ImagingPlotter` objects and you want to plot the `data` of each on the same - subplot, you would input `func_name='figures_2d'` and `figure_name='data'`. - - This function cannot plot different attributes of the plotter objects on the same subplot, for example the - `data` and `noise_map` of the `ImagingPlotter` objects. For this, use the `subplot_of_figures_multi` function. - - Parameters - ---------- - func_name - The name of the function that is called to plot the figure on the subplot. - figure_name - The name of the figure that is plotted on the subplot. - filename_suffix - The suffix of the filename that the subplot is output to. - kwargs - Any additional keyword arguments that are passed to the function that plots the figure on the subplot. - """ - number_subplots = len(self.plotter_list) - - self.plotter_list[0].open_subplot_figure( - number_subplots=number_subplots, subplot_shape=self.subplot_shape - ) - - for i, plotter in enumerate(self.plotter_list): - self.setup_subplot_via_mat_plot( - plotter=plotter, number_subplots=number_subplots, subplot_index=i + 1 - ) - - self.plot_via_func( - plotter=plotter, - figure_name=figure_name, - func_name=func_name, - kwargs=kwargs, - ) - - self.output_subplot(filename_suffix=f"{figure_name}{filename_suffix}") - - def subplot_of_figures_multi( - self, - func_name_list: List[str], - figure_name_list: List[str], - filename_suffix: str = "", - subplot_index_offset: int = 0, - number_subplots: Optional[int] = None, - open_subplot: bool = True, - close_subplot: bool = True, - **kwargs, - ): - """ - Outputs a subplot of figures of the plotter objects in the `plotter_list`, where multiple function names and - figure names are input. - - For example, if you have multiple `ImagingPlotter` objects and you want to plot the `data` and `noise_map` of - each on the same subplot, you would input `func_name_list=['figures_2d', 'figures_2d']` and - `figure_name_list=['data', 'noise_map']`. - - Parameters - ---------- - func_name_list - The list of function names that are called to plot the figures on the subplot. - figure_name_list - The list of figure names that are plotted on the subplot. - filename_suffix - The suffix of the filename that the subplot is output to. - kwargs - Any additional keyword arguments that are passed to the function that plots the figure on the subplot. - """ - if number_subplots is None: - number_subplots = len(self.plotter_list) * len(func_name_list) - - if open_subplot: - self.plotter_list[0].open_subplot_figure( - number_subplots=number_subplots, subplot_shape=self.subplot_shape - ) - - for i, plotter in enumerate(self.plotter_list): - for j, (func_name, figure_name) in enumerate( - zip(func_name_list, figure_name_list) - ): - subplot_shape = self.plotter_list[0].mat_plot_2d.subplot_shape - - subplot_index = subplot_index_offset + (i * subplot_shape[1]) + j + 1 - - self.setup_subplot_via_mat_plot( - plotter=plotter, - number_subplots=number_subplots, - subplot_index=subplot_index, - ) - - self.plot_via_func( - plotter=plotter, - figure_name=figure_name, - func_name=func_name, - kwargs=kwargs, - ) - - if close_subplot: - self.output_subplot(filename_suffix=filename_suffix) - - def subplot_of_multi_yx_1d(self, filename_suffix="", **kwargs): - number_subplots = len(self.plotter_list) - - self.plotter_list[0].plotter_list[0].open_subplot_figure( - number_subplots=number_subplots, - subplot_shape=self.subplot_shape, - subplot_title=self.subplot_title, - ) - - for i, plotter in enumerate(self.plotter_list): - for plott in plotter.plotter_list: - plott.mat_plot_1d.set_for_subplot(is_for_subplot=True) - plott.mat_plot_1d.number_subplots = number_subplots - plott.mat_plot_1d.subplot_shape = self.subplot_shape - plott.mat_plot_1d.subplot_index = i + 1 - - func = getattr(plotter, "figure_1d") - func( - **{ - **{ - "func_name": "figure_1d", - "figure_name": None, - "is_for_subplot": True, - }, - **kwargs, - } - ) - - self.plotter_list[0].plotter_list[0].mat_plot_1d.output.subplot_to_figure( - auto_filename=f"subplot_{filename_suffix}" - ) - self.plotter_list[0].plotter_list[0].close_subplot_figure() - - def output_subplot(self, filename_suffix: str = ""): - """ - Outplot the subplot to a file after all figures have been plotted on the subplot. - - The multi-plotter requires its own output function to ensure that the subplot is output to a file, which - this provides. - - Parameters - ---------- - filename_suffix - The suffix of the filename that the subplot is output to. - """ - - plotter = self.plotter_list[0] - - if plotter.mat_plot_1d is not None: - plotter.mat_plot_1d.output.subplot_to_figure( - auto_filename=f"subplot_{filename_suffix}" - ) - if plotter.mat_plot_2d is not None: - plotter.mat_plot_2d.output.subplot_to_figure( - auto_filename=f"subplot_{filename_suffix}" - ) - plotter.close_subplot_figure() - - def output_to_fits( - self, - func_name_list: List[str], - figure_name_list: List[str], - filename: str, - tag_list: Optional[List[str]] = None, - remove_fits_first: bool = False, - **kwargs, - ): - """ - Outputs a list of figures of the plotter objects in the `plotter_list` to a single .fits file. - - This function takes as input lists of function names and figure names and then calls them via - the `plotter_list` with an interface that outputs each to a .fits file. - - For example, if you have multiple `ImagingPlotter` objects and want to output the `data` and `noise_map` of - each to a single .fits files, you would input: - - - `func_name_list=['figures_2d', 'figures_2d']` and - - `figure_name_list=['data', 'noise_map']`. - - The implementation of this code is hacky, with it using a specific interface in the `Output` object - which sets the format to `fits_multi` to call a function which outputs the .fits files. A major visualuzation - refactor is required to make this more elegant. - - Parameters - ---------- - func_name_list - The list of function names that are called to plot the figures on the subplot. - figure_name_list - The list of figure names that are plotted on the subplot. - filename - The filename that the .fits file is output to. - tag_list - The list of tags that are used to set the `EXTNAME` of each hdu of the .fits file. - remove_fits_first - If the .fits file already exists, it is removed before the new .fits file is output, else it is updated - with the figure going into the next hdu. - kwargs - Any additional keyword arguments that are passed to the function that plots the figure on the subplot. - """ - - output_path = self.plotter_list[0].mat_plot_2d.output.output_path_from( - format="fits_multi" - ) - output_fits_file = Path(output_path) / f"{filename}.fits" - - if remove_fits_first: - output_fits_file.unlink(missing_ok=True) - - for i, plotter in enumerate(self.plotter_list): - plotter.mat_plot_2d.output._format = "fits_multi" - - plotter.set_filename(filename=f"{filename}") - - for j, (func_name, figure_name) in enumerate( - zip(func_name_list, figure_name_list) - ): - if tag_list is not None: - plotter.mat_plot_2d.output._tag_fits_multi = tag_list[j] - - self.plot_via_func( - plotter=plotter, - figure_name=figure_name, - func_name=func_name, - kwargs=kwargs, - ) - - -class MultiYX1DPlotter: - def __init__( - self, - plotter_list, - color_list=None, - legend_labels=None, - y_manual_min_max_value=None, - x_manual_min_max_value=None, - ): - self.plotter_list = plotter_list - - if color_list is None: - color_list = 10 * ["k", "r", "b", "g", "c", "m", "y"] - - self.color_list = color_list - self.legend_labels = legend_labels - - self.y_manual_min_max_value = y_manual_min_max_value - self.x_manual_min_max_value = x_manual_min_max_value - - def figure_1d(self, func_name, figure_name, is_for_subplot=False, **kwargs): - if not is_for_subplot: - self.plotter_list[0].mat_plot_1d.figure.open() - - for i, plotter in enumerate(self.plotter_list): - plotter.set_mat_plot_1d_for_multi_plot( - is_for_multi_plot=True, - color=self.color_list[i], - yticks=self.yticks, - xticks=self.xticks, - ) - - if self.legend_labels is not None: - plotter.mat_plot_1d.yx_plot.label = self.legend_labels[i] - - func = getattr(plotter, func_name) - - if figure_name is None: - func(**{**{}, **kwargs}) - else: - func(**{**{figure_name: True}, **kwargs}) - - plotter.set_mat_plot_1d_for_multi_plot(is_for_multi_plot=False, color=None) - - if not is_for_subplot: - self.plotter_list[0].mat_plot_1d.output.subplot_to_figure( - auto_filename=f"multi_{figure_name}" - ) - self.plotter_list[0].mat_plot_1d.figure.close() - - @property - def yticks(self): - # TODO: Need to make this work for all plotters, rather than just y x, for example - # TODO : GalaxyPlotters where y and x are computed inside the function called via - # TODO : func(**{**{figure_name: True}, **kwargs}) - - if self.y_manual_min_max_value is not None: - return YTicks(manual_min_max_value=self.y_manual_min_max_value) - - try: - min_value = min([min(plotter.y) for plotter in self.plotter_list]) - max_value = max([max(plotter.y) for plotter in self.plotter_list]) - except AttributeError: - return - - return YTicks(manual_min_max_value=(min_value, max_value)) - - @property - def xticks(self): - if self.x_manual_min_max_value is not None: - return XTicks(manual_min_max_value=self.x_manual_min_max_value) - - try: - min_value = min([min(plotter.x) for plotter in self.plotter_list]) - max_value = max([max(plotter.x) for plotter in self.plotter_list]) - except AttributeError: - return - - return XTicks(manual_min_max_value=(min_value, max_value)) diff --git a/autoarray/plot/wrap/base/output.py b/autoarray/plot/output.py similarity index 71% rename from autoarray/plot/wrap/base/output.py rename to autoarray/plot/output.py index 9f7e111bc..eea6a63b5 100644 --- a/autoarray/plot/wrap/base/output.py +++ b/autoarray/plot/output.py @@ -66,18 +66,37 @@ def __init__( @property def format(self) -> str: + """The output format string; defaults to ``"show"`` when none was given.""" if self._format is None: return "show" return self._format @property def format_list(self): + """The output format(s) as a list, so iteration always works.""" if not isinstance(self.format, list): return [self.format] return self.format def output_path_from(self, format): - if format in "show": + """Return the directory path for *format*, creating it if necessary. + + When *format* is ``"show"`` returns ``None`` (no file is written). + When ``format_folder`` is ``True`` the format name is appended as a + sub-directory so that ``png`` and ``pdf`` outputs are kept separate. + + Parameters + ---------- + format + File format string, e.g. ``"png"``, ``"pdf"``, or ``"show"``. + + Returns + ------- + str or None + Absolute path to the output directory, or ``None`` for + ``format == "show"``. + """ + if format == "show": return None if self.format_folder: @@ -90,6 +109,23 @@ def output_path_from(self, format): return output_path def filename_from(self, auto_filename): + """Build the final filename string by applying prefix / suffix. + + When no explicit ``filename`` was passed to ``__init__`` the + *auto_filename* supplied by the calling plotter is used as the base. + + Parameters + ---------- + auto_filename + Fallback filename (without extension) when ``self.filename`` is + ``None``. + + Returns + ------- + str + The resolved filename with any configured prefix and suffix + applied. + """ filename = auto_filename if self.filename is None else self.filename if self.prefix is not None: @@ -101,7 +137,21 @@ def filename_from(self, auto_filename): return filename def savefig(self, filename: str, output_path: str, format: str): + """Call ``plt.savefig`` with the configured ``bbox_inches`` setting. + Catches ``ValueError`` exceptions (e.g. unsupported format) and logs + them without raising, so a single bad output format does not abort + the whole script. + + Parameters + ---------- + filename + Base file name without extension. + output_path + Directory to write the file (must already exist). + format + File format extension string, e.g. ``"png"``. + """ import matplotlib.pyplot as plt try: @@ -190,6 +240,20 @@ def subplot_to_figure( plt.show() def to_figure_output_mode(self, filename: str): + """Save the current figure as a numbered PNG snapshot in *output mode*. + + Output mode is activated by setting the environment variable + ``PYAUTOARRAY_OUTPUT_MODE=1``. Each call increments a global counter + so that figures are saved as ``0_filename.png``, ``1_filename.png``, + etc. in a sub-directory named after the running script. This is useful + for collecting a sequence of figures during automated testing or + demonstration scripts. + + Parameters + ---------- + filename + Base file name (without extension) for this figure. + """ global COUNT try: diff --git a/autoarray/plot/wrap/segmentdata.py b/autoarray/plot/segmentdata.py similarity index 100% rename from autoarray/plot/wrap/segmentdata.py rename to autoarray/plot/segmentdata.py diff --git a/autoarray/plot/utils.py b/autoarray/plot/utils.py new file mode 100644 index 000000000..0a23603f3 --- /dev/null +++ b/autoarray/plot/utils.py @@ -0,0 +1,519 @@ +""" +Shared utilities for the direct-matplotlib plot functions. +""" + +import logging +import os +from typing import List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# autoarray → numpy conversion helpers (used by high-level plot functions) +# --------------------------------------------------------------------------- + + +def auto_mask_edge(array) -> Optional[np.ndarray]: + """Return the edge-pixel ``(y, x)`` coordinates of an autoarray mask. + + Used to overlay the mask boundary on ``plot_array`` images. If *array* + has no ``mask`` attribute, or the mask is fully unmasked, ``None`` is + returned so no overlay is drawn. + + Parameters + ---------- + array + An autoarray ``Array2D`` (or any object with a ``.mask`` attribute + that exposes ``.derive_grid.edge.array``). + + Returns + ------- + numpy.ndarray or None + Shape ``(N, 2)`` float array of ``(y, x)`` edge coordinates, or + ``None`` when the array is unmasked or has no mask. + """ + try: + if not array.mask.is_all_false: + return np.array(array.mask.derive_grid.edge.array) + except AttributeError: + pass + return None + + +def zoom_array(array): + """Crop *array* around its mask when ``zoom_around_mask`` is enabled in config. + + Reads ``visualize/general/general/zoom_around_mask`` from the autoconf + configuration. When the flag is ``True`` and *array* carries a non-trivial + mask the array is cropped via ``Zoom2D`` so that downstream ``imshow`` + calls fill the axes without empty black borders. + + Parameters + ---------- + array + An autoarray ``Array2D`` (or any object). Plain numpy arrays are + returned unchanged. + + Returns + ------- + array + The (potentially cropped) array. If the config flag is ``False``, or + *array* has no mask / the mask is all-``False``, the input is returned + unmodified. + """ + try: + from autoconf import conf + + zoom_around_mask = conf.instance["visualize"]["general"]["general"][ + "zoom_around_mask" + ] + except Exception: + zoom_around_mask = False + + if zoom_around_mask and hasattr(array, "mask") and not array.mask.is_all_false: + from autoarray.mask.derive.zoom_2d import Zoom2D + + return Zoom2D(mask=array.mask).array_2d_from(array=array, buffer=1) + return array + + +def numpy_grid(grid) -> Optional[np.ndarray]: + """Convert a grid-like object to a plain ``(N, 2)`` numpy array, or ``None``. + + Accepts autoarray ``Grid2D`` / ``Grid2DIrregular`` objects (via their + ``.array`` attribute) as well as bare numpy arrays. ``None`` inputs are + passed through so callers can use this as a safe no-op. + + Parameters + ---------- + grid + An autoarray grid, a ``(N, 2)`` numpy array, or ``None``. + + Returns + ------- + numpy.ndarray or None + Plain ``(N, 2)`` float array with ``(y, x)`` columns, or ``None``. + """ + if grid is None: + return None + try: + return np.array(grid.array if hasattr(grid, "array") else grid) + except Exception: + return None + + +def numpy_lines(lines) -> Optional[List[np.ndarray]]: + """Convert a collection of lines to a list of ``(N, 2)`` numpy arrays. + + Accepts autoarray ``Grid2DIrregular`` objects or any iterable of + ``(N, 2)`` array-like sequences. Each element is converted to a plain + numpy array; elements that cannot be converted are silently skipped. + + Parameters + ---------- + lines + An autoarray grid collection, a list of ``(N, 2)`` arrays, or ``None``. + + Returns + ------- + list of numpy.ndarray or None + List of ``(N, 2)`` float arrays (``y`` column 0, ``x`` column 1), or + ``None`` when *lines* is ``None`` or no valid lines are found. + """ + if lines is None: + return None + result = [] + try: + for line in lines: + try: + arr = np.array(line.array if hasattr(line, "array") else line) + if arr.ndim == 2 and arr.shape[1] == 2: + result.append(arr) + except Exception: + pass + except TypeError: + pass + return result or None + + +def numpy_positions(positions) -> Optional[List[np.ndarray]]: + """Convert a positions object to a list of ``(N, 2)`` numpy arrays. + + Positions can be a single ``Grid2DIrregular`` (treated as one group), + a plain ``(N, 2)`` array (treated as one group), or a list of such + objects (each becomes one group, scatter-plotted in a distinct colour). + + Parameters + ---------- + positions + An autoarray ``Grid2DIrregular``, a ``(N, 2)`` numpy array, a list + of the above, or ``None``. + + Returns + ------- + list of numpy.ndarray or None + Each element is a ``(N, 2)`` array of ``(y, x)`` coordinates + representing one group of positions, or ``None`` when *positions* + is ``None`` or cannot be converted. + """ + if positions is None: + return None + try: + arr = np.array(positions.array if hasattr(positions, "array") else positions) + if arr.ndim == 2 and arr.shape[1] == 2: + return [arr] + except Exception: + pass + if isinstance(positions, list): + result = [] + for p in positions: + try: + result.append(np.array(p.array if hasattr(p, "array") else p)) + except Exception: + pass + return result or None + return None + + +def symmetric_vmin_vmax(array): + """Return ``(-abs_max, abs_max)`` colour limits for a symmetric residual colormap. + + Computes the maximum absolute value of *array* and returns symmetric limits + so that zero maps to the centre of the colormap. Typically applied to + residual maps and normalised residual maps. + + Parameters + ---------- + array + An autoarray ``Array2D`` (uses ``.native.array``) or a plain numpy + array. + + Returns + ------- + tuple of (float, float) or (None, None) + ``(vmin, vmax)`` where ``vmin == -vmax == -abs_max``. Returns + ``(None, None)`` if the computation fails (e.g. all-NaN input). + """ + try: + arr = array.native.array if hasattr(array, "native") else np.asarray(array) + abs_max = float(np.nanmax(np.abs(arr))) + return -abs_max, abs_max + except Exception: + return None, None + + +def symmetric_cmap_from(array, symmetric_value=None): + """Return a matplotlib ``Normalize`` centred on zero for a symmetric colormap. + + Parameters + ---------- + array + The data array (autoarray or numpy). Used to compute ``abs_max`` when + *symmetric_value* is not provided. + symmetric_value + If given, fix the half-range to this value (``vmin=-symmetric_value``, + ``vmax=+symmetric_value``). + + Returns + ------- + matplotlib.colors.Normalize or None + """ + import matplotlib.colors as colors + + if symmetric_value is not None: + abs_max = float(symmetric_value) + else: + vmin, vmax = symmetric_vmin_vmax(array) + if vmin is None: + return None + abs_max = max(abs(vmin), abs(vmax)) + + return colors.Normalize(vmin=-abs_max, vmax=abs_max) + + +def set_with_color_values(ax, cmap, color_values, norm=None, fraction=0.047, pad=0.01): + """Attach a colorbar to *ax* driven by *color_values* rather than a plotted artist. + + Useful for Delaunay mapper visualisation where ``ax.tripcolor`` already draws + the mesh but we need a separate colorbar tied to specific solution values. + + Parameters + ---------- + ax + The matplotlib axes to attach the colorbar to. + cmap + A matplotlib colormap name or object. + color_values + The 1-D array of values that define the colorbar range. + norm + A ``matplotlib.colors.Normalize`` instance. If ``None`` a default + ``Normalize(vmin, vmax)`` is created from *color_values*. + fraction, pad + Passed directly to ``plt.colorbar``. + """ + import matplotlib.cm as cm + import matplotlib.colors as mcolors + + if norm is None: + arr = np.asarray(color_values) + norm = mcolors.Normalize(vmin=float(np.nanmin(arr)), vmax=float(np.nanmax(arr))) + + mappable = cm.ScalarMappable(norm=norm, cmap=cmap) + mappable.set_array(color_values) + return plt.colorbar(mappable=mappable, ax=ax, fraction=fraction, pad=pad) + + +def subplot_save(fig, output_path, output_filename, output_format): + """Save a subplot figure to disk, or display it, then close it. + + All ``subplot_*`` functions call this as their final step. When + *output_path* is non-empty the figure is written to + ``/.``; otherwise + ``plt.show()`` is called. ``plt.close(fig)`` is always called to + release memory. + + Parameters + ---------- + fig + The matplotlib ``Figure`` to save or show. + output_path + Directory to write the file. Creates the directory if needed. + ``None`` or an empty string causes ``plt.show()`` to be called. + output_filename + Base file name without extension. + output_format + File format string, e.g. ``"png"`` or ``"pdf"``. + """ + if output_path: + os.makedirs(output_path, exist_ok=True) + try: + fig.savefig( + os.path.join(output_path, f"{output_filename}.{output_format}"), + bbox_inches="tight", + pad_inches=0.1, + ) + except Exception as exc: + logger.warning( + f"subplot_save: could not save {output_filename}.{output_format}: {exc}" + ) + else: + plt.show() + plt.close(fig) + + +def conf_mat_plot_fontsize(section: str, default: int) -> int: + """Read a font size from the ``mat_plot`` section of ``visualize/general.yaml``. + + Parameters + ---------- + section + Sub-key inside ``mat_plot``, e.g. ``"title"``, ``"xlabel"``, + ``"ylabel"``, ``"xticks"``, or ``"yticks"``. + default + Value returned when the config key is absent or unreadable. + + Returns + ------- + int + The configured font size. + """ + try: + from autoconf import conf + + return int( + conf.instance["visualize"]["general"]["mat_plot"][section]["fontsize"] + ) + except Exception: + return default + + +def _parse_figsize(raw) -> Tuple[int, int]: + """Convert *raw* (a tuple/list or a string like ``"(7, 7)"``) to a 2-tuple.""" + if isinstance(raw, (tuple, list)): + return tuple(raw) + import ast + + return tuple(ast.literal_eval(str(raw))) + + +def conf_figsize(context: str = "figures") -> Tuple[int, int]: + """ + Read figsize from ``visualize/general.yaml`` for the given context. + + For single-panel figures the value is taken from + ``mat_plot/figure/figsize``; the *context* argument is kept for + backward compatibility with subplot callers that pass ``"subplots"``. + + Parameters + ---------- + context + ``"figures"`` (single-panel) or ``"subplots"`` (multi-panel). + """ + try: + from autoconf import conf + + if context == "figures": + raw = conf.instance["visualize"]["general"]["mat_plot"]["figure"]["figsize"] + return _parse_figsize(raw) + return tuple(conf.instance["visualize"]["general"][context]["figsize"]) + except Exception: + return (7, 7) if context == "figures" else (19, 16) + + +def apply_labels( + ax: plt.Axes, + title: str = "", + xlabel: str = "", + ylabel: str = "", +) -> None: + """Apply title, axis labels, and tick font sizes to *ax* from config. + + Reads font sizes from the ``mat_plot`` section of + ``visualize/general.yaml`` so that users can override them globally + without touching call sites. Falls back to the values that the + old ``MatWrap`` system used when the config is unavailable. + + Parameters + ---------- + ax + The matplotlib axes to configure. + title + Title string. + xlabel + X-axis label string. + ylabel + Y-axis label string. + """ + title_fs = conf_mat_plot_fontsize("title", default=16) + xlabel_fs = conf_mat_plot_fontsize("xlabel", default=14) + ylabel_fs = conf_mat_plot_fontsize("ylabel", default=14) + xticks_fs = conf_mat_plot_fontsize("xticks", default=12) + yticks_fs = conf_mat_plot_fontsize("yticks", default=12) + + ax.set_title(title, fontsize=title_fs) + ax.set_xlabel(xlabel, fontsize=xlabel_fs) + ax.set_ylabel(ylabel, fontsize=ylabel_fs) + ax.tick_params(axis="x", labelsize=xticks_fs) + ax.tick_params(axis="y", labelsize=yticks_fs) + + +def save_figure( + fig: plt.Figure, + path: str, + filename: str, + format: str = "png", + dpi: int = 300, + structure=None, +) -> None: + """ + Save *fig* to ``/.`` then close it. + + If *path* is an empty string or ``None``, ``plt.show()`` is called instead. + After either action ``plt.close(fig)`` is always called to free memory. + + Parameters + ---------- + fig + The matplotlib figure to save. + path + Directory where the file is written. Created if it does not exist. + filename + File name without extension. + format + File format(s) passed to ``fig.savefig``. Either a single string + (e.g. ``"png"``) or a list/tuple of strings (e.g. ``["png", "pdf"]``) + to save in multiple formats in one call. + dpi + Resolution in dots per inch. + structure + Optional autoarray structure (e.g. ``Array2D``). Required when + *format* is ``"fits"`` — its ``output_to_fits`` method is used + instead of ``fig.savefig``. + """ + if path: + os.makedirs(path, exist_ok=True) + formats = format if isinstance(format, (list, tuple)) else [format] + for fmt in formats: + if fmt == "fits": + if structure is not None and hasattr(structure, "output_to_fits"): + structure.output_to_fits( + file_path=os.path.join(path, f"{filename}.fits"), + overwrite=True, + ) + else: + logger.warning( + f"save_figure: fits format requested for {filename} but no " + "compatible structure was provided; skipping." + ) + else: + try: + fig.savefig( + os.path.join(path, f"{filename}.{fmt}"), + dpi=dpi, + bbox_inches="tight", + pad_inches=0.1, + ) + except Exception as exc: + logger.warning( + f"save_figure: could not save {filename}.{fmt}: {exc}" + ) + else: + plt.show() + + plt.close(fig) + + +def plot_visibilities_1d(vis, ax: plt.Axes, title: str = "") -> None: + """Plot the real and imaginary components of a visibilities array as 1D line plots. + + Draws two overlapping lines — one for the real part and one for the + imaginary part — with a legend. Used by interferometer subplot functions + to visualise raw or residual visibilities. + + Parameters + ---------- + vis + A ``Visibilities`` autoarray object (accessed via ``.slim``) or any + array-like that can be cast to a complex numpy array. + ax + Matplotlib ``Axes`` to draw onto. + title + Axes title string. + """ + try: + y = np.array(vis.slim if hasattr(vis, "slim") else vis) + except Exception: + y = np.asarray(vis) + ax.plot(y.real, label="Real", alpha=0.7) + ax.plot(y.imag, label="Imaginary", alpha=0.7) + ax.set_title(title) + ax.legend(fontsize=8) + + +def apply_extent( + ax: plt.Axes, + extent: Tuple[float, float, float, float], + n_ticks: int = 3, +) -> None: + """ + Apply axis limits and evenly spaced linear ticks to *ax*. + + Parameters + ---------- + ax + The matplotlib axes to configure. + extent + ``[xmin, xmax, ymin, ymax]`` limits. + n_ticks + Number of ticks on each axis. ``3`` produces ``[-R, 0, R]`` for + a symmetric extent, matching the reference ``plot_grid`` example. + """ + xmin, xmax, ymin, ymax = extent + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + ax.set_xticks(np.linspace(xmin, xmax, n_ticks)) + ax.set_yticks(np.linspace(ymin, ymax, n_ticks)) diff --git a/autoarray/plot/visuals/__init__.py b/autoarray/plot/visuals/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/autoarray/plot/visuals/abstract.py b/autoarray/plot/visuals/abstract.py deleted file mode 100644 index 35583e985..000000000 --- a/autoarray/plot/visuals/abstract.py +++ /dev/null @@ -1,47 +0,0 @@ -from abc import ABC - - -class AbstractVisuals(ABC): - def __add__(self, other): - """ - Adds two `Visuals` classes together. - - When we perform plotting, the `Include` class is used to create additional `Visuals` class from the data - structures that are plotted, for example: - - mask = Mask2D.circular(shape_native=(100, 100), pixel_scales=0.1, radius=3.0) - array = Array2D.ones(shape_native=(100, 100), pixel_scales=0.1) - masked_array = al.Array2D(values=array, mask=mask) - array_plotter = aplt.Array2DPlotter(array=masked_array) - array_plotter.figure() - - If the user did not manually input a `Visuals2D` object, the one created in `function_array` is the one used to - plot the image - - However, if the user specifies their own `Visuals2D` object and passed it to the plotter, e.g.: - - visuals_2d = Visuals2D(origin=(0.0, 0.0)) - array_plotter = aplt.Array2DPlotter(array=masked_array) - - We now wish for the `Plotter` to plot the `origin` in the user's input `Visuals2D` object. To achieve this, - one `Visuals2D` object is created: (i) the user's input instance (with an origin). - - This `__add__` override means we can add the two together to make the final `Visuals2D` object that is - plotted on the figure containing both the `origin` and `Mask2D`.: - - visuals_2d = visuals_2d_via_user + visuals_2d_via_include - - The ordering of the addition has been specifically chosen to ensure that the `visuals_2d_via_user` does not - retain the attributes that are added to it by the `visuals_2d_via_include`. This ensures that if multiple plots - are made, the same `visuals_2d_via_user` is used for every plot. If this were not the case, it would - permanently inherit attributes from the `Visuals` from the `Include` method and plot them on all figures. - """ - - for attr, value in self.__dict__.items(): - try: - if other.__dict__[attr] is None and self.__dict__[attr] is not None: - other.__dict__[attr] = self.__dict__[attr] - except KeyError: - pass - - return other diff --git a/autoarray/plot/visuals/one_d.py b/autoarray/plot/visuals/one_d.py deleted file mode 100644 index b84a832b3..000000000 --- a/autoarray/plot/visuals/one_d.py +++ /dev/null @@ -1,32 +0,0 @@ -import numpy as np -from typing import List, Optional, Union - -from autoarray.mask.mask_1d import Mask1D -from autoarray.plot.visuals.abstract import AbstractVisuals -from autoarray.structures.arrays.uniform_1d import Array1D -from autoarray.structures.grids.uniform_1d import Grid1D - - -class Visuals1D(AbstractVisuals): - def __init__( - self, - origin: Optional[Grid1D] = None, - mask: Optional[Mask1D] = None, - points: Optional[Grid1D] = None, - vertical_line: Optional[float] = None, - shaded_region: Optional[List[Union[List, Array1D, np.ndarray]]] = None, - ): - self.origin = origin - self.mask = mask - self.points = points - self.vertical_line = vertical_line - self.shaded_region = shaded_region - - def plot_via_plotter(self, plotter): - if self.points is not None: - plotter.yx_scatter.scatter_yx(y=self.points, x=np.arange(len(self.points))) - - if self.vertical_line is not None: - plotter.vertical_line_axvline.axvline_vertical_line( - vertical_line=self.vertical_line - ) diff --git a/autoarray/plot/visuals/two_d.py b/autoarray/plot/visuals/two_d.py deleted file mode 100644 index a7fb6ca0b..000000000 --- a/autoarray/plot/visuals/two_d.py +++ /dev/null @@ -1,104 +0,0 @@ -from matplotlib import patches as ptch -import numpy as np -from typing import List, Optional, Union - -from autoarray.mask.mask_2d import Mask2D -from autoarray.plot.visuals.abstract import AbstractVisuals -from autoarray.structures.arrays.uniform_1d import Array1D -from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.structures.grids.uniform_2d import Grid2D -from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.structures.vectors.irregular import VectorYX2DIrregular - - -class Visuals2D(AbstractVisuals): - def __init__( - self, - origin: Optional[Grid2D] = None, - mask: Optional[Mask2D] = None, - border: Optional[Grid2D] = None, - lines: Optional[Union[List[Array1D], Grid2DIrregular]] = None, - positions: Optional[Union[Grid2DIrregular, List[Grid2DIrregular]]] = None, - grid: Optional[Grid2D] = None, - mesh_grid: Optional[Grid2D] = None, - vectors: Optional[VectorYX2DIrregular] = None, - patches: Optional[List[ptch.Patch]] = None, - fill_region: Optional[List] = None, - array_overlay: Optional[Array2D] = None, - parallel_overscan=None, - serial_prescan=None, - serial_overscan=None, - indexes=None, - ): - self.origin = origin - self.mask = mask - self.border = border - self.lines = lines - self.positions = positions - self.grid = grid - self.mesh_grid = mesh_grid - self.vectors = vectors - self.patches = patches - self.fill_region = fill_region - self.array_overlay = array_overlay - self.parallel_overscan = parallel_overscan - self.serial_prescan = serial_prescan - self.serial_overscan = serial_overscan - self.indexes = indexes - - def plot_via_plotter(self, plotter, grid_indexes=None): - - if self.mask is not None: - plotter.mask_scatter.scatter_grid(grid=self.mask.derive_grid.edge.array) - - if self.origin is not None: - - origin = self.origin - - if isinstance(origin, tuple): - - origin = Grid2DIrregular(values=[origin]) - - plotter.origin_scatter.scatter_grid( - grid=Grid2DIrregular(values=origin).array - ) - - if self.border is not None: - try: - plotter.border_scatter.scatter_grid(grid=self.border.array) - except AttributeError: - plotter.border_scatter.scatter_grid(grid=self.border) - - if self.grid is not None: - try: - plotter.grid_scatter.scatter_grid(grid=self.grid.array) - except AttributeError: - plotter.grid_scatter.scatter_grid(grid=self.grid) - - if self.mesh_grid is not None: - plotter.mesh_grid_scatter.scatter_grid(grid=self.mesh_grid.array) - - if self.positions is not None: - try: - plotter.positions_scatter.scatter_grid(grid=self.positions.array) - except (AttributeError, ValueError): - plotter.positions_scatter.scatter_grid(grid=self.positions) - - if self.vectors is not None: - plotter.vector_yx_quiver.quiver_vectors(vectors=self.vectors) - - if self.patches is not None: - plotter.patch_overlay.overlay_patches(patches=self.patches) - - if self.fill_region is not None: - plotter.fill.plot_fill(fill_region=self.fill_region) - - if self.lines is not None: - plotter.grid_plot.plot_grid(grid=self.lines) - - if self.indexes is not None and grid_indexes is not None: - - plotter.index_scatter.scatter_grid_indexes( - grid=np.array(grid_indexes), - indexes=self.indexes, - ) diff --git a/autoarray/plot/wrap/__init__.py b/autoarray/plot/wrap/__init__.py deleted file mode 100644 index 3da942a38..000000000 --- a/autoarray/plot/wrap/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -from autoarray.plot.wrap.base.units import Units -from autoarray.plot.wrap.base.figure import Figure -from autoarray.plot.wrap.base.axis import Axis -from autoarray.plot.wrap.base.cmap import Cmap -from autoarray.plot.wrap.base.colorbar import Colorbar -from autoarray.plot.wrap.base.colorbar_tickparams import ColorbarTickParams -from autoarray.plot.wrap.base.tickparams import TickParams -from autoarray.plot.wrap.base.ticks import YTicks -from autoarray.plot.wrap.base.ticks import XTicks -from autoarray.plot.wrap.base.title import Title -from autoarray.plot.wrap.base.label import YLabel -from autoarray.plot.wrap.base.label import XLabel -from autoarray.plot.wrap.base.text import Text -from autoarray.plot.wrap.base.legend import Legend -from autoarray.plot.wrap.base.output import Output - -from autoarray.plot.wrap.one_d.yx_plot import YXPlot -from autoarray.plot.wrap.one_d.yx_scatter import YXScatter -from autoarray.plot.wrap.one_d.avxline import AXVLine -from autoarray.plot.wrap.one_d.fill_between import FillBetween - -from autoarray.plot.wrap.two_d.array_overlay import ArrayOverlay -from autoarray.plot.wrap.two_d.grid_scatter import GridScatter -from autoarray.plot.wrap.two_d.grid_plot import GridPlot -from autoarray.plot.wrap.two_d.grid_errorbar import GridErrorbar -from autoarray.plot.wrap.two_d.vector_yx_quiver import VectorYXQuiver -from autoarray.plot.wrap.two_d.patch_overlay import PatchOverlay -from autoarray.plot.wrap.two_d.origin_scatter import OriginScatter -from autoarray.plot.wrap.two_d.mask_scatter import MaskScatter -from autoarray.plot.wrap.two_d.border_scatter import BorderScatter -from autoarray.plot.wrap.two_d.positions_scatter import PositionsScatter -from autoarray.plot.wrap.two_d.index_scatter import IndexScatter -from autoarray.plot.wrap.two_d.index_plot import IndexPlot -from autoarray.plot.wrap.two_d.mesh_grid_scatter import MeshGridScatter -from autoarray.plot.wrap.two_d.parallel_overscan_plot import ParallelOverscanPlot -from autoarray.plot.wrap.two_d.serial_prescan_plot import SerialPrescanPlot -from autoarray.plot.wrap.two_d.serial_overscan_plot import SerialOverscanPlot diff --git a/autoarray/plot/wrap/base/__init__.py b/autoarray/plot/wrap/base/__init__.py deleted file mode 100644 index 5d1f316c9..000000000 --- a/autoarray/plot/wrap/base/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .units import Units -from .figure import Figure -from .axis import Axis -from .cmap import Cmap -from .colorbar import Colorbar -from .colorbar_tickparams import ColorbarTickParams -from .tickparams import TickParams -from .ticks import YTicks -from .ticks import XTicks -from .title import Title -from .label import YLabel -from .label import XLabel -from .text import Text -from .annotate import Annotate -from .legend import Legend -from .output import Output diff --git a/autoarray/plot/wrap/base/abstract.py b/autoarray/plot/wrap/base/abstract.py deleted file mode 100644 index c246c1afa..000000000 --- a/autoarray/plot/wrap/base/abstract.py +++ /dev/null @@ -1,194 +0,0 @@ -import numpy as np - -from autoconf import conf - - -def set_backend(): - """ - The matplotlib end used by default is the default matplotlib backend on a user's computer. - - The backend can be customized via the `config.visualize.general.ini` config file, if a user needs to overwrite - the backend for visualization to work. - - This has been the case in order to circumvent compatibility issues with MACs. - - It is also common for high perforamcne computers (HPCs) to not support visualization and raise an error when - a graphical backend (e.g. TKAgg) is used. Setting the backend to `Agg` addresses this. - """ - import matplotlib - - backend = conf.get_matplotlib_backend() - - if backend not in "default": - matplotlib.use(backend) - - try: - hpc_mode = conf.instance["general"]["hpc"]["hpc_mode"] - except KeyError: - hpc_mode = False - - if hpc_mode: - matplotlib.use("Agg") - - -def remove_spaces_and_commas_from(colors): - colors = [color.strip(",").strip(" ") for color in colors] - colors = list(filter(None, colors)) - if len(colors) == 1: - return colors[0] - return colors - - -class AbstractMatWrap: - def __init__(self, **kwargs): - """ - An abstract base class for wrapping matplotlib plotting methods. - - Classes are used to wrap matplotlib so that the data structures in the `autoarray.structures` package can be - plotted in standardized withs. This exploits how these structures have specific formats, units, properties etc. - This allows us to make a simple API for plotting structures, for example to plot an `Array2D` structure: - - import autoarray as aa - import autoarray.plot as aplt - - arr = aa.Array2D.no_mask(values=[[1.0, 1.0], [2.0, 2.0]], pixel_scales=2.0) - aplt.Array2D(values=arr) - - The wrapped Mat objects make it simple to customize how matplotlib visualizes this data structure, for example - we can customize the figure size and colormap using the `Figure` and `Cmap` objects. - - figure = aplt.Figure(figsize=(7,7), aspect="square") - cmap = aplt.Cmap(cmap="jet", vmin=1.0, vmax=2.0) - - plotter = aplt.MatPlot2D(figure=figure, cmap=cmap) - - aplt.Array2D(values=arr, plotter=plotter) - - The `Plotter` object is detailed in the `autoarray.plot.plotter` package. - - The matplotlib wrapper objects in ths module also use configuration files to choose their default settings. - For example, in `autoarray.config.visualize.mat_base.Figure.ini` you will note the section: - - figure: - figsize=(7, 7) - - subplot: - figsize=auto - - This specifies that when a data structure (like the `Array2D` above) is plotted, the figsize will always - be (7,7) when a single figure is plotted and it will be chosen automatically if a subplot is plotted. This - allows one to customize the matplotlib settings of every plot in a project. - """ - - self.is_for_subplot = False - self.kwargs = kwargs - - @property - def config_dict(self): - config_dict = conf.instance["visualize"][self.config_folder][ - self.__class__.__name__ - ][self.config_category] - - if "c" in config_dict: - config_dict["c"] = remove_spaces_and_commas_from(colors=config_dict["c"]) - - config_dict = {**config_dict, **self.kwargs} - - if "c" in config_dict: - if config_dict["c"] is None: - config_dict.pop("c") - - if "is_default" in config_dict: - config_dict.pop("is_default") - - return config_dict - - @property - def config_folder(self): - return "mat_wrap" - - @property - def config_category(self): - if self.is_for_subplot: - return "subplot" - return "figure" - - @property - def log10_min_value(self): - return conf.instance["visualize"]["general"]["general"]["log10_min_value"] - - @property - def log10_max_value(self): - return float( - conf.instance["visualize"]["general"]["general"]["log10_max_value"] - ) - - def vmin_from(self, array: np.ndarray, use_log10: bool = False) -> float: - """ - The vmin of a plot, for example the minimum value of the colormap and colorbar. - - If the vmin is manually input by the user, this value is used. Otherwise, the minimum value of the data being - plotted is used, which is computed via nanmin to ensure that NaN entries in the data are ignored. - - If use_log10 is True, the minimum value of the colormap is the log10 of the minimum value of the data. To - ensure negative values are not plotted, which often causes matplotlib errors, the minimum value of the colormap - is rounded up to the log10_min_value attribute of the config file. - - Parameters - ---------- - array - The array of data which is to be plotted. - use_log10 - If True, the minimum value of the colormap is the log10 of the minimum value of the data. - - Returns - ------- - The minimum value of the colormap. - """ - if self.config_dict["norm"] in "log": - use_log10 = True - - if self.config_dict["vmin"] is None: - vmin = np.nanmin(array) - else: - vmin = self.config_dict["vmin"] - - if use_log10 and (vmin < self.log10_min_value): - vmin = self.log10_min_value - - return vmin - - def vmax_from(self, array: np.ndarray, use_log10: bool = False) -> float: - """ - The vmax of a plot, for example the maximum value of the colormap and colorbar. - - If the vmax is manually input by the user, this value is used. Otherwise, the maximum value of the data being - plotted is used, which is computed via nanmax to ensure that NaN entries in the data are ignored. - - If use_log10 is True, the maximum value of the colormap is the log10 of the maximum value of the data. To - ensure values above the log10_max_value attribute of the config file are not plotted, this value is used - as the maximum value of the colormap. - - Parameters - ---------- - array - The array of data which is to be plotted. - use_log10 - If True, the maximum value of the colormap is the log10 of the maximum value of the data. - - Returns - ------- - The maximum value of the colormap. - """ - if self.config_dict["norm"] in "log": - use_log10 = True - - if self.config_dict["vmax"] is None: - vmax = np.nanmax(array) - else: - vmax = self.config_dict["vmax"] - - if use_log10 and (vmax > self.log10_max_value): - vmax = self.log10_max_value - - return vmax diff --git a/autoarray/plot/wrap/base/annotate.py b/autoarray/plot/wrap/base/annotate.py deleted file mode 100644 index e1f1e917f..000000000 --- a/autoarray/plot/wrap/base/annotate.py +++ /dev/null @@ -1,20 +0,0 @@ -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class Annotate(AbstractMatWrap): - """ - The settings used to customize annotations on the figure. - - This object wraps the following Matplotlib methods: - - - plt.annotate: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.text.html - """ - - def set(self): - - import matplotlib.pyplot as plt - - if "x" not in self.kwargs and "y" not in self.kwargs and "s" not in self.kwargs: - return - - plt.annotate(**self.config_dict) diff --git a/autoarray/plot/wrap/base/axis.py b/autoarray/plot/wrap/base/axis.py deleted file mode 100644 index cd57c5d20..000000000 --- a/autoarray/plot/wrap/base/axis.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np -from typing import List - -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class Axis(AbstractMatWrap): - def __init__(self, symmetric_source_centre: bool = False, **kwargs): - """ - Customizes the axis of the plotted figure. - - This object wraps the following Matplotlib method: - - - plt.axis: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.axis.html - - Parameters - ---------- - symmetric_source_centre - If `True`, the axis is symmetric around the centre of the plotted structure's coordinates. - """ - - super().__init__(**kwargs) - - self.symmetric_around_centre = symmetric_source_centre - - def set(self, extent: List[float] = None, grid=None): - """ - Set the axis limits of the figure the grid is plotted on. - - Parameters - ---------- - extent - The extent of the figure which set the axis-limits on the figure the grid is plotted, - following the format [xmin, xmax, ymin, ymax]. - """ - import matplotlib.pyplot as plt - - config_dict = self.config_dict - extent_dict = config_dict.get("extent") - - if extent_dict is not None: - config_dict.pop("extent") - - if self.symmetric_around_centre: - ymin = np.min(grid[:, 0]) - ymax = np.max(grid[:, 0]) - xmin = np.min(grid[:, 1]) - xmax = np.max(grid[:, 1]) - - x = np.max([np.abs(xmin), np.abs(xmax)]) - y = np.max([np.abs(ymin), np.abs(ymax)]) - - extent_symmetric = [-x, x, -y, y] - - return plt.axis(extent_symmetric, **config_dict) - - else: - if extent_dict is not None: - return plt.axis(extent_dict, **config_dict) - return plt.axis(extent, **config_dict) diff --git a/autoarray/plot/wrap/base/cmap.py b/autoarray/plot/wrap/base/cmap.py deleted file mode 100644 index 51144ba46..000000000 --- a/autoarray/plot/wrap/base/cmap.py +++ /dev/null @@ -1,107 +0,0 @@ -import copy -import logging -import numpy as np - - -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - -from autoarray import exc - -logger = logging.getLogger(__name__) - - -class Cmap(AbstractMatWrap): - def __init__(self, symmetric: bool = False, **kwargs): - """ - Customizes the Matplotlib colormap and its normalization. - - This object wraps the following Matplotlib methods: - - - colors.Linear: https://matplotlib.org/3.3.2/tutorials/colors/colormaps.html - - colors.LogNorm: https://matplotlib.org/3.3.2/tutorials/colors/colormapnorms.html - - colors.SymLogNorm: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.colors.SymLogNorm.html - - The cmap that is created is passed into various Matplotlib methods, most notably imshow: - - - https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.imshow.html - - Parameters - ---------- - symmetric - If True, the colormap normalization (e.g. `vmin` and `vmax`) span the same absolute values producing a - symmetric color bar. - """ - - super().__init__(**kwargs) - - self._symmetric = symmetric - self.symmetric_value = None - - def symmetric_cmap_from(self, symmetric_value=None): - cmap = copy.copy(self) - - cmap._symmetric = True - cmap.symmetric_value = symmetric_value - - return cmap - - def norm_from(self, array: np.ndarray, use_log10: bool = False) -> object: - """ - Returns the `Normalization` object which scales of the colormap. - - If vmin / vmax are not manually input by the user, the minimum / maximum values of the data being plotted - are used. - - Parameters - ---------- - array - The array of data which is to be plotted. - """ - import matplotlib.colors as colors - - vmin = self.vmin_from(array=array, use_log10=use_log10) - vmax = self.vmax_from(array=array, use_log10=use_log10) - - if self._symmetric: - if vmin < 0.0 and vmax > 0.0: - if self.symmetric_value is None: - if abs(vmin) > abs(vmax): - vmax = abs(vmin) - else: - vmin = -vmax - else: - vmin = -self.symmetric_value - vmax = self.symmetric_value - - if isinstance(self.config_dict["norm"], colors.Normalize): - return self.config_dict["norm"] - - if self.config_dict["norm"] in "log" or use_log10: - return colors.LogNorm(vmin=vmin, vmax=vmax) - elif self.config_dict["norm"] in "linear": - return colors.Normalize(vmin=vmin, vmax=vmax) - elif self.config_dict["norm"] in "symmetric_log": - return colors.SymLogNorm( - vmin=vmin, - vmax=vmax, - linthresh=self.config_dict["linthresh"], - linscale=self.config_dict["linscale"], - ) - elif self.config_dict["norm"] in "diverge": - return colors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax) - - raise exc.PlottingException( - "The normalization (norm) supplied to the plotter is not a valid string must be " - "{linear, log, symmetric_log}" - ) - - @property - def cmap(self): - from matplotlib.colors import LinearSegmentedColormap - - if self.config_dict["cmap"] == "default": - from autoarray.plot.wrap.segmentdata import segmentdata - - return LinearSegmentedColormap(name="default", segmentdata=segmentdata) - - return self.config_dict["cmap"] diff --git a/autoarray/plot/wrap/base/colorbar.py b/autoarray/plot/wrap/base/colorbar.py deleted file mode 100644 index f5d24f5f3..000000000 --- a/autoarray/plot/wrap/base/colorbar.py +++ /dev/null @@ -1,215 +0,0 @@ -import numpy as np -from typing import List, Optional - -from autoconf import conf - -from autoarray.plot.wrap.base.abstract import AbstractMatWrap -from autoarray.plot.wrap.base.units import Units - -from autoarray import exc - - -class Colorbar(AbstractMatWrap): - def __init__( - self, - manual_tick_labels: Optional[List[float]] = None, - manual_tick_values: Optional[List[float]] = None, - manual_alignment: Optional[str] = None, - manual_unit: Optional[str] = None, - manual_log10: bool = False, - **kwargs, - ): - """ - Customizes the colorbar of the plotted figure. - - This object wraps the following Matplotlib method: - - - plt.colorbar: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.colorbar.html - - The colorbar object `cb` that is created is also customized using the following methods: - - - cb.set_yticklabels: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.axes.Axes.set_yticklabels.html - - Parameters - ---------- - manual_tick_labels - Manually override the colorbar tick labels to an input list of float. - manual_tick_values - If the colorbar tick labels are manually specified the locations on the colorbar they appear running 0 -> 1. - manual_alignment - The vertical alignment of the colorbar tick labels, specified via the matplotlib method `set_yticklabels` - and input `va`. - manual_unit - The unit label that appears next to the colorbar tick labels, which if not input uses a default unit label - specified as `cb_unit` in the config file `config/visualize/general.yaml. - """ - - super().__init__(**kwargs) - - self.manual_tick_labels = manual_tick_labels - self.manual_tick_values = manual_tick_values - self.manual_alignment = manual_alignment - self.manual_unit = manual_unit - self.manual_log10 = manual_log10 - - @property - def cb_unit(self): - if self.manual_unit is None: - return conf.instance["visualize"]["general"]["units"]["cb_unit"] - return self.manual_unit - - def tick_values_from(self, norm=None, use_log10: bool = False): - if ( - sum( - x is not None - for x in [self.manual_tick_values, self.manual_tick_labels] - ) - == 1 - ): - raise exc.PlottingException( - "You can only manually specify the colorbar tick labels and values if both are input." - ) - - if self.manual_tick_values is not None: - return self.manual_tick_values - - if norm is not None: - min_value = norm.vmin - max_value = norm.vmax - - if use_log10: - if min_value < self.log10_min_value: - min_value = self.log10_min_value - - log_mid_value = (np.log10(max_value) + np.log10(min_value)) / 2.0 - mid_value = 10**log_mid_value - - else: - mid_value = (max_value + min_value) / 2.0 - - return [min_value, mid_value, max_value] - - def tick_labels_from( - self, - units: Units, - manual_tick_values: List[float], - cb_unit=None, - ): - if manual_tick_values is None: - return None - - convert_factor = units.colorbar_convert_factor or 1.0 - - if self.manual_tick_labels is not None: - manual_tick_labels = self.manual_tick_labels - else: - manual_tick_labels = [ - np.round(value * convert_factor, 2) for value in manual_tick_values - ] - - if self.manual_log10: - manual_tick_labels = [ - "{:.0e}".format(label) for label in manual_tick_labels - ] - - manual_tick_labels = [ - label.replace("1e", "$10^{") + "}$" for label in manual_tick_labels - ] - - manual_tick_labels = [ - label.replace("{-0", "{-").replace("{+0", "{+").replace("+", "") - for label in manual_tick_labels - ] - - if units.colorbar_label is None: - if cb_unit is None: - cb_unit = self.cb_unit - else: - cb_unit = units.colorbar_label - - middle_index = (len(manual_tick_labels) - 1) // 2 - manual_tick_labels[middle_index] = ( - rf"{manual_tick_labels[middle_index]}{cb_unit}" - ) - - return manual_tick_labels - - def set( - self, units: Units, ax=None, norm=None, cb_unit=None, use_log10: bool = False - ): - """ - Set the figure's colorbar, optionally overriding the tick labels and values with manual inputs. - """ - import matplotlib.pyplot as plt - - tick_values = self.tick_values_from(norm=norm, use_log10=use_log10) - tick_labels = self.tick_labels_from( - manual_tick_values=tick_values, - units=units, - cb_unit=cb_unit, - ) - - if tick_values is None and tick_labels is None: - cb = plt.colorbar(ax=ax, **self.config_dict) - else: - cb = plt.colorbar(ticks=tick_values, ax=ax, **self.config_dict) - cb.ax.set_yticklabels( - labels=tick_labels, va=self.manual_alignment or "center" - ) - - return cb - - def set_with_color_values( - self, - units: Units, - cmap: str, - color_values: np.ndarray, - ax=None, - norm=None, - use_log10: bool = False, - ): - """ - Set the figure's colorbar using an array of already known color values. - - This method is used for producing the color bar on a Delaunay mesh plot, which is unable to use the in-built - Matplotlib colorbar method. - - Parameters - ---------- - cmap - The colormap used to map normalized data values to RGBA - colors (see https://matplotlib.org/3.3.2/api/cm_api.html). - color_values - The values of the pixels on the mesh which are used to create the colorbar. - """ - import matplotlib.pyplot as plt - import matplotlib.cm as cm - - mappable = cm.ScalarMappable(norm=norm, cmap=cmap) - mappable.set_array(color_values) - - tick_values = self.tick_values_from(norm=norm, use_log10=use_log10) - tick_labels = self.tick_labels_from( - manual_tick_values=tick_values, - units=units, - ) - - if tick_values is None and tick_labels is None: - - cb = plt.colorbar( - mappable=mappable, - ax=ax, - **self.config_dict, - ) - else: - cb = plt.colorbar( - mappable=mappable, - ax=ax, - ticks=tick_values, - **self.config_dict, - ) - cb.ax.set_yticklabels( - labels=tick_labels, va=self.manual_alignment or "center" - ) - - return cb diff --git a/autoarray/plot/wrap/base/colorbar_tickparams.py b/autoarray/plot/wrap/base/colorbar_tickparams.py deleted file mode 100644 index 59ec73857..000000000 --- a/autoarray/plot/wrap/base/colorbar_tickparams.py +++ /dev/null @@ -1,14 +0,0 @@ -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class ColorbarTickParams(AbstractMatWrap): - """ - Customizes the ticks of the colorbar of the plotted figure. - - This object wraps the following Matplotlib colorbar method: - - - cb.set_yticklabels: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.axes.Axes.set_yticklabels.html - """ - - def set(self, cb): - cb.ax.tick_params(**self.config_dict) diff --git a/autoarray/plot/wrap/base/figure.py b/autoarray/plot/wrap/base/figure.py deleted file mode 100644 index 61a3347b6..000000000 --- a/autoarray/plot/wrap/base/figure.py +++ /dev/null @@ -1,100 +0,0 @@ -from enum import Enum -import gc -from typing import Union, Tuple - -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class Aspect(Enum): - square = 1 - auto = 2 - equal = 3 - - -class Figure(AbstractMatWrap): - """ - Sets up the Matplotlib figure before plotting (this is used when plotting individual figures and subplots). - - This object wraps the following Matplotlib methods: - - - plt.figure: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.figure.html - - plt.close: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.close.html - - It also controls the aspect ratio of the figure plotted. - """ - - @property - def config_dict(self): - """ - Creates a config dict of valid inputs of the method `plt.figure` from the object's config_dict. - """ - - config_dict = super().config_dict - - if config_dict["figsize"] == "auto": - config_dict["figsize"] = None - elif isinstance(config_dict["figsize"], str): - config_dict["figsize"] = tuple( - map(int, config_dict["figsize"][1:-1].split(",")) - ) - - return config_dict - - def aspect_for_subplot_from(self, extent): - ratio = float((extent[1] - extent[0]) / (extent[3] - extent[2])) - - aspect = Aspect[self.config_dict["aspect"]] - - if aspect == Aspect.square: - return ratio - elif aspect == Aspect.auto: - return 1.0 / ratio - elif aspect == Aspect.equal: - return 1.0 - - raise ValueError( - f""" - The `aspect` variable used to set up the figure is {aspect}. - - This is not a valid value, which must be one of square / auto / equal. - """ - ) - - def aspect_from(self, shape_native: Union[Tuple[int, int]]) -> Union[float, str]: - """ - Returns the aspect ratio of the figure from the 2D shape of a data structure. - - This is used to ensure that rectangular arrays are plotted as square figures on sub-plots. - - Parameters - ---------- - shape_native - The two dimensional shape of an `Array2D` that is to be plotted. - """ - if isinstance(self.config_dict["aspect"], str): - if self.config_dict["aspect"] in "square": - return float(shape_native[1]) / float(shape_native[0]) - - return self.config_dict["aspect"] - - def open(self): - """ - Wraps the Matplotlib method 'plt.figure' for opening a figure. - """ - import matplotlib.pyplot as plt - - if not plt.fignum_exists(num=1): - config_dict = self.config_dict - config_dict.pop("aspect") - fig = plt.figure(**config_dict) - return fig, plt.gca() - return None, None - - def close(self): - """ - Wraps the Matplotlib method 'plt.close' for closing a figure. - """ - import matplotlib.pyplot as plt - - plt.close() - gc.collect() diff --git a/autoarray/plot/wrap/base/label.py b/autoarray/plot/wrap/base/label.py deleted file mode 100644 index a54d5c474..000000000 --- a/autoarray/plot/wrap/base/label.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Optional - -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class AbstractLabel(AbstractMatWrap): - def __init__(self, **kwargs): - """ - The settings used to customize the figure's title and y and x labels. - - This object wraps the following Matplotlib methods: - - - plt.ylabel: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.ylabel.html - - plt.xlabel: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.xlabel.html - - The y and x labels will automatically be set if not specified, using the input units. - - Parameters - ---------- - units - The units the data is plotted using. - manual_label - A manual label which overrides the default computed via the units if input. - """ - - super().__init__(**kwargs) - - self.manual_label = self.kwargs.get("label") - - -class YLabel(AbstractLabel): - def set( - self, - auto_label: Optional[str] = None, - ): - """ - Set the y labels of the figure, including the fontsize. - - The y labels are always the distance scales, thus the labels are either arc-seconds or kpc and depending on - the unit_label the figure is plotted in. - - Parameters - ---------- - units - The units of the image that is plotted which informs the appropriate y label text. - """ - import matplotlib.pyplot as plt - - config_dict = self.config_dict - - if self.manual_label is not None: - config_dict.pop("ylabel") - plt.ylabel(ylabel=self.manual_label, **config_dict) - elif auto_label is not None: - config_dict.pop("ylabel") - plt.ylabel(ylabel=auto_label, **config_dict) - else: - plt.ylabel(**config_dict) - - -class XLabel(AbstractLabel): - def set( - self, - auto_label: Optional[str] = None, - ): - """ - Set the x labels of the figure, including the fontsize. - - The x labels are always the distance scales, thus the labels are either arc-seconds or kpc and depending on - the unit_label the figure is plotted in. - - Parameters - ---------- - units - The units of the image that is plotted which informs the appropriate x label text. - """ - import matplotlib.pyplot as plt - - config_dict = self.config_dict - - if self.manual_label is not None: - config_dict.pop("xlabel") - plt.xlabel(xlabel=self.manual_label, **config_dict) - elif auto_label is not None: - config_dict.pop("xlabel") - plt.xlabel(xlabel=auto_label, **config_dict) - else: - plt.xlabel(**config_dict) diff --git a/autoarray/plot/wrap/base/legend.py b/autoarray/plot/wrap/base/legend.py deleted file mode 100644 index 09a6e9d4d..000000000 --- a/autoarray/plot/wrap/base/legend.py +++ /dev/null @@ -1,28 +0,0 @@ -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class Legend(AbstractMatWrap): - """ - The settings used to include and customize a legend on a figure. - - This object wraps the following Matplotlib methods: - - - plt.legend: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.legend.html - """ - - def __init__(self, label=None, include=True, **kwargs): - super().__init__(**kwargs) - - self.label = label - self.include = include - - def set(self): - - import matplotlib.pyplot as plt - - if self.include: - config_dict = self.config_dict - config_dict.pop("include") if "include" in config_dict else None - config_dict.pop("include_2d") if "include_2d" in config_dict else None - - plt.legend(**config_dict) diff --git a/autoarray/plot/wrap/base/text.py b/autoarray/plot/wrap/base/text.py deleted file mode 100644 index 4141bc0ac..000000000 --- a/autoarray/plot/wrap/base/text.py +++ /dev/null @@ -1,20 +0,0 @@ -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class Text(AbstractMatWrap): - """ - The settings used to customize text on the figure. - - This object wraps the following Matplotlib methods: - - - plt.text: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.text.html - """ - - def set(self): - - import matplotlib.pyplot as plt - - if "x" not in self.kwargs and "y" not in self.kwargs and "s" not in self.kwargs: - return - - plt.text(**self.config_dict) diff --git a/autoarray/plot/wrap/base/tickparams.py b/autoarray/plot/wrap/base/tickparams.py deleted file mode 100644 index 7369a29a0..000000000 --- a/autoarray/plot/wrap/base/tickparams.py +++ /dev/null @@ -1,18 +0,0 @@ -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class TickParams(AbstractMatWrap): - """ - The settings used to customize a figure's y and x ticks parameters. - - This object wraps the following Matplotlib methods: - - - plt.tick_params: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.tick_params.html - """ - - def set(self): - """Set the tick_params of the figure using the method `plt.tick_params`.""" - - import matplotlib.pyplot as plt - - plt.tick_params(**self.config_dict) diff --git a/autoarray/plot/wrap/base/ticks.py b/autoarray/plot/wrap/base/ticks.py deleted file mode 100644 index 3b7c72e57..000000000 --- a/autoarray/plot/wrap/base/ticks.py +++ /dev/null @@ -1,452 +0,0 @@ -import numpy as np -from typing import List, Tuple, Optional - -from autoconf import conf - -from autoarray.plot.wrap.base.abstract import AbstractMatWrap -from autoarray.plot.wrap.base.units import Units - - -class TickMaker: - def __init__( - self, - min_value: float, - max_value: float, - factor: float, - number_of_ticks: int, - units, - ): - self.min_value = min_value - self.max_value = max_value - self.factor = factor - self.number_of_ticks = number_of_ticks - self.units = units - - @property - def centre(self): - return self.max_value - ((self.max_value - self.min_value) / 2.0) - - @property - def tick_values_linear(self): - value_0 = self.centre - ((self.centre - self.max_value)) * self.factor - value_1 = self.centre + ((self.min_value - self.centre)) * self.factor - - return np.linspace(value_0, value_1, self.number_of_ticks) - - @property - def tick_values_log10(self): - min_value = self.min_value - - if self.min_value < 0.001: - min_value = 0.001 - - min_value = 10 ** np.floor(np.log10(min_value)) - - max_value = 10 ** np.ceil(np.log10(self.max_value)) - number = int(abs(np.log10(max_value) - np.log10(min_value))) + 1 - - return np.logspace(np.log10(min_value), np.log10(max_value), number) - - @property - def tick_values_integers(self): - ticks = np.arange(int(self.max_value - self.min_value)) - - if not self.units.use_scaled: - ticks = ticks.astype("int") - - return ticks - - -class LabelMaker: - def __init__( - self, - tick_values, - min_value: float, - max_value: float, - units, - pixels: Optional[int] = None, - round_sf: int = 2, - yunit=None, - xunit=None, - manual_suffix=None, - ): - self.tick_values = tick_values - self.min_value = min_value - self.max_value = max_value - self.units = units - self.pixels = pixels - self.convert_factor = self.units.ticks_convert_factor or 1.0 - self.yunit = yunit - self.xunit = xunit - self.round_sf = round_sf - self.manual_suffix = manual_suffix - - @property - def suffix(self) -> Optional[str]: - """ - Returns the label of an object, by determining it from the figure units if the label is not manually specified. - - Parameters - ---------- - units - The units of the data structure that is plotted which informs the appropriate label text. - """ - - if self.manual_suffix is not None: - return self.manual_suffix - - if self.yunit is not None: - return self.yunit - - if self.xunit is not None: - return self.xunit - - if self.units.ticks_label is not None: - return self.units.ticks_label - - units_conf = conf.instance["visualize"]["general"]["units"] - - if self.units is None: - return units_conf["unscaled_symbol"] - - if self.units.use_scaled: - return units_conf["scaled_symbol"] - - return units_conf["unscaled_symbol"] - - @property - def span(self): - return self.max_value - self.min_value - - @property - def tick_values_rounded(self): - values = np.asarray(self.tick_values) * self.convert_factor - values_positive = np.where( - np.isfinite(values) & (values != 0), - np.abs(values), - 10 ** (self.round_sf - 1), - ) - mags = 10 ** (self.round_sf - 1 - np.floor(np.log10(values_positive))) - return np.round(values * mags) / mags - - @property - def labels_linear(self): - if self.units.use_raw: - return self.with_appended_suffix(self.tick_values_rounded) - - if not self.units.use_scaled and self.yunit is None: - return self.labels_linear_pixels - - labels = np.asarray([value for value in self.tick_values_rounded]) - - if not self.units.use_scaled and self.yunit is None: - labels = [f"{int(label)}" for label in labels] - return self.with_appended_suffix(labels) - - @property - def labels_linear_pixels(self): - if self.max_value == self.min_value: - labels = [f"{int(label)}" for label in self.tick_values] - return self.with_appended_suffix(labels) - - ticks_from_zero = [tick - self.min_value for tick in self.tick_values] - labels = [(tick / self.span) * self.pixels for tick in ticks_from_zero] - - labels = [f"{int(label)}" for label in labels] - - return self.with_appended_suffix(labels) - - @property - def labels_log10(self): - labels = ["{:.0e}".format(label) for label in self.tick_values] - labels = [label.replace("1e", "$10^{") + "}$" for label in labels] - labels = [ - label.replace("{-0", "{-").replace("{+0", "{+").replace("+", "") - for label in labels - ] - # labels = [label.replace("1e", "").replace("-0", "-").replace("+0", "+").replace("+0", "0") for label in labels] - - return self.with_appended_suffix(labels) - - def with_appended_suffix(self, labels): - """ - The labels used for the y and x ticks can be append with a suffix. - - For example, if the labels were [-1.0, 0.0, 1.0] and the suffix is ", the labels with the suffix appended - is [-1.0", 0.0", 1.0"]. - - Parameters - ---------- - labels - The y and x labels which are append with the suffix. - """ - - labels = [str(label) for label in labels] - - all_end_0 = True - - for label in labels: - if not label.endswith(".0"): - all_end_0 = False - - if all_end_0: - labels = [label[:-2] for label in labels] - - return [f"{label}{self.suffix}" for label in labels] - - -class AbstractTicks(AbstractMatWrap): - def __init__( - self, - manual_factor: Optional[float] = None, - manual_values: Optional[List[float]] = None, - manual_min_max_value: Optional[Tuple[float, float]] = None, - manual_units: Optional[str] = None, - manual_suffix: Optional[str] = None, - **kwargs, - ): - """ - The settings used to customize a figure's y and x ticks using the `YTicks` and `XTicks` objects. - - This object wraps the following Matplotlib methods: - - - plt.yticks: https://matplotlib.org/3.3.1/api/_as_gen/matplotlib.pyplot.yticks.html - - plt.xticks: https://matplotlib.org/3.3.1/api/_as_gen/matplotlib.pyplot.xticks.html - - Parameters - ---------- - manual_values - Manually override the tick labels to display the labels as the input list of floats. - manual_units - Manually override the units in brackets of the tick label. - manual_suffix - A suffix applied to every tick label (e.g. for the suffix `kpc` 0.0 becomes 0.0kpc). - """ - super().__init__(**kwargs) - - self.manual_factor = manual_factor - self.manual_values = manual_values - self.manual_min_max_value = manual_min_max_value - self.manual_units = manual_units - self.manual_suffix = manual_suffix - - def factor_from(self, suffix): - if self.manual_factor is not None: - return self.manual_factor - return conf.instance["visualize"][self.config_folder][self.__class__.__name__][ - "manual" - ][f"extent_factor{suffix}"] - - def number_of_ticks_from(self, suffix): - return conf.instance["visualize"][self.config_folder][self.__class__.__name__][ - "manual" - ][f"number_of_ticks{suffix}"] - - def tick_maker_from( - self, min_value: float, max_value: float, units, is_for_1d_plot: bool - ): - suffix = "_1d" if is_for_1d_plot else "_2d" - - factor = self.factor_from(suffix=suffix) - number_of_ticks = self.number_of_ticks_from(suffix=suffix) - - return TickMaker( - min_value=min_value, - max_value=max_value, - factor=factor, - units=units, - number_of_ticks=number_of_ticks, - ) - - def ticks_from( - self, - min_value: float, - max_value: float, - units: Units, - is_log10: bool = False, - is_for_1d_plot: bool = False, - ): - tick_maker = self.tick_maker_from( - min_value=min_value, - max_value=max_value, - units=units, - is_for_1d_plot=is_for_1d_plot, - ) - - if self.manual_values: - return self.manual_values - elif is_log10: - return tick_maker.tick_values_log10 - return tick_maker.tick_values_linear - - def labels_from( - self, - ticks, - min_value: float, - max_value: float, - units, - yunit, - xunit, - pixels: Optional[int] = None, - is_log10: bool = False, - ): - label_maker = LabelMaker( - tick_values=ticks, - min_value=min_value, - max_value=max_value, - units=units, - pixels=pixels, - yunit=yunit, - xunit=xunit, - manual_suffix=self.manual_suffix, - ) - - if self.manual_units: - return ticks - elif is_log10: - return label_maker.labels_log10 - return label_maker.labels_linear - - def ticks_and_labels_from( - self, - min_value, - max_value, - units, - pixels: Optional[int] = None, - use_integers: bool = False, - yunit=None, - xunit=None, - is_log10: bool = False, - is_for_1d_plot: bool = False, - ): - if use_integers: - ticks = np.arange(int(max_value - min_value)) - return ticks, ticks - - ticks = self.ticks_from( - min_value=min_value, - max_value=max_value, - units=units, - is_log10=is_log10, - is_for_1d_plot=is_for_1d_plot, - ) - - labels = self.labels_from( - ticks=ticks, - min_value=min_value, - max_value=max_value, - units=units, - yunit=yunit, - xunit=xunit, - pixels=pixels, - is_log10=is_log10, - ) - return ticks, labels - - -class YTicks(AbstractTicks): - def set( - self, - min_value: float, - max_value: float, - units: Units, - pixels: Optional[int] = None, - yunit=None, - is_for_1d_plot: bool = False, - is_log10: bool = False, - ): - """ - Set the y ticks of a figure using the shape of an input `Array2D` object and input units. - - Parameters - ---------- - array - The 2D array of data which is plotted. - min_value - the minimum value of the yticks that figure is plotted using. - max_value - the maximum value of the yticks that figure is plotted using. - units - The units of the figure. - """ - import matplotlib.pyplot as plt - from matplotlib.ticker import FormatStrFormatter - - if self.manual_min_max_value: - min_value = self.manual_min_max_value[0] - max_value = self.manual_min_max_value[1] - - ticks, labels = self.ticks_and_labels_from( - min_value=min_value, - max_value=max_value, - units=units, - pixels=pixels, - yunit=yunit, - is_log10=is_log10, - is_for_1d_plot=is_for_1d_plot, - ) - - if is_log10: - plt.ylim(max(min_value, self.log10_min_value), max_value) - - if not is_for_1d_plot and not units.use_scaled: - labels = reversed(labels) - - plt.yticks(ticks=ticks, labels=labels, **self.config_dict) - - if self.manual_units is not None: - plt.gca().yaxis.set_major_formatter( - FormatStrFormatter(f"{self.manual_units}") - ) - - -class XTicks(AbstractTicks): - def set( - self, - min_value: float, - max_value: float, - units: Units, - pixels: Optional[int] = None, - xunit=None, - use_integers=False, - is_for_1d_plot: bool = False, - is_log10: bool = False, - ): - """ - Set the x ticks of a figure using the shape of an input `Array2D` object and input units. - - Parameters - ---------- - array - The 2D array of data which is plotted. - min_value - the minimum value of the xticks that figure is plotted using. - max_value - the maximum value of the xticks that figure is plotted using. - units - The units of the figure. - """ - import matplotlib.pyplot as plt - from matplotlib.ticker import FormatStrFormatter - - if self.manual_min_max_value: - min_value = self.manual_min_max_value[0] - max_value = self.manual_min_max_value[1] - - ticks, labels = self.ticks_and_labels_from( - min_value=min_value, - max_value=max_value, - pixels=pixels, - units=units, - yunit=xunit, - use_integers=use_integers, - is_for_1d_plot=is_for_1d_plot, - is_log10=is_log10, - ) - - plt.xticks(ticks=ticks, labels=labels, **self.config_dict) - - if self.manual_units is not None: - plt.gca().xaxis.set_major_formatter( - FormatStrFormatter(f"{self.manual_units}") - ) diff --git a/autoarray/plot/wrap/base/title.py b/autoarray/plot/wrap/base/title.py deleted file mode 100644 index 60aec1d30..000000000 --- a/autoarray/plot/wrap/base/title.py +++ /dev/null @@ -1,46 +0,0 @@ -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class Title(AbstractMatWrap): - def __init__(self, prefix: str = None, disable_log10_label: bool = False, **kwargs): - """ - The settings used to customize the figure's title. - - This object wraps the following Matplotlib methods: - - - plt.title: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.title.html - - The title will automatically be set if not specified, using the name of the function used to plot the data. - - Parameters - ---------- - prefix - A string that is added before the title, for example to put the name of the dataset and galaxy in the title. - disable_log10_label - If True, the (log10) label is not added to the title if the data is plotted on a log-scale. - """ - - super().__init__(**kwargs) - - self.prefix = prefix - self.disable_log10_label = disable_log10_label - self.manual_label = self.kwargs.get("label") - - def set(self, auto_title=None, use_log10: bool = False): - - import matplotlib.pyplot as plt - - config_dict = self.config_dict - - label = auto_title if self.manual_label is None else self.manual_label - - if self.prefix is not None: - label = f"{self.prefix} {label}" - - if use_log10 and not self.disable_log10_label: - label = f"{label} (log10)" - - if "label" in config_dict: - config_dict.pop("label") - - plt.title(label=label, **config_dict) diff --git a/autoarray/plot/wrap/base/units.py b/autoarray/plot/wrap/base/units.py deleted file mode 100644 index db1048b32..000000000 --- a/autoarray/plot/wrap/base/units.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging -from typing import Optional - -from autoconf import conf - -logger = logging.getLogger(__name__) - - -class Units: - def __init__( - self, - use_scaled: Optional[bool] = None, - use_raw: Optional[bool] = False, - ticks_convert_factor: Optional[float] = None, - ticks_label: Optional[str] = None, - colorbar_convert_factor: Optional[float] = None, - colorbar_label: Optional[str] = None, - **kwargs, - ): - """ - This object controls the units of a plotted figure, and performs multiple tasks when making the plot: - - 1: Species the units of the plot (e.g. meters, kilometers) and contains a conversion factor which converts - the plotted data from its current units (e.g. meters) to the units plotted (e.g. kilometeters). Pixel units - can be used if `use_scaled=False`. - - 2: Uses the conversion above to manually override the yticks and xticks of the figure, so it appears in the - converted units. - - 3: Sets the ylabel and xlabel to include a string containing the units. - - Parameters - ---------- - use_scaled - If True, plot the 2D data with y and x ticks corresponding to its scaled - coordinates (its `pixel_scales` attribute is used as the `ticks_convert_factor`). If `False` plot them in - pixel units. - ticks_convert_factor - If plotting the labels in scaled units, this factor multiplies the values that are used for the labels. - This allows for additional unit conversions of the figure labels. - """ - - self.ticks_convert_factor = ticks_convert_factor - self.ticks_label = ticks_label - - if use_scaled is not None: - self.use_scaled = use_scaled - else: - try: - self.use_scaled = conf.instance["visualize"]["general"]["units"][ - "use_scaled" - ] - except KeyError: - self.use_scaled = True - - self.use_raw = use_raw - - self.colorbar_convert_factor = colorbar_convert_factor - self.colorbar_label = colorbar_label - - self.kwargs = kwargs diff --git a/autoarray/plot/wrap/one_d/__init__.py b/autoarray/plot/wrap/one_d/__init__.py deleted file mode 100644 index 29373197f..000000000 --- a/autoarray/plot/wrap/one_d/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .yx_plot import YXPlot -from .yx_scatter import YXScatter -from .avxline import AXVLine -from .fill_between import FillBetween diff --git a/autoarray/plot/wrap/one_d/abstract.py b/autoarray/plot/wrap/one_d/abstract.py deleted file mode 100644 index ee1651ffc..000000000 --- a/autoarray/plot/wrap/one_d/abstract.py +++ /dev/null @@ -1,18 +0,0 @@ -from autoarray.plot.wrap.base.abstract import set_backend - -set_backend() - -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class AbstractMatWrap1D(AbstractMatWrap): - """ - An abstract base class for wrapping matplotlib plotting methods which take as input and plot data structures. For - example, the `ArrayOverlay` object specifically plots `Array2D` data structures. - - As full description of the matplotlib wrapping can be found in `mat_base.AbstractMatWrap`. - """ - - @property - def config_folder(self): - return "mat_wrap_1d" diff --git a/autoarray/plot/wrap/one_d/avxline.py b/autoarray/plot/wrap/one_d/avxline.py deleted file mode 100644 index 7d36672d8..000000000 --- a/autoarray/plot/wrap/one_d/avxline.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import List, Optional - -from autoarray.plot.wrap.one_d.abstract import AbstractMatWrap1D - - -class AXVLine(AbstractMatWrap1D): - def __init__(self, no_label=False, **kwargs): - """ - Plots vertical lines on 1D plot of y versus x using the method `plt.axvline`. - - This method is typically called after `plot_y_vs_x` to add vertical lines to the figure. - - This object wraps the following Matplotlib methods: - - - plt.avxline: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.axvline.html - - Parameters - ---------- - vertical_line - The vertical lines of data that are plotted on the figure. - label - Labels for each vertical line used by a `Legend`. - """ - super().__init__(**kwargs) - - self.no_label = no_label - - def axvline_vertical_line( - self, - vertical_line: float, - vertical_errors: Optional[List[float]] = None, - label: Optional[str] = None, - ): - """ - Plot an input vertical line given by its x coordinate as a float using the method `plt.axvline`. - - Parameters - ---------- - vertical_line - The vertical lines of data that are plotted on the figure. - label - Labels for each vertical line used by a `Legend`. - """ - import matplotlib.pyplot as plt - - if vertical_line is [] or vertical_line is None: - return - - if self.no_label: - label = None - - plt.axvline(x=vertical_line, label=label, **self.config_dict) - - if vertical_errors is not None: - config_dict = self.config_dict - - if "linestyle" in config_dict: - config_dict.pop("linestyle") - - plt.axvline(x=vertical_errors[0], linestyle="--", **config_dict) - plt.axvline(x=vertical_errors[1], linestyle="--", **config_dict) diff --git a/autoarray/plot/wrap/one_d/fill_between.py b/autoarray/plot/wrap/one_d/fill_between.py deleted file mode 100644 index 8a91b9a73..000000000 --- a/autoarray/plot/wrap/one_d/fill_between.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np -from typing import List, Union - -from autoarray.plot.wrap.one_d.abstract import AbstractMatWrap1D -from autoarray.structures.arrays.uniform_1d import Array1D - - -class FillBetween(AbstractMatWrap1D): - def __init__(self, match_color_to_yx: bool = True, **kwargs): - """ - Fills between two lines on a 1D plot of y versus x using the method `plt.fill_between`. - - This method is typically called after `plot_y_vs_x` to add a shaded region to the figure. - - This object wraps the following Matplotlib methods: - - - plt.fill_between: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.fill_between.html - - Parameters - ---------- - match_color_to_yx - If True, the color of the shaded region is automatically matched to that of the yx line that is plotted, - irrespective of the user inputs. - """ - super().__init__(**kwargs) - self.match_color_to_yx = match_color_to_yx - - def fill_between_shaded_regions( - self, - x: Union[np.ndarray, Array1D, List], - y1: Union[np.ndarray, Array1D, List], - y2: Union[np.ndarray, Array1D, List], - ): - """ - Fill in between two lines `y1` and `y2` on a plot of y vs x. - - Parameters - ---------- - x - The xdata that is plotted. - y1 - The first line of ydata that defines the region that is filled in. - y1 - The second line of ydata that defines the region that is filled in. - """ - import matplotlib.pyplot as plt - - config_dict = self.config_dict - - if self.match_color_to_yx: - config_dict["color"] = plt.gca().lines[-1].get_color() - - plt.fill_between(x=x, y1=y1, y2=y2, **config_dict) diff --git a/autoarray/plot/wrap/one_d/yx_plot.py b/autoarray/plot/wrap/one_d/yx_plot.py deleted file mode 100644 index d679b91e3..000000000 --- a/autoarray/plot/wrap/one_d/yx_plot.py +++ /dev/null @@ -1,93 +0,0 @@ -import numpy as np -from typing import Union - -from autoarray.plot.wrap.one_d.abstract import AbstractMatWrap1D -from autoarray.structures.arrays.uniform_1d import Array1D - -from autoarray import exc - - -class YXPlot(AbstractMatWrap1D): - def __init__(self, plot_axis_type=None, label=None, **kwargs): - """ - Plots 1D data structures as a y vs x figure. - - This object wraps the following Matplotlib methods: - - - plt.plot: https://matplotlib.org/3.3.3/api/_as_gen/matplotlib.pyplot.plot.html - """ - - super().__init__(**kwargs) - - self.plot_axis_type = plot_axis_type - self.label = label - - def plot_y_vs_x( - self, - y: Union[np.ndarray, Array1D], - x: Union[np.ndarray, Array1D], - label: str = None, - plot_axis_type=None, - y_errors=None, - x_errors=None, - y_extra=None, - y_extra_2=None, - ls_errorbar="", - ): - """ - Plots 1D y-data against 1D x-data using the matplotlib method `plt.plot`, `plt.semilogy`, `plt.loglog`, - or `plt.scatter`. - - Parameters - ---------- - y - The ydata that is plotted. - x - The xdata that is plotted. - plot_axis_type - The method used to make the plot that defines the scale of the axes {"linear", "semilogy", "loglog", - "scatter"}. - label - Optionally include a label on the plot for a `Legend` to display. - """ - import matplotlib.pyplot as plt - - if self.label is not None: - label = self.label - - if plot_axis_type == "linear" or plot_axis_type == "symlog": - plt.plot(x, y, label=label, **self.config_dict) - elif plot_axis_type == "semilogy": - plt.semilogy(x, y, label=label, **self.config_dict) - elif plot_axis_type == "loglog": - plt.loglog(x, y, label=label, **self.config_dict) - elif plot_axis_type == "scatter": - plt.scatter(x, y, label=label, **self.config_dict) - elif plot_axis_type == "errorbar" or plot_axis_type == "errorbar_logy": - plt.errorbar( - x, - y, - yerr=y_errors, - xerr=x_errors, - # marker="o", - fmt="o", - # ls=ls_errorbar, - **self.config_dict, - ) - if plot_axis_type == "errorbar_logy": - plt.yscale("log") - else: - raise exc.PlottingException( - "The plot_axis_type supplied to the plotter is not a valid string (must be linear " - "{semilogy, loglog})" - ) - - if y_extra is not None: - if isinstance(y_extra, list): - for y_extra_ in y_extra: - plt.plot(x, y_extra_) - else: - plt.plot(x, y_extra, c="r") - - if y_extra_2 is not None: - plt.plot(x, y_extra_2, c="r") diff --git a/autoarray/plot/wrap/one_d/yx_scatter.py b/autoarray/plot/wrap/one_d/yx_scatter.py deleted file mode 100644 index 5e3c1d93a..000000000 --- a/autoarray/plot/wrap/one_d/yx_scatter.py +++ /dev/null @@ -1,38 +0,0 @@ -import numpy as np -from typing import Union - -from autoarray.plot.wrap.one_d.abstract import AbstractMatWrap1D -from autoarray.structures.grids.uniform_1d import Grid1D - - -class YXScatter(AbstractMatWrap1D): - def __init__(self, **kwargs): - """ - Scatters a 1D set of points on a 1D plot. Unlike the `YXPlot` object these are scattered over an existing plot. - - This object wraps the following Matplotlib methods: - - - plt.scatter: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.scatter.html - """ - - super().__init__(**kwargs) - - def scatter_yx(self, y: Union[np.ndarray, Grid1D], x: list): - """ - Plot an input grid of (y,x) coordinates using the matplotlib method `plt.scatter`. - - Parameters - ---------- - grid - The points that are - errors - The error on every point of the grid that is plotted. - """ - import matplotlib.pyplot as plt - - config_dict = self.config_dict - - if len(config_dict["c"]) > 1: - config_dict["c"] = config_dict["c"][0] - - plt.scatter(y=y, x=x, **config_dict) diff --git a/autoarray/plot/wrap/two_d/__init__.py b/autoarray/plot/wrap/two_d/__init__.py deleted file mode 100644 index 8490a9af0..000000000 --- a/autoarray/plot/wrap/two_d/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from .array_overlay import ArrayOverlay -from .contour import Contour -from .fill import Fill -from .grid_scatter import GridScatter -from .grid_plot import GridPlot -from .grid_errorbar import GridErrorbar -from .vector_yx_quiver import VectorYXQuiver -from .patch_overlay import PatchOverlay -from .delaunay_drawer import DelaunayDrawer -from .origin_scatter import OriginScatter -from .mask_scatter import MaskScatter -from .border_scatter import BorderScatter -from .positions_scatter import PositionsScatter -from .index_scatter import IndexScatter -from .index_plot import IndexPlot -from .mesh_grid_scatter import MeshGridScatter -from .parallel_overscan_plot import ParallelOverscanPlot -from .serial_prescan_plot import SerialPrescanPlot -from .serial_overscan_plot import SerialOverscanPlot diff --git a/autoarray/plot/wrap/two_d/abstract.py b/autoarray/plot/wrap/two_d/abstract.py deleted file mode 100644 index 9348e1c9e..000000000 --- a/autoarray/plot/wrap/two_d/abstract.py +++ /dev/null @@ -1,18 +0,0 @@ -from autoarray.plot.wrap.base.abstract import set_backend - -set_backend() - -from autoarray.plot.wrap.base.abstract import AbstractMatWrap - - -class AbstractMatWrap2D(AbstractMatWrap): - """ - An abstract base class for wrapping matplotlib plotting methods which take as input and plot data structures. For - example, the `ArrayOverlay` object specifically plots `Array2D` data structures. - - As full description of the matplotlib wrapping can be found in `mat_base.AbstractMatWrap`. - """ - - @property - def config_folder(self): - return "mat_wrap_2d" diff --git a/autoarray/plot/wrap/two_d/array_overlay.py b/autoarray/plot/wrap/two_d/array_overlay.py deleted file mode 100644 index 372bb5f6c..000000000 --- a/autoarray/plot/wrap/two_d/array_overlay.py +++ /dev/null @@ -1,26 +0,0 @@ -import matplotlib.pyplot as plt - -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D -from autoarray.mask.derive.zoom_2d import Zoom2D - - -class ArrayOverlay(AbstractMatWrap2D): - """ - Overlays an `Array2D` data structure over a figure. - - This object wraps the following Matplotlib method: - - - plt.imshow: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.imshow.html - - This uses the `Units` and coordinate system of the `Array2D` to overlay it on on the coordinate system of the - figure that is plotted. - """ - - def overlay_array(self, array, figure): - aspect = figure.aspect_from(shape_native=array.shape_native) - - zoom = Zoom2D(mask=array.mask) - array_zoom = zoom.array_2d_from(array=array, buffer=0) - extent = array_zoom.geometry.extent - - plt.imshow(X=array.native, aspect=aspect, extent=extent, **self.config_dict) diff --git a/autoarray/plot/wrap/two_d/border_scatter.py b/autoarray/plot/wrap/two_d/border_scatter.py deleted file mode 100644 index 25839470d..000000000 --- a/autoarray/plot/wrap/two_d/border_scatter.py +++ /dev/null @@ -1,11 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_scatter import GridScatter - - -class BorderScatter(GridScatter): - """ - Plots a border over an image, using the `Mask2d` object's (y,x) `border` property. - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ - - pass diff --git a/autoarray/plot/wrap/two_d/contour.py b/autoarray/plot/wrap/two_d/contour.py deleted file mode 100644 index 164fe86be..000000000 --- a/autoarray/plot/wrap/two_d/contour.py +++ /dev/null @@ -1,105 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -from typing import List, Optional, Union - -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D -from autoarray.structures.arrays.uniform_2d import Array2D - - -class Contour(AbstractMatWrap2D): - def __init__( - self, - manual_levels: Optional[List[float]] = None, - total_contours: Optional[int] = None, - use_log10: Optional[bool] = None, - include_values: Optional[bool] = None, - **kwargs, - ): - """ - Customizes the contours of the plotted figure. - - This object wraps the following Matplotlib method: - - - plt.contours: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.contours.html - - Parameters - ---------- - manual_levels - Manually override the levels at which the contours are plotted. - total_contours - The total number of contours plotted, which also determines the spacing between each contour. - use_log10 - Whether the contours are plotted with a log10 spacing between each contour (alternative is linear). - include_values - Whether the values of the contours are included on the figure. - """ - - super().__init__(**kwargs) - - self.manual_levels = manual_levels - self.total_contours = total_contours or self.config_dict.get("total_contours") - self.use_log10 = use_log10 or self.config_dict.get("use_log10") - self.include_values = include_values or self.config_dict.get("include_values") - - def levels_from( - self, array: Union[np.ndarray, Array2D] - ) -> Union[np.ndarray, List[float]]: - """ - The levels at which the contours are plotted, which may be determined in the following ways: - - - Automatically computed from the minimum and maximum values of the array, using a log10 or linear spacing. - - Overriden by the input `manual_levels` (e.g. if it is not None). - - Returns - ------- - The levels at which the contours are plotted. - """ - if self.manual_levels is None: - if self.use_log10: - min_value = np.min(array) - if min_value < self.log10_min_value: - min_value = self.log10_min_value - - return np.logspace( - np.log10(min_value), - np.log10(np.max(array)), - self.total_contours, - ) - return np.linspace(np.min(array), np.max(array), self.total_contours) - - return self.manual_levels - - def set( - self, - array: Union[np.ndarray, Array2D], - extent: List[float] = None, - use_log10: bool = False, - ): - """ - Plot an input grid of (y,x) coordinates using the matplotlib method `plt.scatter`. - - Parameters - ---------- - array - The array of values the contours are plotted over. - """ - - if not use_log10: - if self.kwargs.get("is_default") is True: - return - - config_dict = self.config_dict - config_dict.pop("total_contours") - config_dict.pop("use_log10") - config_dict.pop("include_values") - - levels = self.levels_from(array.array) - - ax = plt.contour( - array.native.array[::-1], levels=levels, extent=extent, **config_dict - ) - if self.include_values: - try: - ax.clabel(levels=levels, inline=True, fontsize=10) - except ValueError: - pass diff --git a/autoarray/plot/wrap/two_d/delaunay_drawer.py b/autoarray/plot/wrap/two_d/delaunay_drawer.py deleted file mode 100644 index 915f5d660..000000000 --- a/autoarray/plot/wrap/two_d/delaunay_drawer.py +++ /dev/null @@ -1,120 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -from typing import Optional - -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D -from autoarray.plot.wrap.base.units import Units - -from autoarray.plot.wrap import base as wb - - -def facecolors_from(values, simplices): - facecolors = np.zeros(shape=simplices.shape[0]) - for i in range(simplices.shape[0]): - facecolors[i] = np.sum(1.0 / 3.0 * values[simplices[i, :]]) - - return facecolors - - -class DelaunayDrawer(AbstractMatWrap2D): - """ - Draws Delaunay pixels from a `MapperDelaunay` object (see `inversions.mapper`). This includes both drawing - each Delaunay cell and coloring it according to a color value. - - The mapper contains the grid of (y,x) coordinate where the centre of each Delaunay cell is plotted. - - This object wraps methods described in below: - - https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.fill.html - """ - - def draw_delaunay_pixels( - self, - mapper, - pixel_values: Optional[np.ndarray], - units: Units, - cmap: Optional[wb.Cmap], - colorbar: Optional[wb.Colorbar], - colorbar_tickparams: Optional[wb.ColorbarTickParams] = None, - ax=None, - use_log10: bool = False, - ): - """ - Draws the Delaunay pixels of the input `mapper` using its `mesh_grid` which contains the (y,x) - coordinate of the centre of every Delaunay cell. This uses the method `plt.fill`. - - Parameters - ---------- - mapper - A mapper object which contains the Delaunay mesh. - pixel_values - An array used to compute the color values that every Delaunay cell is plotted using. - cmap - The colormap used to plot each Delaunay cell. - colorbar - The `Colorbar` object in `mat_base` used to set the colorbar of the figure the Delaunay mesh is plotted on. - colorbar_tickparams - The `ColorbarTickParams` object in `mat_base` used to set the tick labels of the colorbar. - ax - The matplotlib axis the Delaunay mesh is plotted on. - use_log10 - If `True`, the colorbar is plotted using a log10 scale. - """ - - if pixel_values is None: - pixel_values = np.zeros(shape=mapper.source_plane_mesh_grid.shape[0]) - - pixel_values = np.asarray(pixel_values) - - if ax is None: - ax = plt.gca() - - source_pixelization_grid = mapper.source_plane_mesh_grid - - simplices = mapper.interpolator.delaunay.simplices - - # Remove padded -1 values required for JAX - simplices = np.asarray(simplices) - valid_mask = np.all(simplices >= 0, axis=1) - simplices = simplices[valid_mask] - - facecolors = facecolors_from(values=pixel_values, simplices=simplices) - - norm = cmap.norm_from(array=pixel_values, use_log10=use_log10) - - if use_log10: - pixel_values[pixel_values < 1e-4] = 1e-4 - pixel_values = np.log10(pixel_values) - - vmin = cmap.vmin_from(array=pixel_values, use_log10=use_log10) - vmax = cmap.vmax_from(array=pixel_values, use_log10=use_log10) - - color_values = np.where(pixel_values > vmax, vmax, pixel_values) - color_values = np.where(pixel_values < vmin, vmin, color_values) - - cmap = plt.get_cmap(cmap.cmap) - - if colorbar is not None: - cb = colorbar.set_with_color_values( - units=units, - norm=norm, - cmap=cmap, - color_values=color_values, - ax=ax, - use_log10=use_log10, - ) - - if cb is not None and colorbar_tickparams is not None: - colorbar_tickparams.set(cb=cb) - - ax.tripcolor( - source_pixelization_grid.array[:, 1], - source_pixelization_grid.array[:, 0], - simplices, - facecolors=facecolors, - edgecolors="None", - cmap=cmap, - vmin=vmin, - vmax=vmax, - **self.config_dict, - ) diff --git a/autoarray/plot/wrap/two_d/fill.py b/autoarray/plot/wrap/two_d/fill.py deleted file mode 100644 index f580dde54..000000000 --- a/autoarray/plot/wrap/two_d/fill.py +++ /dev/null @@ -1,38 +0,0 @@ -import logging - -import matplotlib.pyplot as plt - -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D - - -logger = logging.getLogger(__name__) - - -class Fill(AbstractMatWrap2D): - def __init__(self, **kwargs): - """ - The settings used to customize plots using fill on a figure - - This object wraps the following Matplotlib methods: - - - plt.fill https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.fill.html - - Parameters - ---------- - symmetric - If True, the colormap normalization (e.g. `vmin` and `vmax`) span the same absolute values producing a - symmetric color bar. - """ - - super().__init__(**kwargs) - - def plot_fill(self, fill_region): - - try: - y_fill = fill_region[:, 0] - x_fill = fill_region[:, 1] - except TypeError: - y_fill = fill_region[0] - x_fill = fill_region[1] - - plt.fill(x_fill, y_fill, **self.config_dict) diff --git a/autoarray/plot/wrap/two_d/grid_errorbar.py b/autoarray/plot/wrap/two_d/grid_errorbar.py deleted file mode 100644 index f29a259a1..000000000 --- a/autoarray/plot/wrap/two_d/grid_errorbar.py +++ /dev/null @@ -1,147 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -import itertools -from typing import List, Union, Optional - -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D -from autoarray.structures.grids.uniform_2d import Grid2D -from autoarray.structures.grids.irregular_2d import Grid2DIrregular - - -class GridErrorbar(AbstractMatWrap2D): - """ - Plots an input set of grid points with 2D errors, for example (y,x) coordinates or data structures representing 2D - (y,x) coordinates like a `Grid2D` or `Grid2DIrregular`. Multiple lists of (y,x) coordinates are plotted with - varying colors. - - This object wraps the following Matplotlib methods: - - - plt.errorbar: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.errorbar.html - - Parameters - ---------- - colors : [str] - The color or list of colors that the grid is plotted using. For plotting indexes or a grid list, a - list of colors can be specified which the plot cycles through. - """ - - def config_dict_remove_marker(self, config_dict): - if config_dict.get("fmt") and config_dict.get("marker"): - config_dict.pop("marker") - - return config_dict - - def errorbar_grid( - self, - grid: Union[np.ndarray, Grid2D], - y_errors: Optional[Union[np.ndarray, List]] = None, - x_errors: Optional[Union[np.ndarray, List]] = None, - ): - """ - Plot an input grid of (y,x) coordinates using the matplotlib method `plt.errorbar`. - - The (y,x) coordinates are plotted as dots, with a line / cross for its errors. - - Parameters - ---------- - grid : Grid2D - The grid of (y,x) coordinates that is plotted. - y_errors - The y values of the error on every point of the grid that is plotted (e.g. vertically). - x_errors - The x values of the error on every point of the grid that is plotted (e.g. horizontally). - """ - - config_dict = self.config_dict - - if len(config_dict["c"]) > 1: - config_dict["c"] = config_dict["c"][0] - - config_dict = self.config_dict_remove_marker(config_dict=config_dict) - - try: - plt.errorbar( - y=grid[:, 0], x=grid[:, 1], yerr=y_errors, xerr=x_errors, **config_dict - ) - except (IndexError, TypeError): - return self.errorbar_grid_list(grid_list=grid) - - def errorbar_grid_list( - self, - grid_list: Union[List[Grid2D], List[Grid2DIrregular]], - y_errors: Optional[Union[np.ndarray, List]] = None, - x_errors: Optional[Union[np.ndarray, List]] = None, - ): - """ - Plot an input list of grids of (y,x) coordinates using the matplotlib method `plt.errorbar`. - - The (y,x) coordinates are plotted as dots, with a line / cross for its errors. - - This method colors each grid in each entry of the list the same, so that the different grids are visible in - the plot. - - Parameters - ---------- - grid_list - The list of grids of (y,x) coordinates that are plotted. - """ - if len(grid_list) == 0: - return - - color = itertools.cycle(self.config_dict["c"]) - config_dict = self.config_dict - config_dict.pop("c") - - config_dict = self.config_dict_remove_marker(config_dict=config_dict) - - try: - for grid in grid_list: - plt.errorbar( - y=grid[:, 0], - x=grid[:, 1], - yerr=np.asarray(y_errors), - xerr=np.asarray(x_errors), - c=next(color), - **config_dict, - ) - except IndexError: - return None - - def errorbar_grid_colored( - self, - grid: Union[np.ndarray, Grid2D], - color_array: np.ndarray, - cmap: str, - y_errors: Optional[Union[np.ndarray, List]] = None, - x_errors: Optional[Union[np.ndarray, List]] = None, - ): - """ - Plot an input grid of (y,x) coordinates using the matplotlib method `plt.errorbar`. - - The method colors the errorbared grid according to an input ndarray of color values, using an input colormap. - - Parameters - ---------- - grid : Grid2D - The grid of (y,x) coordinates that is plotted. - color_array : ndarray - The array of RGB color values used to color the grid. - cmap - The Matplotlib colormap used for the grid point coloring. - """ - - config_dict = self.config_dict - config_dict.pop("c") - - plt.scatter(y=grid[:, 0], x=grid[:, 1], c=color_array, cmap=cmap) - - config_dict = self.config_dict_remove_marker(config_dict=self.config_dict) - - plt.errorbar( - y=grid[:, 0], - x=grid[:, 1], - yerr=np.asarray(y_errors), - xerr=np.asarray(x_errors), - zorder=0.0, - **config_dict, - ) diff --git a/autoarray/plot/wrap/two_d/grid_plot.py b/autoarray/plot/wrap/two_d/grid_plot.py deleted file mode 100644 index c591c214d..000000000 --- a/autoarray/plot/wrap/two_d/grid_plot.py +++ /dev/null @@ -1,116 +0,0 @@ -import numpy as np -import itertools -from typing import List, Union, Tuple - -from autoarray.geometry.geometry_2d import Geometry2D -from autoarray.operators.contour import Grid2DContour -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D -from autoarray.structures.grids.uniform_2d import Grid2D -from autoarray.structures.grids.irregular_2d import Grid2DIrregular - - -class GridPlot(AbstractMatWrap2D): - """ - Plots `Grid2D` data structure that are better visualized as solid lines, for example rectangular lines that are - plotted over an image and grids of (y,x) coordinates as lines (as opposed to a scatter of points - using the `GridScatter` object). - - This object wraps the following Matplotlib methods: - - - plt.plot: https://matplotlib.org/3.3.3/api/_as_gen/matplotlib.pyplot.plot.html - - Parameters - ---------- - colors : [str] - The color or list of colors that the grid is plotted using. For plotting indexes or a grid list, a - list of colors can be specified which the plot cycles through. - """ - - def plot_rectangular_grid_lines( - self, extent: Tuple[float, float, float, float], shape_native: Tuple[int, int] - ): - """ - Plots a rectangular grid of lines on a plot, using the coordinate system of the figure. - - The size and shape of the grid is specified by the `extent` and `shape_native` properties of a data structure - which will provide the rectangaular grid lines on a suitable coordinate system for the plot. - - Parameters - ---------- - extent : (float, float, float, float) - The extent of the rectangular grid, with format [xmin, xmax, ymin, ymax] - shape_native - The 2D shape of the mask the array is paired with. - """ - import matplotlib.pyplot as plt - - ys = np.linspace(extent[2], extent[3], shape_native[1] + 1) - xs = np.linspace(extent[0], extent[1], shape_native[0] + 1) - - config_dict = self.config_dict - config_dict.pop("c") - config_dict["c"] = "k" - - # grid lines - for x in xs: - plt.plot([x, x], [ys[0], ys[-1]], **config_dict) - for y in ys: - plt.plot([xs[0], xs[-1]], [y, y], **config_dict) - - def plot_grid(self, grid: Union[np.ndarray, Grid2D]): - """ - Plot an input grid of (y,x) coordinates using the matplotlib method `plt.plot`. - - Parameters - ---------- - grid - The grid of (y,x) coordinates that is plotted. - """ - import matplotlib.pyplot as plt - - try: - color = self.config_dict["c"] - - if isinstance(color, list): - color = color[0] - - config_dict = self.config_dict - config_dict.pop("c") - - plt.plot(grid[:, 1], grid[:, 0], c=color, **config_dict) - except (IndexError, TypeError): - self.plot_grid_list(grid_list=grid) - - def plot_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular]]): - """ - Plot an input list of grids of (y,x) coordinates using the matplotlib method `plt.line`. - - This method colors each grid in the list the same, so that the different grids are visible in the plot. - - This provides an alternative to `GridScatter.scatter_grid_list` where the plotted grids appear as lines - instead of scattered points. - - Parameters - ---------- - grid_list - The list of grids of (y,x) coordinates that are plotted. - """ - import matplotlib.pyplot as plt - - if len(grid_list) == 0: - return None - - color = itertools.cycle(self.config_dict["c"]) - config_dict = self.config_dict - config_dict.pop("c") - - try: - for grid in grid_list: - try: - plt.plot(grid[:, 1], grid[:, 0], c=next(color), **config_dict) - except ValueError: - plt.plot( - grid.array[:, 1], grid.array[:, 0], c=next(color), **config_dict - ) - except IndexError: - pass diff --git a/autoarray/plot/wrap/two_d/grid_scatter.py b/autoarray/plot/wrap/two_d/grid_scatter.py deleted file mode 100644 index e9b9879d0..000000000 --- a/autoarray/plot/wrap/two_d/grid_scatter.py +++ /dev/null @@ -1,147 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -import itertools -from typing import List, Union - - -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D -from autoarray.structures.grids.uniform_2d import Grid2D -from autoarray.structures.grids.irregular_2d import Grid2DIrregular - - -class GridScatter(AbstractMatWrap2D): - """ - Scatters an input set of grid points, for example (y,x) coordinates or data structures representing 2D (y,x) - coordinates like a `Grid2D` or `Grid2DIrregular`. List of (y,x) coordinates are plotted with varying colors. - - This object wraps the following Matplotlib methods: - - - plt.scatter: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.scatter.html - - There are a number of children of this method in the `mat_obj.py` module that plot specific sets of (y,x) - points. Each of these objects uses uses their own config file and settings so that each has a unique appearance - on every figure: - - - `OriginScatter`: plots the (y,x) coordinates of the origin of a data structure (e.g. as a black cross). - - `MaskScatter`: plots a mask over an image, using the `Mask2d` object's (y,x) `edge` property. - - `BorderScatter: plots a border over an image, using the `Mask2d` object's (y,x) `border` property. - - `PositionsScatter`: plots the (y,x) coordinates that are input in a plotter via the `positions` input. - - `IndexScatter`: plots specific (y,x) coordinates of a grid (or grids) via their 1d or 2d indexes. - - `MeshGridScatter`: plots the grid of a `Mesh` object (see `autoarray.inversion`). - - Parameters - ---------- - colors : [str] - The color or list of colors that the grid is plotted using. For plotting indexes or a grid list, a - list of colors can be specified which the plot cycles through. - """ - - def scatter_grid(self, grid: Union[np.ndarray, Grid2D]): - """ - Plot an input grid of (y,x) coordinates using the matplotlib method `plt.scatter`. - - Parameters - ---------- - grid : Grid2D - The grid of (y,x) coordinates that is plotted. - errors - The error on every point of the grid that is plotted. - """ - - config_dict = self.config_dict - - if len(config_dict["c"]) > 1: - config_dict["c"] = config_dict["c"][0] - - try: - plt.scatter(y=grid[:, 0], x=grid[:, 1], **config_dict) - except (IndexError, TypeError): - return self.scatter_grid_list(grid_list=grid) - - def scatter_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular]]): - """ - Plot an input list of grids of (y,x) coordinates using the matplotlib method `plt.scatter`. - - This method colors each grid in each entry of the list the same, so that the different grids are visible in - the plot. - - Parameters - ---------- - grid_list - The list of grids of (y,x) coordinates that are plotted. - """ - if len(grid_list) == 0: - return - - color = itertools.cycle(self.config_dict["c"]) - config_dict = self.config_dict - config_dict.pop("c") - - try: - for grid in grid_list: - try: - plt.scatter( - y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict - ) - except ValueError: - plt.scatter( - y=grid.array[:, 0], - x=grid.array[:, 1], - c=next(color), - **config_dict, - ) - except IndexError: - return None - - def scatter_grid_colored( - self, grid: Union[np.ndarray, Grid2D], color_array: np.ndarray, cmap: str - ): - """ - Plot an input grid of (y,x) coordinates using the matplotlib method `plt.scatter`. - - The method colors the scattered grid according to an input ndarray of color values, using an input colormap. - - Parameters - ---------- - grid : Grid2D - The grid of (y,x) coordinates that is plotted. - color_array : ndarray - The array of RGB color values used to color the grid. - cmap - The Matplotlib colormap used for the grid point coloring. - """ - - config_dict = self.config_dict - config_dict.pop("c") - - plt.scatter(y=grid[:, 0], x=grid[:, 1], c=color_array, cmap=cmap, **config_dict) - - def scatter_grid_indexes( - self, - grid: Union[np.ndarray, Grid2D, Grid2DIrregular], - indexes: np.ndarray, - ): - """ - Plot specific points of an input grid of (y,x) coordinates, which are specified according to the 1D or 2D - indexes of the `Grid2D`. - - This method allows us to color in points on grids that map between one another. - - Parameters - ---------- - grid : Grid2D - The grid of (y,x) coordinates that is plotted. - indexes - The 1D indexes of the grid that are colored in when plotted. - """ - color = itertools.cycle(self.config_dict["c"]) - config_dict = self.config_dict - config_dict.pop("c") - - for index_list in indexes: - plt.scatter( - y=grid[index_list, 0], - x=grid[index_list, 1], - color=next(color), - **config_dict, - ) diff --git a/autoarray/plot/wrap/two_d/index_plot.py b/autoarray/plot/wrap/two_d/index_plot.py deleted file mode 100644 index e6dd584ba..000000000 --- a/autoarray/plot/wrap/two_d/index_plot.py +++ /dev/null @@ -1,11 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_plot import GridPlot - - -class IndexPlot(GridPlot): - """ - Plots specific (y,x) coordinates of a grid (or grids) via their 1d or 2d indexes. - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ - - pass diff --git a/autoarray/plot/wrap/two_d/index_scatter.py b/autoarray/plot/wrap/two_d/index_scatter.py deleted file mode 100644 index a427c036b..000000000 --- a/autoarray/plot/wrap/two_d/index_scatter.py +++ /dev/null @@ -1,11 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_scatter import GridScatter - - -class IndexScatter(GridScatter): - """ - Plots specific (y,x) coordinates of a grid (or grids) via their 1d or 2d indexes. - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ - - pass diff --git a/autoarray/plot/wrap/two_d/mask_scatter.py b/autoarray/plot/wrap/two_d/mask_scatter.py deleted file mode 100644 index ed264571c..000000000 --- a/autoarray/plot/wrap/two_d/mask_scatter.py +++ /dev/null @@ -1,9 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_scatter import GridScatter - - -class MaskScatter(GridScatter): - """ - Plots a mask over an image, using the `Mask2d` object's (y,x) `edge` property. - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ diff --git a/autoarray/plot/wrap/two_d/mesh_grid_scatter.py b/autoarray/plot/wrap/two_d/mesh_grid_scatter.py deleted file mode 100644 index 7826fd55e..000000000 --- a/autoarray/plot/wrap/two_d/mesh_grid_scatter.py +++ /dev/null @@ -1,9 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_scatter import GridScatter - - -class MeshGridScatter(GridScatter): - """ - Plots the grid of a `Mesh` object (see `autoarray.inversion`). - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ diff --git a/autoarray/plot/wrap/two_d/origin_scatter.py b/autoarray/plot/wrap/two_d/origin_scatter.py deleted file mode 100644 index 97a438f8f..000000000 --- a/autoarray/plot/wrap/two_d/origin_scatter.py +++ /dev/null @@ -1,9 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_scatter import GridScatter - - -class OriginScatter(GridScatter): - """ - Plots the (y,x) coordinates of the origin of a data structure (e.g. as a black cross). - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ diff --git a/autoarray/plot/wrap/two_d/parallel_overscan_plot.py b/autoarray/plot/wrap/two_d/parallel_overscan_plot.py deleted file mode 100644 index 81fd49ef4..000000000 --- a/autoarray/plot/wrap/two_d/parallel_overscan_plot.py +++ /dev/null @@ -1,9 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_plot import GridPlot - - -class ParallelOverscanPlot(GridPlot): - """ - Plots the lines of a parallel overscan `Region2D` object. - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ diff --git a/autoarray/plot/wrap/two_d/patch_overlay.py b/autoarray/plot/wrap/two_d/patch_overlay.py deleted file mode 100644 index 5075eb5a4..000000000 --- a/autoarray/plot/wrap/two_d/patch_overlay.py +++ /dev/null @@ -1,30 +0,0 @@ -from matplotlib import patches as ptch -from typing import Union - -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D - - -class PatchOverlay(AbstractMatWrap2D): - """ - Adds patches to a plotted figure using matplotlib `patches` objects. - - The coordinate system of each `Patch` uses that of the figure, which is typically set up using the plotted - data structure. This makes it straight forward to add patches in specific locations. - - This object wraps methods described in below: - - https://matplotlib.org/3.3.2/api/collections_api.html - """ - - def overlay_patches(self, patches: Union[ptch.Patch]): - """ - Overlay a list of patches on a figure, for example an `Ellipse`. - ` - Parameters - ---------- - patches : [Patch] - The patches that are laid over the figure. - """ - - # patch_collection = PatchCollection(patches=patches, **self.config_dict) - # plt.gcf().gca().add_collection(patch_collection) diff --git a/autoarray/plot/wrap/two_d/positions_scatter.py b/autoarray/plot/wrap/two_d/positions_scatter.py deleted file mode 100644 index cdebaffec..000000000 --- a/autoarray/plot/wrap/two_d/positions_scatter.py +++ /dev/null @@ -1,11 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_scatter import GridScatter - - -class PositionsScatter(GridScatter): - """ - Plots the (y,x) coordinates that are input in a plotter via the `positions` input. - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ - - pass diff --git a/autoarray/plot/wrap/two_d/serial_overscan_plot.py b/autoarray/plot/wrap/two_d/serial_overscan_plot.py deleted file mode 100644 index af25c3891..000000000 --- a/autoarray/plot/wrap/two_d/serial_overscan_plot.py +++ /dev/null @@ -1,9 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_plot import GridPlot - - -class SerialOverscanPlot(GridPlot): - """ - Plots the lines of a serial overscan `Region2D` object. - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ diff --git a/autoarray/plot/wrap/two_d/serial_prescan_plot.py b/autoarray/plot/wrap/two_d/serial_prescan_plot.py deleted file mode 100644 index 4b0360120..000000000 --- a/autoarray/plot/wrap/two_d/serial_prescan_plot.py +++ /dev/null @@ -1,9 +0,0 @@ -from autoarray.plot.wrap.two_d.grid_plot import GridPlot - - -class SerialPrescanPlot(GridPlot): - """ - Plots the lines of a serial prescan `Region2D` object. - - See `wrap.base.Scatter` for a description of how matplotlib is wrapped to make this plot. - """ diff --git a/autoarray/plot/wrap/two_d/vector_yx_quiver.py b/autoarray/plot/wrap/two_d/vector_yx_quiver.py deleted file mode 100644 index e8dd7523d..000000000 --- a/autoarray/plot/wrap/two_d/vector_yx_quiver.py +++ /dev/null @@ -1,34 +0,0 @@ -import matplotlib.pyplot as plt - -from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D -from autoarray.structures.vectors.irregular import VectorYX2DIrregular - - -class VectorYXQuiver(AbstractMatWrap2D): - """ - Plots a `VectorField` data structure. A vector field is a set of 2D vectors on a grid of 2d (y,x) coordinates. - These are plotted as arrows representing the (y,x) components of each vector at each (y,x) coordinate of it - grid. - - This object wraps the following Matplotlib method: - - https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.quiver.html - """ - - def quiver_vectors(self, vectors: VectorYX2DIrregular): - """ - Plot a vector field using the matplotlib method `plt.quiver` such that each vector appears as an arrow whose - direction depends on the y and x magnitudes of the vector. - - Parameters - ---------- - vectors : VectorYX2DIrregular - The vector field that is plotted using `plt.quiver`. - """ - plt.quiver( - vectors.grid[:, 1], - vectors.grid[:, 0], - vectors[:, 1], - vectors[:, 0], - **self.config_dict, - ) diff --git a/autoarray/plot/yx.py b/autoarray/plot/yx.py new file mode 100644 index 000000000..fc26ad7ba --- /dev/null +++ b/autoarray/plot/yx.py @@ -0,0 +1,144 @@ +""" +Standalone function for plotting 1D y-vs-x data. + +Replaces ``MatPlot1D.plot_yx`` / ``MatWrap`` system. +""" + +from typing import List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np + +from autoarray.plot.utils import apply_labels, conf_figsize, save_figure + + +def plot_yx( + y, + x=None, + ax: Optional[plt.Axes] = None, + # --- errors / extras -------------------------------------------------------- + y_errors: Optional[np.ndarray] = None, + x_errors: Optional[np.ndarray] = None, + y_extra: Optional[np.ndarray] = None, + shaded_region: Optional[Tuple[np.ndarray, np.ndarray]] = None, + # --- cosmetics -------------------------------------------------------------- + title: str = "", + xlabel: str = "", + ylabel: str = "", + label: Optional[str] = None, + color: str = "b", + linestyle: str = "-", + plot_axis_type: str = "linear", + # --- figure control (used only when ax is None) ----------------------------- + figsize: Optional[Tuple[int, int]] = None, + output_path: Optional[str] = None, + output_filename: str = "yx", + output_format: str = "png", +) -> None: + """ + Plot 1D y versus x data. + + Replaces ``MatPlot1D.plot_yx`` with direct matplotlib calls. + + Parameters + ---------- + y + 1D numpy array of y values. + x + 1D numpy array of x values. When ``None`` integer indices are used. + ax + Existing ``Axes`` to draw onto. ``None`` creates a new figure. + y_errors, x_errors + Per-point error values; trigger ``plt.errorbar``. + y_extra + Optional second y series to overlay. + shaded_region + Tuple ``(y1, y2)`` arrays; filled region drawn with alpha. + title + Figure title. + xlabel, ylabel + Axis labels. + label + Legend label for the main series. + color + Line / marker colour. + linestyle + Line style string. + plot_axis_type + One of ``"linear"``, ``"log"``, ``"loglog"``, ``"symlog"``. + figsize + Figure size in inches. + output_path + Directory for saving. Empty / ``None`` calls ``plt.show()``. + output_filename + Base file name without extension. + output_format + File format, e.g. ``"png"``. + """ + # --- autoarray extraction -------------------------------------------------- + if x is None and hasattr(y, "grid_radial"): + x = y.grid_radial + y = y.array if hasattr(y, "array") else np.asarray(y) + if x is not None: + x = x.array if hasattr(x, "array") else np.asarray(x) + + # guard: nothing to draw + if y is None or len(y) == 0 or np.isnan(y).all(): + return + + owns_figure = ax is None + if owns_figure: + figsize = figsize or conf_figsize("figures") + fig, ax = plt.subplots(1, 1, figsize=figsize) + else: + fig = ax.get_figure() + + if x is None: + x = np.arange(len(y)) + + # --- main line / scatter --------------------------------------------------- + if y_errors is not None or x_errors is not None: + ax.errorbar( + x, + y, + yerr=y_errors, + xerr=x_errors, + fmt="-o", + color=color, + label=label, + markersize=3, + ) + elif plot_axis_type == "scatter": + ax.scatter(x, y, s=2, c=color, label=label) + elif plot_axis_type in ("log", "semilogy"): + ax.semilogy(x, y, color=color, linestyle=linestyle, label=label) + elif plot_axis_type == "loglog": + ax.loglog(x, y, color=color, linestyle=linestyle, label=label) + else: + ax.plot(x, y, color=color, linestyle=linestyle, label=label) + + if plot_axis_type == "symlog": + ax.set_yscale("symlog") + + # --- extras ---------------------------------------------------------------- + if y_extra is not None: + ax.plot(x, y_extra, color="r", linestyle="--", alpha=0.7) + + if shaded_region is not None: + y1, y2 = shaded_region + ax.fill_between(x, y1, y2, alpha=0.3) + + # --- labels ---------------------------------------------------------------- + apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel) + + if label is not None: + ax.legend(fontsize=12) + + # --- output ---------------------------------------------------------------- + if owns_figure: + save_figure( + fig, + path=output_path or "", + filename=output_filename, + format=output_format, + ) diff --git a/autoarray/structures/arrays/rgb.py b/autoarray/structures/arrays/rgb.py index 8e2171f23..fae210db6 100644 --- a/autoarray/structures/arrays/rgb.py +++ b/autoarray/structures/arrays/rgb.py @@ -10,7 +10,7 @@ def __init__(self, values, mask): the same functionality as `Array2D` objects. By passing an RGB image to this class, the following visualization functionality is used when the RGB - image is used in `Plotter` objects: + image is used in plotting: - The RGB image is plotted using the `imshow` function of Matplotlib. - Functionality which sets the scale of the axis, zooms the image, and sets the axis limits is used. diff --git a/autoarray/structures/plot/__init__.py b/autoarray/structures/plot/__init__.py index e69de29bb..56a5b7c31 100644 --- a/autoarray/structures/plot/__init__.py +++ b/autoarray/structures/plot/__init__.py @@ -0,0 +1,5 @@ +from autoarray.structures.plot.structure_plots import ( + plot_array_2d, + plot_grid_2d, + plot_yx_1d, +) diff --git a/autoarray/structures/plot/structure_plots.py b/autoarray/structures/plot/structure_plots.py new file mode 100644 index 000000000..7b739c99f --- /dev/null +++ b/autoarray/structures/plot/structure_plots.py @@ -0,0 +1,12 @@ +""" +Thin convenience aliases that forward directly to the core plot functions. + +``plot_array``, ``plot_grid``, and ``plot_yx`` now accept autoarray objects +natively, so these wrappers exist only for name-discoverability. +""" + +from autoarray.plot.array import plot_array as plot_array_2d +from autoarray.plot.grid import plot_grid as plot_grid_2d +from autoarray.plot.yx import plot_yx as plot_yx_1d + +__all__ = ["plot_array_2d", "plot_grid_2d", "plot_yx_1d"] diff --git a/autoarray/structures/plot/structure_plotters.py b/autoarray/structures/plot/structure_plotters.py deleted file mode 100644 index 7e7cf655e..000000000 --- a/autoarray/structures/plot/structure_plotters.py +++ /dev/null @@ -1,185 +0,0 @@ -import numpy as np -from typing import List, Optional, Union - -from autoarray.plot.abstract_plotters import AbstractPlotter -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.mat_plot.one_d import MatPlot1D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.structures.arrays.uniform_1d import Array1D -from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.structures.grids.uniform_1d import Grid1D -from autoarray.structures.grids.uniform_2d import Grid2D - - -class Array2DPlotter(AbstractPlotter): - def __init__( - self, - array: Array2D, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots `Array2D` objects using the matplotlib method `imshow()` and many other matplotlib functions which - customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Array2D` and plotted via the visuals object. - - Parameters - ---------- - array - The 2D array the plotter plot. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) - - self.array = array - - def figure_2d(self): - """ - Plots the plotter's `Array2D` object in 2D. - """ - self.mat_plot_2d.plot_array( - array=self.array, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Array2D", filename="array"), - ) - - -class Grid2DPlotter(AbstractPlotter): - def __init__( - self, - grid: Grid2D, - mat_plot_2d: MatPlot2D = None, - visuals_2d: Visuals2D = None, - ): - """ - Plots `Grid2D` objects using the matplotlib method `scatter()` and many other matplotlib functions which - customize the plot's appearance. - - The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot2d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Grid2D` and plotted via the visuals object. - - Parameters - ---------- - grid - The 2D grid the plotter plot. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - """ - super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) - - self.grid = grid - - def figure_2d( - self, - color_array: np.ndarray = None, - plot_grid_lines: bool = False, - plot_over_sampled_grid: bool = False, - ): - """ - Plots the plotter's `Grid2D` object in 2D. - - Parameters - ---------- - color_array - An array of RGB color values which can be used to give the plotted 2D grid a colorscale (w/ colorbar). - plot_grid_lines - If True, a rectangular grid of lines is plotted on the figure showing the pixels which the grid coordinates - are centred on. - plot_over_sampled_grid - If True, the grid is plotted with over-sampled sub-gridded coordinates based on the `sub_size` attribute - of the grid's over-sampling object. - """ - self.mat_plot_2d.plot_grid( - grid=self.grid, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Grid2D", filename="grid"), - color_array=color_array, - plot_grid_lines=plot_grid_lines, - plot_over_sampled_grid=plot_over_sampled_grid, - ) - - -class YX1DPlotter(AbstractPlotter): - def __init__( - self, - y: Union[Array1D, List], - x: Optional[Union[Array1D, Grid1D, List]] = None, - mat_plot_1d: MatPlot1D = None, - visuals_1d: Visuals1D = None, - should_plot_grid: bool = False, - should_plot_zero: bool = False, - plot_axis_type: Optional[str] = None, - plot_yx_dict=None, - auto_labels=AutoLabels(), - ): - """ - Plots two 1D objects using the matplotlib method `plot()` (or a similar method) and many other matplotlib - functions which customize the plot's appearance. - - The `mat_plot_1d` attribute wraps matplotlib function calls to make the figure. By default, the settings - passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files, - but a user can manually input values into `MatPlot1d` to customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals1D` object. Attributes may be extracted from - the `Array1D` and plotted via the visuals object. - - Parameters - ---------- - y - The 1D y values the plotter plot. - x - The 1D x values the plotter plot. - mat_plot_1d - Contains objects which wrap the matplotlib function calls that make 1D plots. - visuals_1d - Contains 1D visuals that can be overlaid on 1D plots. - """ - - if isinstance(y, list): - y = Array1D.no_mask(values=y, pixel_scales=1.0) - - if isinstance(x, list): - x = Array1D.no_mask(values=x, pixel_scales=1.0) - - super().__init__(visuals_1d=visuals_1d, mat_plot_1d=mat_plot_1d) - - self.y = y - self.x = y.grid_radial if x is None else x - self.should_plot_grid = should_plot_grid - self.should_plot_zero = should_plot_zero - self.plot_axis_type = plot_axis_type - self.plot_yx_dict = plot_yx_dict or {} - self.auto_labels = auto_labels - - def figure_1d(self): - """ - Plots the plotter's y and x values in 1D. - """ - - self.mat_plot_1d.plot_yx( - y=self.y, - x=self.x, - visuals_1d=self.visuals_1d, - auto_labels=self.auto_labels, - should_plot_grid=self.should_plot_grid, - should_plot_zero=self.should_plot_zero, - plot_axis_type_override=self.plot_axis_type, - **self.plot_yx_dict, - ) diff --git a/pyproject.toml b/pyproject.toml index b70191ea6..044254db3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ optional=[ "tensorflow-probability==0.25.0" ] test = ["pytest"] -dev = ["pytest", "black"] +dev = ["pytest", "black", "numba", "pynufft==2022.2.2"] [tool.pytest.ini_options] testpaths = ["test_autoarray"] \ No newline at end of file diff --git a/test_autoarray/conftest.py b/test_autoarray/conftest.py index c35f26bf6..5ac460b65 100644 --- a/test_autoarray/conftest.py +++ b/test_autoarray/conftest.py @@ -1,5 +1,8 @@ +import jax import jax.numpy as jnp +jax.config.update("jax_enable_x64", True) + def pytest_configure(): _ = jnp.sum(jnp.array([0.0])) # Force backend init @@ -26,6 +29,9 @@ def __call__(self, path, *args, **kwargs): def make_plot_patch(monkeypatch): plot_patch = PlotPatch() monkeypatch.setattr(pyplot, "savefig", plot_patch) + from matplotlib.figure import Figure + + monkeypatch.setattr(Figure, "savefig", plot_patch) return plot_patch diff --git a/test_autoarray/dataset/plot/test_imaging_plotters.py b/test_autoarray/dataset/plot/test_imaging_plotters.py index 86c881325..fdd071968 100644 --- a/test_autoarray/dataset/plot/test_imaging_plotters.py +++ b/test_autoarray/dataset/plot/test_imaging_plotters.py @@ -1,84 +1,108 @@ -from os import path -import pytest -import autoarray as aa -import autoarray.plot as aplt - - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "imaging" - ) - - -def test__individual_attributes_are_output( - imaging_7x7, grid_2d_irregular_7x7_list, mask_2d_7x7, plot_path, plot_patch -): - visuals = aplt.Visuals2D(mask=mask_2d_7x7, positions=grid_2d_irregular_7x7_list) - - dataset_plotter = aplt.ImagingPlotter( - dataset=imaging_7x7, - visuals_2d=visuals, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), - ) - - dataset_plotter.figures_2d( - data=True, - noise_map=True, - psf=True, - signal_to_noise_map=True, - over_sample_size_lp=True, - over_sample_size_pixelization=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") in plot_patch.paths - assert path.join(plot_path, "psf.png") in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "over_sample_size_lp.png") in plot_patch.paths - assert path.join(plot_path, "over_sample_size_pixelization.png") in plot_patch.paths - - plot_patch.paths = [] - - dataset_plotter.figures_2d( - data=True, - psf=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert not path.join(plot_path, "noise_map.png") in plot_patch.paths - assert path.join(plot_path, "psf.png") in plot_patch.paths - assert not path.join(plot_path, "signal_to_noise_map.png") in plot_patch.paths - - -def test__subplot_is_output( - imaging_7x7, grid_2d_irregular_7x7_list, mask_2d_7x7, plot_path, plot_patch -): - dataset_plotter = aplt.ImagingPlotter( - dataset=imaging_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")), - ) - - dataset_plotter.subplot_dataset() - - assert path.join(plot_path, "subplot_dataset.png") in plot_patch.paths - - -def test__output_as_fits__correct_output_format( - imaging_7x7, grid_2d_irregular_7x7_list, mask_2d_7x7, plot_path, plot_patch -): - dataset_plotter = aplt.ImagingPlotter( - dataset=imaging_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="fits")), - ) - - dataset_plotter.figures_2d(data=True, psf=True) - - image_from_plot = aa.ndarray_via_fits_from( - file_path=path.join(plot_path, "data.fits"), hdu=0 - ) - - assert image_from_plot.shape == (7, 7) +from os import path +import pytest +import autoarray as aa +import autoarray.plot as aplt + + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plot_path_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "imaging" + ) + + +def test__individual_attributes_are_output( + imaging_7x7, grid_2d_irregular_7x7_list, mask_2d_7x7, plot_path, plot_patch +): + aplt.plot_array_2d( + array=imaging_7x7.data, + positions=grid_2d_irregular_7x7_list, + output_path=plot_path, + output_filename="data", + output_format="png", + ) + assert path.join(plot_path, "data.png") in plot_patch.paths + + aplt.plot_array_2d( + array=imaging_7x7.noise_map, + output_path=plot_path, + output_filename="noise_map", + output_format="png", + ) + assert path.join(plot_path, "noise_map.png") in plot_patch.paths + + if imaging_7x7.psf is not None: + aplt.plot_array_2d( + array=imaging_7x7.psf.kernel, + output_path=plot_path, + output_filename="psf", + output_format="png", + ) + assert path.join(plot_path, "psf.png") in plot_patch.paths + + aplt.plot_array_2d( + array=imaging_7x7.signal_to_noise_map, + output_path=plot_path, + output_filename="signal_to_noise_map", + output_format="png", + ) + assert path.join(plot_path, "signal_to_noise_map.png") in plot_patch.paths + + aplt.plot_array_2d( + array=imaging_7x7.grids.over_sample_size_lp, + output_path=plot_path, + output_filename="over_sample_size_lp", + output_format="png", + ) + assert path.join(plot_path, "over_sample_size_lp.png") in plot_patch.paths + + aplt.plot_array_2d( + array=imaging_7x7.grids.over_sample_size_pixelization, + output_path=plot_path, + output_filename="over_sample_size_pixelization", + output_format="png", + ) + assert path.join(plot_path, "over_sample_size_pixelization.png") in plot_patch.paths + + plot_patch.paths = [] + + aplt.plot_array_2d( + array=imaging_7x7.data, + output_path=plot_path, + output_filename="data", + output_format="png", + ) + assert path.join(plot_path, "data.png") in plot_patch.paths + assert not path.join(plot_path, "noise_map.png") in plot_patch.paths + + +def test__subplot_is_output( + imaging_7x7, grid_2d_irregular_7x7_list, mask_2d_7x7, plot_path, plot_patch +): + aplt.subplot_imaging_dataset( + dataset=imaging_7x7, + output_path=plot_path, + output_format="png", + ) + + assert path.join(plot_path, "subplot_dataset.png") in plot_patch.paths + + +def test__output_as_fits__correct_output_format( + imaging_7x7, grid_2d_irregular_7x7_list, mask_2d_7x7, plot_path, plot_patch +): + aplt.plot_array_2d( + array=imaging_7x7.data, + output_path=plot_path, + output_filename="data", + output_format="fits", + ) + + image_from_plot = aa.ndarray_via_fits_from( + file_path=path.join(plot_path, "data.fits"), hdu=0 + ) + + assert image_from_plot.shape == (7, 7) diff --git a/test_autoarray/dataset/plot/test_interferometer_plotters.py b/test_autoarray/dataset/plot/test_interferometer_plotters.py index 0297f081a..1715c1d4c 100644 --- a/test_autoarray/dataset/plot/test_interferometer_plotters.py +++ b/test_autoarray/dataset/plot/test_interferometer_plotters.py @@ -1,79 +1,79 @@ -from os import path - -import pytest -import autoarray.plot as aplt - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), - "files", - "plots", - "interferometer", - ) - - -def test__individual_attributes_are_output(interferometer_7, plot_path, plot_patch): - dataset_plotter = aplt.InterferometerPlotter( - dataset=interferometer_7, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - dataset_plotter.figures_2d( - data=True, - noise_map=True, - u_wavelengths=True, - v_wavelengths=True, - uv_wavelengths=True, - amplitudes_vs_uv_distances=True, - phases_vs_uv_distances=True, - dirty_image=True, - dirty_noise_map=True, - dirty_signal_to_noise_map=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") in plot_patch.paths - assert path.join(plot_path, "u_wavelengths.png") in plot_patch.paths - assert path.join(plot_path, "v_wavelengths.png") in plot_patch.paths - assert path.join(plot_path, "uv_wavelengths.png") in plot_patch.paths - assert path.join(plot_path, "amplitudes_vs_uv_distances.png") in plot_patch.paths - assert path.join(plot_path, "phases_vs_uv_distances.png") in plot_patch.paths - assert path.join(plot_path, "dirty_image.png") in plot_patch.paths - assert path.join(plot_path, "dirty_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "dirty_signal_to_noise_map.png") in plot_patch.paths - - plot_patch.paths = [] - - dataset_plotter.figures_2d( - data=True, - u_wavelengths=False, - v_wavelengths=True, - amplitudes_vs_uv_distances=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert not path.join(plot_path, "u_wavelengths.png") in plot_patch.paths - assert path.join(plot_path, "v_wavelengths.png") in plot_patch.paths - assert path.join(plot_path, "amplitudes_vs_uv_distances.png") in plot_patch.paths - assert path.join(plot_path, "phases_vs_uv_distances.png") not in plot_patch.paths - - -def test__subplots_are_output(interferometer_7, plot_path, plot_patch): - dataset_plotter = aplt.InterferometerPlotter( - dataset=interferometer_7, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - dataset_plotter.subplot_dataset() - - assert path.join(plot_path, "subplot_dataset.png") in plot_patch.paths - - dataset_plotter.subplot_dirty_images() - - assert path.join(plot_path, "subplot_dirty_images.png") in plot_patch.paths +from os import path + +import pytest +import autoarray.plot as aplt + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plot_path_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), + "files", + "plots", + "interferometer", + ) + + +def test__individual_attributes_are_output(interferometer_7, plot_path, plot_patch): + aplt.plot_grid_2d( + grid=interferometer_7.data.in_grid, + output_path=plot_path, + output_filename="data", + output_format="png", + ) + assert path.join(plot_path, "data.png") in plot_patch.paths + + aplt.plot_array_2d( + array=interferometer_7.dirty_image, + output_path=plot_path, + output_filename="dirty_image", + output_format="png", + ) + assert path.join(plot_path, "dirty_image.png") in plot_patch.paths + + aplt.plot_array_2d( + array=interferometer_7.dirty_noise_map, + output_path=plot_path, + output_filename="dirty_noise_map", + output_format="png", + ) + assert path.join(plot_path, "dirty_noise_map.png") in plot_patch.paths + + aplt.plot_array_2d( + array=interferometer_7.dirty_signal_to_noise_map, + output_path=plot_path, + output_filename="dirty_signal_to_noise_map", + output_format="png", + ) + assert path.join(plot_path, "dirty_signal_to_noise_map.png") in plot_patch.paths + + plot_patch.paths = [] + + aplt.plot_grid_2d( + grid=interferometer_7.data.in_grid, + output_path=plot_path, + output_filename="data", + output_format="png", + ) + assert path.join(plot_path, "data.png") in plot_patch.paths + assert not path.join(plot_path, "dirty_image.png") in plot_patch.paths + + +def test__subplots_are_output(interferometer_7, plot_path, plot_patch): + aplt.subplot_interferometer_dataset( + dataset=interferometer_7, + output_path=plot_path, + output_format="png", + ) + + assert path.join(plot_path, "subplot_dataset.png") in plot_patch.paths + + aplt.subplot_interferometer_dirty_images( + dataset=interferometer_7, + output_path=plot_path, + output_format="png", + ) + + assert path.join(plot_path, "subplot_dirty_images.png") in plot_patch.paths diff --git a/test_autoarray/fit/plot/test_fit_imaging_plotters.py b/test_autoarray/fit/plot/test_fit_imaging_plotters.py index 22223ff61..7930f1a80 100644 --- a/test_autoarray/fit/plot/test_fit_imaging_plotters.py +++ b/test_autoarray/fit/plot/test_fit_imaging_plotters.py @@ -1,87 +1,130 @@ -import autoarray as aa -import autoarray.plot as aplt -import pytest -from os import path - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), - "files", - "plots", - "fit_dataset", - ) - - -def test__fit_quantities_are_output(fit_imaging_7x7, plot_path, plot_patch): - fit_plotter = aplt.FitImagingPlotter( - fit=fit_imaging_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_plotter.figures_2d( - data=True, - noise_map=True, - signal_to_noise_map=True, - model_image=True, - residual_map=True, - normalized_residual_map=True, - chi_squared_map=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "model_image.png") in plot_patch.paths - assert path.join(plot_path, "residual_map.png") in plot_patch.paths - assert path.join(plot_path, "normalized_residual_map.png") in plot_patch.paths - assert path.join(plot_path, "chi_squared_map.png") in plot_patch.paths - - plot_patch.paths = [] - - fit_plotter.figures_2d( - data=True, - noise_map=False, - signal_to_noise_map=False, - model_image=True, - chi_squared_map=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") not in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") not in plot_patch.paths - assert path.join(plot_path, "model_image.png") in plot_patch.paths - assert path.join(plot_path, "residual_map.png") not in plot_patch.paths - assert path.join(plot_path, "normalized_residual_map.png") not in plot_patch.paths - assert path.join(plot_path, "chi_squared_map.png") in plot_patch.paths - - -def test__fit_sub_plot(fit_imaging_7x7, plot_path, plot_patch): - fit_plotter = aplt.FitImagingPlotter( - fit=fit_imaging_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_plotter.subplot_fit() - - assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths - - -def test__output_as_fits__correct_output_format( - fit_imaging_7x7, grid_2d_irregular_7x7_list, mask_2d_7x7, plot_path, plot_patch -): - fit_plotter = aplt.FitImagingPlotter( - fit=fit_imaging_7x7, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="fits")), - ) - - fit_plotter.figures_2d(data=True) - - image_from_plot = aa.ndarray_via_fits_from( - file_path=path.join(plot_path, "data.fits"), hdu=0 - ) - - assert image_from_plot.shape == (5, 5) +import autoarray as aa +import autoarray.plot as aplt +import pytest +from os import path + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plot_path_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), + "files", + "plots", + "fit_dataset", + ) + + +def test__fit_quantities_are_output(fit_imaging_7x7, plot_path, plot_patch): + aplt.plot_array_2d( + array=fit_imaging_7x7.data, + output_path=plot_path, + output_filename="data", + output_format="png", + ) + assert path.join(plot_path, "data.png") in plot_patch.paths + + aplt.plot_array_2d( + array=fit_imaging_7x7.noise_map, + output_path=plot_path, + output_filename="noise_map", + output_format="png", + ) + assert path.join(plot_path, "noise_map.png") in plot_patch.paths + + aplt.plot_array_2d( + array=fit_imaging_7x7.signal_to_noise_map, + output_path=plot_path, + output_filename="signal_to_noise_map", + output_format="png", + ) + assert path.join(plot_path, "signal_to_noise_map.png") in plot_patch.paths + + aplt.plot_array_2d( + array=fit_imaging_7x7.model_data, + output_path=plot_path, + output_filename="model_image", + output_format="png", + ) + assert path.join(plot_path, "model_image.png") in plot_patch.paths + + aplt.plot_array_2d( + array=fit_imaging_7x7.residual_map, + output_path=plot_path, + output_filename="residual_map", + output_format="png", + ) + assert path.join(plot_path, "residual_map.png") in plot_patch.paths + + aplt.plot_array_2d( + array=fit_imaging_7x7.normalized_residual_map, + output_path=plot_path, + output_filename="normalized_residual_map", + output_format="png", + ) + assert path.join(plot_path, "normalized_residual_map.png") in plot_patch.paths + + aplt.plot_array_2d( + array=fit_imaging_7x7.chi_squared_map, + output_path=plot_path, + output_filename="chi_squared_map", + output_format="png", + ) + assert path.join(plot_path, "chi_squared_map.png") in plot_patch.paths + + plot_patch.paths = [] + + aplt.plot_array_2d( + array=fit_imaging_7x7.data, + output_path=plot_path, + output_filename="data", + output_format="png", + ) + aplt.plot_array_2d( + array=fit_imaging_7x7.model_data, + output_path=plot_path, + output_filename="model_image", + output_format="png", + ) + aplt.plot_array_2d( + array=fit_imaging_7x7.chi_squared_map, + output_path=plot_path, + output_filename="chi_squared_map", + output_format="png", + ) + + assert path.join(plot_path, "data.png") in plot_patch.paths + assert path.join(plot_path, "noise_map.png") not in plot_patch.paths + assert path.join(plot_path, "signal_to_noise_map.png") not in plot_patch.paths + assert path.join(plot_path, "model_image.png") in plot_patch.paths + assert path.join(plot_path, "residual_map.png") not in plot_patch.paths + assert path.join(plot_path, "normalized_residual_map.png") not in plot_patch.paths + assert path.join(plot_path, "chi_squared_map.png") in plot_patch.paths + + +def test__fit_sub_plot(fit_imaging_7x7, plot_path, plot_patch): + aplt.subplot_fit_imaging( + fit=fit_imaging_7x7, + output_path=plot_path, + output_format="png", + ) + + assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths + + +def test__output_as_fits__correct_output_format( + fit_imaging_7x7, grid_2d_irregular_7x7_list, mask_2d_7x7, plot_path, plot_patch +): + aplt.plot_array_2d( + array=fit_imaging_7x7.data, + output_path=plot_path, + output_filename="data", + output_format="fits", + ) + + image_from_plot = aa.ndarray_via_fits_from( + file_path=path.join(plot_path, "data.fits"), hdu=0 + ) + + assert image_from_plot.shape == (5, 5) diff --git a/test_autoarray/fit/plot/test_fit_interferometer_plotters.py b/test_autoarray/fit/plot/test_fit_interferometer_plotters.py index f4b55c264..c9f7decdf 100644 --- a/test_autoarray/fit/plot/test_fit_interferometer_plotters.py +++ b/test_autoarray/fit/plot/test_fit_interferometer_plotters.py @@ -1,138 +1,133 @@ -import autoarray.plot as aplt -import pytest - -from os import path - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), - "files", - "plots", - "fit_dataset", - ) - - -def test__fit_quantities_are_output(fit_interferometer_7, plot_path, plot_patch): - fit_plotter = aplt.FitInterferometerPlotter( - fit=fit_interferometer_7, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_plotter.figures_2d( - data=True, - noise_map=True, - signal_to_noise_map=True, - model_data=True, - residual_map_real=True, - residual_map_imag=True, - normalized_residual_map_real=True, - normalized_residual_map_imag=True, - chi_squared_map_real=True, - chi_squared_map_imag=True, - dirty_image=True, - dirty_noise_map=True, - dirty_signal_to_noise_map=True, - dirty_model_image=True, - dirty_residual_map=True, - dirty_normalized_residual_map=True, - dirty_chi_squared_map=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "model_data.png") in plot_patch.paths - assert ( - path.join(plot_path, "real_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "real_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "real_normalized_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_normalized_residual_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_chi_squared_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_chi_squared_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert path.join(plot_path, "dirty_image.png") in plot_patch.paths - assert path.join(plot_path, "dirty_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "dirty_signal_to_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "dirty_model_image_2d.png") in plot_patch.paths - assert path.join(plot_path, "dirty_residual_map_2d.png") in plot_patch.paths - assert ( - path.join(plot_path, "dirty_normalized_residual_map_2d.png") in plot_patch.paths - ) - assert path.join(plot_path, "dirty_chi_squared_map_2d.png") in plot_patch.paths - - plot_patch.paths = [] - - fit_plotter.figures_2d( - data=True, - noise_map=False, - signal_to_noise_map=False, - model_data=True, - chi_squared_map_real=True, - chi_squared_map_imag=True, - ) - - assert path.join(plot_path, "data.png") in plot_patch.paths - assert path.join(plot_path, "noise_map.png") not in plot_patch.paths - assert path.join(plot_path, "signal_to_noise_map.png") not in plot_patch.paths - assert path.join(plot_path, "model_data.png") in plot_patch.paths - assert ( - path.join(plot_path, "real_residual_map_vs_uv_distances.png") - not in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_residual_map_vs_uv_distances.png") - not in plot_patch.paths - ) - assert ( - path.join(plot_path, "real_normalized_residual_map_vs_uv_distances.png") - not in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_normalized_residual_map_vs_uv_distances.png") - not in plot_patch.paths - ) - assert ( - path.join(plot_path, "real_chi_squared_map_vs_uv_distances.png") - in plot_patch.paths - ) - assert ( - path.join(plot_path, "imag_chi_squared_map_vs_uv_distances.png") - in plot_patch.paths - ) - - -def test__fit_sub_plots(fit_interferometer_7, plot_path, plot_patch): - fit_plotter = aplt.FitInterferometerPlotter( - fit=fit_interferometer_7, - mat_plot_1d=aplt.MatPlot1D(output=aplt.Output(path=plot_path, format="png")), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - fit_plotter.subplot_fit() - - assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths - - fit_plotter.subplot_fit_dirty_images() - - assert path.join(plot_path, "subplot_fit_dirty_images.png") in plot_patch.paths +import autoarray.plot as aplt +import numpy as np +import pytest + +from os import path + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plot_path_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), + "files", + "plots", + "fit_dataset", + ) + + +def test__fit_quantities_are_output(fit_interferometer_7, plot_path, plot_patch): + uv = fit_interferometer_7.dataset.uv_distances / 10**3.0 + + aplt.plot_grid_2d( + grid=fit_interferometer_7.data.in_grid, + output_path=plot_path, + output_filename="data", + output_format="png", + ) + assert path.join(plot_path, "data.png") in plot_patch.paths + + aplt.plot_yx_1d( + y=np.real(fit_interferometer_7.residual_map), + x=uv, + output_path=plot_path, + output_filename="real_residual_map_vs_uv_distances", + output_format="png", + plot_axis_type="scatter", + ) + assert ( + path.join(plot_path, "real_residual_map_vs_uv_distances.png") + in plot_patch.paths + ) + + aplt.plot_yx_1d( + y=np.real(fit_interferometer_7.chi_squared_map), + x=uv, + output_path=plot_path, + output_filename="real_chi_squared_map_vs_uv_distances", + output_format="png", + plot_axis_type="scatter", + ) + assert ( + path.join(plot_path, "real_chi_squared_map_vs_uv_distances.png") + in plot_patch.paths + ) + + aplt.plot_yx_1d( + y=np.imag(fit_interferometer_7.chi_squared_map), + x=uv, + output_path=plot_path, + output_filename="imag_chi_squared_map_vs_uv_distances", + output_format="png", + plot_axis_type="scatter", + ) + assert ( + path.join(plot_path, "imag_chi_squared_map_vs_uv_distances.png") + in plot_patch.paths + ) + + aplt.plot_array_2d( + array=fit_interferometer_7.dirty_image, + output_path=plot_path, + output_filename="dirty_image", + output_format="png", + ) + assert path.join(plot_path, "dirty_image.png") in plot_patch.paths + + plot_patch.paths = [] + + aplt.plot_grid_2d( + grid=fit_interferometer_7.data.in_grid, + output_path=plot_path, + output_filename="data", + output_format="png", + ) + aplt.plot_yx_1d( + y=np.real(fit_interferometer_7.chi_squared_map), + x=uv, + output_path=plot_path, + output_filename="real_chi_squared_map_vs_uv_distances", + output_format="png", + plot_axis_type="scatter", + ) + aplt.plot_yx_1d( + y=np.imag(fit_interferometer_7.chi_squared_map), + x=uv, + output_path=plot_path, + output_filename="imag_chi_squared_map_vs_uv_distances", + output_format="png", + plot_axis_type="scatter", + ) + + assert path.join(plot_path, "data.png") in plot_patch.paths + assert ( + path.join(plot_path, "real_chi_squared_map_vs_uv_distances.png") + in plot_patch.paths + ) + assert ( + path.join(plot_path, "imag_chi_squared_map_vs_uv_distances.png") + in plot_patch.paths + ) + assert ( + path.join(plot_path, "real_residual_map_vs_uv_distances.png") + not in plot_patch.paths + ) + + +def test__fit_sub_plots(fit_interferometer_7, plot_path, plot_patch): + aplt.subplot_fit_interferometer( + fit=fit_interferometer_7, + output_path=plot_path, + output_format="png", + ) + + assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths + + aplt.subplot_fit_interferometer_dirty_images( + fit=fit_interferometer_7, + output_path=plot_path, + output_format="png", + ) + + assert path.join(plot_path, "subplot_fit_dirty_images.png") in plot_patch.paths diff --git a/test_autoarray/fit/test_fit_imaging.py b/test_autoarray/fit/test_fit_imaging.py index 411330d46..980a54e68 100644 --- a/test_autoarray/fit/test_fit_imaging.py +++ b/test_autoarray/fit/test_fit_imaging.py @@ -7,6 +7,7 @@ # Helper: build the "identical model, no masking" fit used by multiple tests # --------------------------------------------------------------------------- + def _make_identical_fit_no_mask(): mask = aa.Mask2D(mask=[[False, False], [False, False]], pixel_scales=(1.0, 1.0)) data = aa.Array2D(values=[1.0, 2.0, 3.0, 4.0], mask=mask) diff --git a/test_autoarray/fit/test_fit_interferometer.py b/test_autoarray/fit/test_fit_interferometer.py index 97e7edf09..5ea71c10d 100644 --- a/test_autoarray/fit/test_fit_interferometer.py +++ b/test_autoarray/fit/test_fit_interferometer.py @@ -8,6 +8,7 @@ # Helpers # --------------------------------------------------------------------------- + def _make_dataset(): real_space_mask = aa.Mask2D( mask=[[False, False], [False, False]], pixel_scales=(1.0, 1.0) diff --git a/test_autoarray/inversion/pixelization/mappers/test_abstract.py b/test_autoarray/inversion/pixelization/mappers/test_abstract.py index eca7894df..48e7dfaa0 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_abstract.py +++ b/test_autoarray/inversion/pixelization/mappers/test_abstract.py @@ -133,7 +133,9 @@ def test__adaptive_pixel_signals_from___matches_util(grid_2d_7x7, image_7x7): assert (pixel_signals == pixel_signals_util).all() -def test__mapped_to_source_from__delaunay_mapper__matches_mapping_matrix_util(grid_2d_7x7): +def test__mapped_to_source_from__delaunay_mapper__matches_mapping_matrix_util( + grid_2d_7x7, +): mesh_grid = aa.Grid2D.no_mask( values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]], shape_native=(3, 2), diff --git a/test_autoarray/inversion/pixelization/mappers/test_delaunay.py b/test_autoarray/inversion/pixelization/mappers/test_delaunay.py index f0f478baf..451dd35c1 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_delaunay.py +++ b/test_autoarray/inversion/pixelization/mappers/test_delaunay.py @@ -29,7 +29,9 @@ def test__pixel_weights_delaunay_from__two_data_points__returns_correct_barycent assert (pixel_weights == np.array([[0.25, 0.5, 0.25], [1.0, 0.0, 0.0]])).all() -def test__pix_indexes_for_sub_slim_index__delaunay_mesh__matches_util_and_expected_values(grid_2d_sub_1_7x7): +def test__pix_indexes_for_sub_slim_index__delaunay_mesh__matches_util_and_expected_values( + grid_2d_sub_1_7x7, +): mesh_grid = aa.Grid2D.no_mask( values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]], shape_native=(3, 2), diff --git a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py index d7d3c26b6..183246ef1 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py +++ b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py @@ -51,7 +51,9 @@ def test__pix_indexes_for_sub_slim_index__rectangular_uniform_mesh__matches_util assert mapper.pix_weights_for_sub_slim_index == pytest.approx(weights, 1.0e-4) -def test__pixel_signals_from__rectangular_adapt_density_mesh__matches_util(grid_2d_sub_1_7x7, image_7x7): +def test__pixel_signals_from__rectangular_adapt_density_mesh__matches_util( + grid_2d_sub_1_7x7, image_7x7 +): mesh_grid = overlay_grid_from( shape_native=(3, 3), grid=grid_2d_sub_1_7x7.over_sampled, buffer=1e-8 diff --git a/test_autoarray/inversion/plot/test_inversion_plotters.py b/test_autoarray/inversion/plot/test_inversion_plotters.py index 225470cb8..d97874161 100644 --- a/test_autoarray/inversion/plot/test_inversion_plotters.py +++ b/test_autoarray/inversion/plot/test_inversion_plotters.py @@ -1,65 +1,69 @@ -from os import path -import autoarray.plot as aplt - -import pytest - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), - "files", - "plots", - "inversion", - ) - - -def test__individual_attributes_are_output_for_all_mappers( - rectangular_inversion_7x7_3x3, - grid_2d_irregular_7x7_list, - plot_path, - plot_patch, -): - inversion_plotter = aplt.InversionPlotter( - inversion=rectangular_inversion_7x7_3x3, - visuals_2d=aplt.Visuals2D(indexes=[0]), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - inversion_plotter.figures_2d(reconstructed_operated_data=True) - - assert path.join(plot_path, "reconstructed_operated_data.png") in plot_patch.paths - - inversion_plotter.figures_2d_of_pixelization( - pixelization_index=0, - reconstructed_operated_data=True, - reconstruction=True, - reconstruction_noise_map=True, - regularization_weights=True, - ) - - assert path.join(plot_path, "reconstructed_operated_data.png") in plot_patch.paths - assert path.join(plot_path, "reconstruction.png") in plot_patch.paths - assert path.join(plot_path, "reconstruction_noise_map.png") in plot_patch.paths - assert path.join(plot_path, "regularization_weights.png") in plot_patch.paths - - -def test__inversion_subplot_of_mapper__is_output_for_all_inversions( - imaging_7x7, - rectangular_inversion_7x7_3x3, - plot_path, - plot_patch, -): - inversion_plotter = aplt.InversionPlotter( - inversion=rectangular_inversion_7x7_3x3, - visuals_2d=aplt.Visuals2D(indexes=[0]), - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - inversion_plotter.subplot_of_mapper(mapper_index=0) - assert path.join(plot_path, "subplot_inversion_0.png") in plot_patch.paths - - inversion_plotter.subplot_mappings(pixelization_index=0) - assert path.join(plot_path, "subplot_mappings_0.png") in plot_patch.paths +from os import path +import autoarray.plot as aplt +from autoarray.inversion.mappers.abstract import Mapper + +import pytest + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plot_path_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), + "files", + "plots", + "inversion", + ) + + +def test__individual_attributes_are_output_for_all_mappers( + rectangular_inversion_7x7_3x3, + grid_2d_irregular_7x7_list, + plot_path, + plot_patch, +): + aplt.plot_array_2d( + array=rectangular_inversion_7x7_3x3.mapped_reconstructed_operated_data, + output_path=plot_path, + output_filename="reconstructed_operated_data", + output_format="png", + ) + + assert path.join(plot_path, "reconstructed_operated_data.png") in plot_patch.paths + + mapper = rectangular_inversion_7x7_3x3.cls_list_from(cls=Mapper)[0] + pixel_values = rectangular_inversion_7x7_3x3.reconstruction_dict[mapper] + + aplt.plot_mapper( + mapper=mapper, + solution_vector=pixel_values, + output_path=plot_path, + output_filename="reconstruction", + output_format="png", + ) + + assert path.join(plot_path, "reconstruction.png") in plot_patch.paths + + +def test__inversion_subplot_of_mapper__is_output_for_all_inversions( + imaging_7x7, + rectangular_inversion_7x7_3x3, + plot_path, + plot_patch, +): + aplt.subplot_of_mapper( + inversion=rectangular_inversion_7x7_3x3, + mapper_index=0, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_inversion_0.png") in plot_patch.paths + + aplt.subplot_mappings( + inversion=rectangular_inversion_7x7_3x3, + pixelization_index=0, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_mappings_0.png") in plot_patch.paths diff --git a/test_autoarray/inversion/plot/test_mapper_plotters.py b/test_autoarray/inversion/plot/test_mapper_plotters.py index 2f0f7cb04..dc60c3025 100644 --- a/test_autoarray/inversion/plot/test_mapper_plotters.py +++ b/test_autoarray/inversion/plot/test_mapper_plotters.py @@ -1,60 +1,46 @@ -from os import path -import pytest - -import autoarray as aa -import autoarray.plot as aplt - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "structures" - ) - - -def test__figure_2d( - rectangular_mapper_7x7_3x3, - delaunay_mapper_9_3x3, - plot_path, - plot_patch, -): - visuals_2d = aplt.Visuals2D( - indexes=[[(0, 0), (0, 1)], [(1, 2)]], - ) - - mat_plot_2d = aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="mapper1", format="png") - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, - visuals_2d=visuals_2d, - mat_plot_2d=mat_plot_2d, - ) - - mapper_plotter.figure_2d() - - assert path.join(plot_path, "mapper1.png") in plot_patch.paths - - -def test__subplot_image_and_mapper( - imaging_7x7, - rectangular_mapper_7x7_3x3, - delaunay_mapper_9_3x3, - plot_path, - plot_patch, -): - visuals_2d = aplt.Visuals2D(indexes=[0, 1, 2]) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, - visuals_2d=visuals_2d, - mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - ) - - mapper_plotter.subplot_image_and_mapper( - image=imaging_7x7.data, - ) - assert path.join(plot_path, "subplot_image_and_mapper.png") in plot_patch.paths +from os import path +import pytest + +import autoarray as aa +import autoarray.plot as aplt + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plot_path_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), "files", "structures" + ) + + +def test__plot_mapper( + rectangular_mapper_7x7_3x3, + delaunay_mapper_9_3x3, + plot_path, + plot_patch, +): + aplt.plot_mapper( + mapper=rectangular_mapper_7x7_3x3, + output_path=plot_path, + output_filename="mapper1", + output_format="png", + ) + + assert path.join(plot_path, "mapper1.png") in plot_patch.paths + + +def test__subplot_image_and_mapper( + imaging_7x7, + rectangular_mapper_7x7_3x3, + delaunay_mapper_9_3x3, + plot_path, + plot_patch, +): + aplt.subplot_image_and_mapper( + mapper=rectangular_mapper_7x7_3x3, + image=imaging_7x7.data, + output_path=plot_path, + output_format="png", + ) + assert path.join(plot_path, "subplot_image_and_mapper.png") in plot_patch.paths diff --git a/test_autoarray/inversion/regularizations/test_regularization_util.py b/test_autoarray/inversion/regularizations/test_regularization_util.py index 97b80332b..e2ff20a22 100644 --- a/test_autoarray/inversion/regularizations/test_regularization_util.py +++ b/test_autoarray/inversion/regularizations/test_regularization_util.py @@ -602,7 +602,9 @@ def splitted_data(): return splitted_mappings, splitted_sizes, splitted_weights -def test__reg_split_from__splitted_mapping_data__produces_correct_mappings_sizes_and_weights(splitted_data): +def test__reg_split_from__splitted_mapping_data__produces_correct_mappings_sizes_and_weights( + splitted_data, +): splitted_mappings, splitted_sizes, splitted_weights = splitted_data @@ -673,7 +675,9 @@ def test__reg_split_from__splitted_mapping_data__produces_correct_mappings_sizes assert splitted_weights == pytest.approx(expected_weights, abs=1.0e-4) -def test__pixel_splitted_regularization_matrix_from__uniform_weights__correct_regularization_matrix(splitted_data): +def test__pixel_splitted_regularization_matrix_from__uniform_weights__correct_regularization_matrix( + splitted_data, +): splitted_mappings, splitted_sizes, splitted_weights = splitted_data @@ -701,7 +705,9 @@ def test__pixel_splitted_regularization_matrix_from__uniform_weights__correct_re assert pytest.approx(regularization_matrix, 1e-4) == np.array(expected_reg_matrix) -def test__pixel_splitted_regularization_matrix_from__non_uniform_weights__correct_regularization_matrix(splitted_data): +def test__pixel_splitted_regularization_matrix_from__non_uniform_weights__correct_regularization_matrix( + splitted_data, +): splitted_mappings, splitted_sizes, splitted_weights = splitted_data diff --git a/test_autoarray/mask/test_mask_1d.py b/test_autoarray/mask/test_mask_1d.py index f415964d4..035d1f504 100644 --- a/test_autoarray/mask/test_mask_1d.py +++ b/test_autoarray/mask/test_mask_1d.py @@ -62,24 +62,30 @@ def test__constructor__input_is_2d_mask__raises_exception(): # --------------------------------------------------------------------------- -@pytest.mark.parametrize("mask_values,expected", [ - ([False, False, False, False], False), - ([False, False], False), - ([False, True, False, False], False), - ([True, True, True, True], True), -]) +@pytest.mark.parametrize( + "mask_values,expected", + [ + ([False, False, False, False], False), + ([False, False], False), + ([False, True, False, False], False), + ([True, True, True, True], True), + ], +) def test__is_all_true__various_masks__returns_correct_boolean(mask_values, expected): mask = aa.Mask1D(mask=mask_values, pixel_scales=1.0) assert mask.is_all_true == expected -@pytest.mark.parametrize("mask_values,expected", [ - ([False, False, False, False], True), - ([False, False], True), - ([False, True, False, False], False), - ([True, True, False, False], False), -]) +@pytest.mark.parametrize( + "mask_values,expected", + [ + ([False, False, False, False], True), + ([False, False], True), + ([False, True, False, False], False), + ([True, True, False, False], False), + ], +) def test__is_all_false__various_masks__returns_correct_boolean(mask_values, expected): mask = aa.Mask1D(mask=mask_values, pixel_scales=1.0) diff --git a/test_autoarray/mask/test_mask_2d.py b/test_autoarray/mask/test_mask_2d.py index e45e7e9a7..2295c6647 100644 --- a/test_autoarray/mask/test_mask_2d.py +++ b/test_autoarray/mask/test_mask_2d.py @@ -414,7 +414,9 @@ def test__from_fits__output_to_fits__roundtrip_preserves_values_pixel_scales_and @pytest.mark.parametrize("resized_shape", [(1, 1), (5, 5)]) -def test__from_fits__with_resized_mask_shape__output_shape_matches_requested_shape(resized_shape): +def test__from_fits__with_resized_mask_shape__output_shape_matches_requested_shape( + resized_shape, +): mask = aa.Mask2D.from_fits( file_path=path.join(test_data_path, "3x3_ones.fits"), hdu=0, @@ -443,24 +445,30 @@ def test__constructor__1d_mask_without_shape_native__raises_mask_exception(): # --------------------------------------------------------------------------- -@pytest.mark.parametrize("mask_values,expected", [ - ([[False, False], [False, False]], False), - ([[False, False]], False), - ([[False, True], [False, False]], False), - ([[True, True], [True, True]], True), -]) +@pytest.mark.parametrize( + "mask_values,expected", + [ + ([[False, False], [False, False]], False), + ([[False, False]], False), + ([[False, True], [False, False]], False), + ([[True, True], [True, True]], True), + ], +) def test__is_all_true__various_masks__returns_correct_boolean(mask_values, expected): mask = aa.Mask2D(mask=mask_values, pixel_scales=1.0) assert mask.is_all_true == expected -@pytest.mark.parametrize("mask_values,expected", [ - ([[False, False], [False, False]], True), - ([[False, False]], True), - ([[False, True], [False, False]], False), - ([[True, True], [False, False]], False), -]) +@pytest.mark.parametrize( + "mask_values,expected", + [ + ([[False, False], [False, False]], True), + ([[False, False]], True), + ([[False, True], [False, False]], False), + ([[True, True], [False, False]], False), + ], +) def test__is_all_false__various_masks__returns_correct_boolean(mask_values, expected): mask = aa.Mask2D(mask=mask_values, pixel_scales=1.0) @@ -472,50 +480,53 @@ def test__is_all_false__various_masks__returns_correct_boolean(mask_values, expe # --------------------------------------------------------------------------- -@pytest.mark.parametrize("mask_values,expected_shape", [ - ( - [ - [True, True, True, True, True, True, True, True, True], - [True, False, False, False, False, False, False, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, False, True, False, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, False, False, False, False, False, False, True], - [True, True, True, True, True, True, True, True, True], - ], - (7, 7), - ), - ( - [ - [True, True, True, True, True, True, True, True, False], - [True, False, False, False, False, False, False, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, False, True, False, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, False, False, False, False, False, False, True], - [True, True, True, True, True, True, True, True, True], - ], - (8, 8), - ), - ( - [ - [True, True, True, True, True, True, True, False, True], - [True, False, False, False, False, False, False, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, False, True, False, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, False, False, False, False, False, False, True], - [True, True, True, True, True, True, True, True, True], - ], - (8, 7), - ), -]) +@pytest.mark.parametrize( + "mask_values,expected_shape", + [ + ( + [ + [True, True, True, True, True, True, True, True, True], + [True, False, False, False, False, False, False, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, False, True, False, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, False, False, False, False, False, False, True], + [True, True, True, True, True, True, True, True, True], + ], + (7, 7), + ), + ( + [ + [True, True, True, True, True, True, True, True, False], + [True, False, False, False, False, False, False, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, False, True, False, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, False, False, False, False, False, False, True], + [True, True, True, True, True, True, True, True, True], + ], + (8, 8), + ), + ( + [ + [True, True, True, True, True, True, True, False, True], + [True, False, False, False, False, False, False, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, False, True, False, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, False, False, False, False, False, False, True], + [True, True, True, True, True, True, True, True, True], + ], + (8, 7), + ), + ], +) def test__shape_native_masked_pixels__various_unmasked_regions__returns_bounding_box_shape( mask_values, expected_shape ): @@ -542,10 +553,13 @@ def test__rescaled_from__5x5_mask_with_one_masked_pixel__rescaled_mask_matches_u assert (mask_rescaled == mask_rescaled_manual).all() -@pytest.mark.parametrize("new_shape,expected_masked_position", [ - ((7, 7), (3, 3)), - ((3, 3), (1, 1)), -]) +@pytest.mark.parametrize( + "new_shape,expected_masked_position", + [ + ((7, 7), (3, 3)), + ((3, 3), (1, 1)), + ], +) def test__resized_from__5x5_mask_with_center_masked__resized_mask_has_masked_pixel_at_center( new_shape, expected_masked_position ): @@ -661,11 +675,14 @@ def test__is_circular__non_circular_mask__returns_false(): assert mask.is_circular == False -@pytest.mark.parametrize("shape_native,radius", [ - ((5, 5), 1.0), - ((10, 10), 3.0), - ((10, 10), 4.0), -]) +@pytest.mark.parametrize( + "shape_native,radius", + [ + ((5, 5), 1.0), + ((10, 10), 3.0), + ((10, 10), 4.0), + ], +) def test__is_circular__circular_mask__returns_true(shape_native, radius): mask = aa.Mask2D.circular( shape_native=shape_native, radius=radius, pixel_scales=(1.0, 1.0) @@ -674,10 +691,13 @@ def test__is_circular__circular_mask__returns_true(shape_native, radius): assert mask.is_circular == True -@pytest.mark.parametrize("shape_native,radius,pixel_scales", [ - ((10, 10), 3.0, (1.0, 1.0)), - ((30, 30), 5.5, (0.5, 0.5)), -]) +@pytest.mark.parametrize( + "shape_native,radius,pixel_scales", + [ + ((10, 10), 3.0, (1.0, 1.0)), + ((30, 30), 5.5, (0.5, 0.5)), + ], +) def test__circular_radius__circular_mask__returns_radius_used_to_create_mask( shape_native, radius, pixel_scales ): diff --git a/test_autoarray/plot/mat_plot/test_mat_plot.py b/test_autoarray/plot/mat_plot/test_mat_plot.py index c87086fb0..c6c4bb68e 100644 --- a/test_autoarray/plot/mat_plot/test_mat_plot.py +++ b/test_autoarray/plot/mat_plot/test_mat_plot.py @@ -1,36 +1,2 @@ -import autoarray.plot as aplt - - -def test__add_mat_plot_objects_together(): - extent = [1.0, 2.0, 3.0, 4.0] - fontsize = 20 - - mat_plot_2d_0 = aplt.MatPlot2D(axis=aplt.Axis(extent=extent)) - - mat_plot_2d_1 = aplt.MatPlot2D(ylabel=aplt.YLabel(fontsize=fontsize)) - - mat_plot_2d = mat_plot_2d_0 + mat_plot_2d_1 - - assert mat_plot_2d.axis.config_dict["extent"] == extent - assert mat_plot_2d.ylabel.config_dict["fontsize"] == 20 - - mat_plot_2d = mat_plot_2d_1 + mat_plot_2d_0 - - assert mat_plot_2d.axis.config_dict["extent"] == extent - assert mat_plot_2d.ylabel.config_dict["fontsize"] == 20 - - units = aplt.Units() - output = aplt.Output(format="png") - - mat_plot_2d_0 = aplt.MatPlot2D( - axis=aplt.Axis(extent=extent), units=units, output=output - ) - mat_plot_2d_1 = aplt.MatPlot2D(ylabel=aplt.YLabel(fontsize=fontsize)) - - mat_plot_2d = mat_plot_2d_0 + mat_plot_2d_1 - - assert mat_plot_2d.output.format == "png" - - mat_plot_2d = mat_plot_2d_1 + mat_plot_2d_0 - - assert mat_plot_2d.output.format == "png" +# MatPlot1D and MatPlot2D have been removed. +# Configuration is now done via direct wrapper objects passed to plotters. diff --git a/test_autoarray/plot/test_abstract_plotters.py b/test_autoarray/plot/test_abstract_plotters.py deleted file mode 100644 index 05c905f7c..000000000 --- a/test_autoarray/plot/test_abstract_plotters.py +++ /dev/null @@ -1,107 +0,0 @@ -from os import path -import pytest -import matplotlib.pyplot as plt - -import autoarray as aa -import autoarray.plot as aplt -from autoarray.plot import abstract_plotters - -directory = path.dirname(path.realpath(__file__)) - - -def test__get_subplot_shape(): - plotter = abstract_plotters.AbstractPlotter(mat_plot_2d=aplt.MatPlot2D()) - - subplot_shape = plotter.mat_plot_2d.get_subplot_shape(number_subplots=1) - - assert subplot_shape == (1, 1) - - subplot_shape = plotter.mat_plot_2d.get_subplot_shape(number_subplots=3) - - assert subplot_shape == (2, 2) - - with pytest.raises(aa.exc.PlottingException): - plotter.mat_plot_2d.get_subplot_shape(number_subplots=1000) - - -# def test__get_subplot_figsize(): -# plotter = abstract_plotters.AbstractPlotter( -# mat_plot_2d=aplt.MatPlot2D(figure=aplt.Figure(figsize="auto")) -# ) -# -# figsize = plotter.get_subplot_figsize(number_subplots=1) -# -# assert figsize == (7, 7) -# -# figsize = plotter.get_subplot_figsize(number_subplots=4) -# -# assert figsize == (7, 7) -# -# figure = aplt.Figure(figsize=(20, 20)) -# -# plotter = abstract_plotters.AbstractPlotter( -# mat_plot_2d=aplt.MatPlot2D(figure=figure) -# ) -# -# figsize = plotter.get_subplot_figsize(number_subplots=4) -# -# assert figsize == (20, 20) - - -def test__open_and_close_subplot_figures(): - figure = aplt.Figure(figsize=(20, 20)) - - plotter = abstract_plotters.AbstractPlotter( - mat_plot_2d=aplt.MatPlot2D(figure=figure) - ) - - plotter.mat_plot_2d.figure.open() - - assert plt.fignum_exists(num=1) is True - - plotter.mat_plot_2d.figure.close() - - assert plt.fignum_exists(num=1) is False - - plotter = abstract_plotters.AbstractPlotter( - mat_plot_2d=aplt.MatPlot2D(figure=figure) - ) - - assert plt.fignum_exists(num=1) is False - - plotter.open_subplot_figure(number_subplots=4) - - assert plt.fignum_exists(num=1) is True - - plotter.mat_plot_2d.figure.close() - - assert plt.fignum_exists(num=1) is False - - -def test__uses_figure_or_subplot_configs_correctly(): - figure = aplt.Figure(figsize=(8, 8)) - cmap = aplt.Cmap(cmap="warm") - - mat_plot_2d = aplt.MatPlot2D(figure=figure, cmap=cmap) - - plotter = abstract_plotters.AbstractPlotter(mat_plot_2d=mat_plot_2d) - - assert plotter.mat_plot_2d.figure.config_dict["figsize"] == (8, 8) - assert plotter.mat_plot_2d.figure.config_dict["aspect"] == "square" - assert plotter.mat_plot_2d.cmap.config_dict["cmap"] == "warm" - assert plotter.mat_plot_2d.cmap.config_dict["norm"] == "linear" - - figure = aplt.Figure() - figure.is_for_subplot = True - - cmap = aplt.Cmap() - cmap.is_for_subplot = True - - mat_plot_2d = aplt.MatPlot2D(figure=figure, cmap=cmap) - - plotter = abstract_plotters.AbstractPlotter(mat_plot_2d=mat_plot_2d) - - assert plotter.mat_plot_2d.figure.config_dict["figsize"] == None - assert plotter.mat_plot_2d.figure.config_dict["aspect"] == "square" - assert plotter.mat_plot_2d.cmap.config_dict["cmap"] == "default" - assert plotter.mat_plot_2d.cmap.config_dict["norm"] == "linear" diff --git a/test_autoarray/plot/test_multi_plotters.py b/test_autoarray/plot/test_multi_plotters.py index 9c2048ac3..dc4cb2f12 100644 --- a/test_autoarray/plot/test_multi_plotters.py +++ b/test_autoarray/plot/test_multi_plotters.py @@ -1,100 +1,2 @@ -from os import path -import pytest -import autoarray as aa -import autoarray.plot as aplt - -import numpy as np - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "imaging" - ) - - -def test__multi_plotter__subplot_of_plotter_list_figure( - imaging_7x7, plot_path, plot_patch -): - mat_plot_2d = aplt.MatPlot2D(output=aplt.Output(plot_path, format="png")) - - plotter_0 = aplt.ImagingPlotter(dataset=imaging_7x7, mat_plot_2d=mat_plot_2d) - plotter_1 = aplt.ImagingPlotter(dataset=imaging_7x7, mat_plot_2d=mat_plot_2d) - - plotter_list = [plotter_0, plotter_1] - - multi_plotter = aplt.MultiFigurePlotter(plotter_list=plotter_list) - multi_plotter.subplot_of_figure(func_name="figures_2d", figure_name="data") - - assert path.join(plot_path, "subplot_data.png") in plot_patch.paths - - plot_patch.paths = [] - - multi_plotter = aplt.MultiFigurePlotter(plotter_list=plotter_list) - multi_plotter.subplot_of_figure( - func_name="figures_2d", figure_name="data", noise_map=True - ) - - assert path.join(plot_path, "subplot_data.png") in plot_patch.paths - - -class MockYX1DPlotter(aplt.YX1DPlotter): - def __init__( - self, - y, - x, - mat_plot_1d: aplt.MatPlot1D = None, - visuals_1d: aplt.Visuals1D = None, - ): - super().__init__( - y=y, - x=x, - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - ) - - def figures_1d(self, figure_name=False): - if figure_name: - self.figure_1d() - - -def test__multi_yx_plotter__subplot_of_plotter_list_figure( - imaging_7x7, plot_path, plot_patch -): - mat_plot_1d = aplt.MatPlot1D(output=aplt.Output(plot_path, format="png")) - - plotter_0 = MockYX1DPlotter( - y=aa.Array1D.no_mask([1.0, 2.0, 3.0], pixel_scales=1.0), - x=aa.Array1D.no_mask([0.5, 1.0, 1.5], pixel_scales=0.5), - mat_plot_1d=mat_plot_1d, - ) - - plotter_1 = MockYX1DPlotter( - y=aa.Array1D.no_mask([1.0, 2.0, 4.0], pixel_scales=1.0), - x=aa.Array1D.no_mask([0.5, 1.0, 1.5], pixel_scales=0.5), - mat_plot_1d=mat_plot_1d, - ) - - multi_plotter = aplt.MultiYX1DPlotter(plotter_list=[plotter_0, plotter_1]) - multi_plotter.figure_1d(func_name="figures_1d", figure_name="figure_name") - - assert path.join(plot_path, "multi_figure_name.png") in plot_patch.paths - - -def test__multi_yx_plotter__xticks_span_all_plotter_ranges(): - plotter_0 = MockYX1DPlotter( - y=aa.Array1D.no_mask([1.0, 2.0, 3.0], pixel_scales=1.0), - x=aa.Array1D.no_mask([0.5, 1.0, 1.5], pixel_scales=0.5), - ) - - plotter_1 = MockYX1DPlotter( - y=aa.Array1D.no_mask([1.0, 2.0, 4.0], pixel_scales=1.0), - x=aa.Array1D.no_mask([0.25, 1.0, 1.5], pixel_scales=0.5), - ) - - multi_plotter = aplt.MultiYX1DPlotter(plotter_list=[plotter_0, plotter_1]) - - assert multi_plotter.xticks.manual_min_max_value == (0.25, 1.5) - assert multi_plotter.yticks.manual_min_max_value == (1.0, 4.0) +# MultiFigurePlotter and MultiYX1DPlotter have been removed. +# Users should write their own matplotlib code for multi-panel plots. diff --git a/test_autoarray/plot/visuals/test_visuals.py b/test_autoarray/plot/visuals/test_visuals.py index c5449089c..9976631c5 100644 --- a/test_autoarray/plot/visuals/test_visuals.py +++ b/test_autoarray/plot/visuals/test_visuals.py @@ -1,24 +1,2 @@ -import autoarray.plot as aplt - - -def test__add_visuals_together__replaces_nones(): - visuals_1 = aplt.Visuals2D(mask=1) - visuals_0 = aplt.Visuals2D(border=10) - - visuals = visuals_0 + visuals_1 - - assert visuals.mask == 1 - assert visuals.border == 10 - assert visuals_1.mask == 1 - assert visuals_1.border == 10 - assert visuals_0.border == 10 - assert visuals_0.mask == None - - visuals_0 = aplt.Visuals2D(mask=1) - visuals_1 = aplt.Visuals2D(mask=2) - - visuals = visuals_1 + visuals_0 - - assert visuals.mask == 1 - assert visuals.border == None - assert visuals_1.mask == 2 +# Visuals classes (Visuals1D, Visuals2D) have been removed. +# Overlay objects are now passed directly to Plotter constructors. diff --git a/test_autoarray/plot/wrap/base/test_abstract.py b/test_autoarray/plot/wrap/base/test_abstract.py deleted file mode 100644 index f118b967c..000000000 --- a/test_autoarray/plot/wrap/base/test_abstract.py +++ /dev/null @@ -1,27 +0,0 @@ -import autoarray.plot as aplt - - -def test__from_config_or_via_manual_input(): - # Testing for config loading, could be any matplot object but use GridScatter as example - - grid_scatter = aplt.GridScatter() - - assert grid_scatter.config_dict["marker"] == "x" - assert grid_scatter.config_dict["c"] == "y" - - grid_scatter = aplt.GridScatter(marker="x") - - assert grid_scatter.config_dict["marker"] == "x" - assert grid_scatter.config_dict["c"] == "y" - - grid_scatter = aplt.GridScatter() - grid_scatter.is_for_subplot = True - - assert grid_scatter.config_dict["marker"] == "." - assert grid_scatter.config_dict["c"] == "r" - - grid_scatter = aplt.GridScatter(c=["r", "b"]) - grid_scatter.is_for_subplot = True - - assert grid_scatter.config_dict["marker"] == "." - assert grid_scatter.config_dict["c"] == ["r", "b"] diff --git a/test_autoarray/plot/wrap/base/test_annotate.py b/test_autoarray/plot/wrap/base/test_annotate.py deleted file mode 100644 index 7ac5eb1e3..000000000 --- a/test_autoarray/plot/wrap/base/test_annotate.py +++ /dev/null @@ -1,21 +0,0 @@ -import autoarray.plot as aplt - - -def test__loads_values_from_config_if_not_manually_input(): - title = aplt.Annotate() - - assert title.config_dict["fontsize"] == 16 - - title = aplt.Annotate(fontsize=1) - - assert title.config_dict["fontsize"] == 1 - - title = aplt.Annotate() - title.is_for_subplot = True - - assert title.config_dict["fontsize"] == 10 - - title = aplt.Annotate(fontsize=2) - title.is_for_subplot = True - - assert title.config_dict["fontsize"] == 2 diff --git a/test_autoarray/plot/wrap/base/test_axis.py b/test_autoarray/plot/wrap/base/test_axis.py deleted file mode 100644 index f9739cb26..000000000 --- a/test_autoarray/plot/wrap/base/test_axis.py +++ /dev/null @@ -1,34 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - - -def test__loads_values_from_config_if_not_manually_input(): - axis = aplt.Axis() - - assert axis.config_dict["emit"] is True - - axis = aplt.Axis(emit=False) - - assert axis.config_dict["emit"] is False - - axis = aplt.Axis() - axis.is_for_subplot = True - - assert axis.config_dict["emit"] is False - - axis = aplt.Axis(emit=True) - axis.is_for_subplot = True - - assert axis.config_dict["emit"] is True - - -def test__sets_axis_correct_for_different_settings(): - axis = aplt.Axis(symmetric_source_centre=False) - - axis.set(extent=[0.1, 0.2, 0.3, 0.4]) - - axis = aplt.Axis(symmetric_source_centre=True) - - grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=2.0) - - axis.set(extent=[0.1, 0.2, 0.3, 0.4], grid=grid) diff --git a/test_autoarray/plot/wrap/base/test_cmap.py b/test_autoarray/plot/wrap/base/test_cmap.py deleted file mode 100644 index a666eb1b1..000000000 --- a/test_autoarray/plot/wrap/base/test_cmap.py +++ /dev/null @@ -1,108 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - -import matplotlib.colors as colors - - -def test__loads_values_from_config_if_not_manually_input(): - cmap = aplt.Cmap() - - assert cmap.config_dict["cmap"] == "default" - assert cmap.config_dict["norm"] == "linear" - - cmap = aplt.Cmap(cmap="cold") - - assert cmap.config_dict["cmap"] == "cold" - assert cmap.config_dict["norm"] == "linear" - - cmap = aplt.Cmap() - cmap.is_for_subplot = True - - assert cmap.config_dict["cmap"] == "default" - assert cmap.config_dict["norm"] == "linear" - - cmap = aplt.Cmap(cmap="cold") - cmap.is_for_subplot = True - - assert cmap.config_dict["cmap"] == "cold" - assert cmap.config_dict["norm"] == "linear" - - -def test__norm_from__uses_input_vmin_and_max_if_input(): - cmap = aplt.Cmap(vmin=0.0, vmax=1.0, norm="linear") - - norm = cmap.norm_from(array=None) - - assert isinstance(norm, colors.Normalize) - assert norm.vmin == 0.0 - assert norm.vmax == 1.0 - - cmap = aplt.Cmap(vmin=0.0, vmax=1.0, norm="log") - - norm = cmap.norm_from(array=None) - - assert isinstance(norm, colors.LogNorm) - assert norm.vmin == 1.0e-4 # Increased from 0.0 to ensure min isn't inf - assert norm.vmax == 1.0 - - cmap = aplt.Cmap( - vmin=0.0, vmax=1.0, linthresh=2.0, linscale=3.0, norm="symmetric_log" - ) - - norm = cmap.norm_from(array=None) - - assert isinstance(norm, colors.SymLogNorm) - assert norm.vmin == 0.0 - assert norm.vmax == 1.0 - assert norm.linthresh == 2.0 - - -def test__norm_from__cmap_symmetric_true(): - cmap = aplt.Cmap(vmin=-0.5, vmax=1.0, norm="linear", symmetric=True) - - norm = cmap.norm_from(array=None) - - assert isinstance(norm, colors.Normalize) - assert norm.vmin == -1.0 - assert norm.vmax == 1.0 - - cmap = aplt.Cmap(vmin=-2.0, vmax=1.0, norm="linear") - cmap = cmap.symmetric_cmap_from() - - norm = cmap.norm_from(array=None) - - assert isinstance(norm, colors.Normalize) - assert norm.vmin == -2.0 - assert norm.vmax == 2.0 - - -def test__norm_from__uses_array_to_get_vmin_and_max_if_no_manual_input(): - array = aa.Array2D.ones(shape_native=(2, 2), pixel_scales=1.0) - array[0] = 0.0 - - cmap = aplt.Cmap(vmin=None, vmax=None, norm="linear") - - norm = cmap.norm_from(array=array) - - assert isinstance(norm, colors.Normalize) - assert norm.vmin == 0.0 - assert norm.vmax == 1.0 - - cmap = aplt.Cmap(vmin=None, vmax=None, norm="log") - - norm = cmap.norm_from(array=array) - - assert isinstance(norm, colors.LogNorm) - assert norm.vmin == 1.0e-4 # Increased from 0.0 to ensure min isn't inf - assert norm.vmax == 1.0 - - cmap = aplt.Cmap( - vmin=None, vmax=None, linthresh=2.0, linscale=3.0, norm="symmetric_log" - ) - - norm = cmap.norm_from(array=array) - - assert isinstance(norm, colors.SymLogNorm) - assert norm.vmin == 0.0 - assert norm.vmax == 1.0 - assert norm.linthresh == 2.0 diff --git a/test_autoarray/plot/wrap/base/test_colorbar.py b/test_autoarray/plot/wrap/base/test_colorbar.py deleted file mode 100644 index 44e53778f..000000000 --- a/test_autoarray/plot/wrap/base/test_colorbar.py +++ /dev/null @@ -1,58 +0,0 @@ -import autoarray.plot as aplt - -import matplotlib.pyplot as plt -import numpy as np - - -def test__loads_values_from_config_if_not_manually_input(): - colorbar = aplt.Colorbar() - - assert colorbar.config_dict["fraction"] == 3.0 - assert colorbar.manual_tick_values == None - assert colorbar.manual_tick_labels == None - - colorbar = aplt.Colorbar( - manual_tick_values=(1.0, 2.0), manual_tick_labels=(3.0, 4.0) - ) - - assert colorbar.manual_tick_values == (1.0, 2.0) - assert colorbar.manual_tick_labels == (3.0, 4.0) - - colorbar = aplt.Colorbar() - colorbar.is_for_subplot = True - - assert colorbar.config_dict["fraction"] == 0.1 - - colorbar = aplt.Colorbar(fraction=6.0) - colorbar.is_for_subplot = True - - assert colorbar.config_dict["fraction"] == 6.0 - - -def test__plot__works_for_reasonable_range_of_values(): - figure = aplt.Figure() - - fig, ax = figure.open() - plt.imshow(np.ones((2, 2))) - cb = aplt.Colorbar(fraction=1.0, pad=2.0) - cb.set(ax=ax, units=None) - figure.close() - - fig, ax = figure.open() - plt.imshow(np.ones((2, 2))) - cb = aplt.Colorbar( - fraction=0.1, - pad=0.5, - manual_tick_values=[0.25, 0.5, 0.75], - manual_tick_labels=[1.0, 2.0, 3.0], - ) - cb.set(ax=ax, units=aplt.Units()) - figure.close() - - fig, ax = figure.open() - plt.imshow(np.ones((2, 2))) - cb = aplt.Colorbar(fraction=0.1, pad=0.5) - cb.set_with_color_values( - cmap=aplt.Cmap().cmap, color_values=[1.0, 2.0, 3.0], ax=ax, units=None - ) - figure.close() diff --git a/test_autoarray/plot/wrap/base/test_colorbar_tickparams.py b/test_autoarray/plot/wrap/base/test_colorbar_tickparams.py deleted file mode 100644 index f5cc08208..000000000 --- a/test_autoarray/plot/wrap/base/test_colorbar_tickparams.py +++ /dev/null @@ -1,21 +0,0 @@ -import autoarray.plot as aplt - - -def test__loads_values_from_config_if_not_manually_input(): - colorbar_tickparams = aplt.ColorbarTickParams() - - assert colorbar_tickparams.config_dict["labelsize"] == 1 - - colorbar_tickparams = aplt.ColorbarTickParams(labelsize=20) - - assert colorbar_tickparams.config_dict["labelsize"] == 20 - - colorbar_tickparams = aplt.ColorbarTickParams() - colorbar_tickparams.is_for_subplot = True - - assert colorbar_tickparams.config_dict["labelsize"] == 1 - - colorbar_tickparams = aplt.ColorbarTickParams(labelsize=10) - colorbar_tickparams.is_for_subplot = True - - assert colorbar_tickparams.config_dict["labelsize"] == 10 diff --git a/test_autoarray/plot/wrap/base/test_figure.py b/test_autoarray/plot/wrap/base/test_figure.py deleted file mode 100644 index 7e0ff12d4..000000000 --- a/test_autoarray/plot/wrap/base/test_figure.py +++ /dev/null @@ -1,59 +0,0 @@ -import autoarray.plot as aplt - -from os import path - -import matplotlib.pyplot as plt - - -def test__loads_values_from_config_if_not_manually_input(): - figure = aplt.Figure() - - assert figure.config_dict["figsize"] == (7, 7) - assert figure.config_dict["aspect"] == "square" - - figure = aplt.Figure(aspect="auto") - - assert figure.config_dict["figsize"] == (7, 7) - assert figure.config_dict["aspect"] == "auto" - - figure = aplt.Figure() - figure.is_for_subplot = True - - assert figure.config_dict["figsize"] == None - assert figure.config_dict["aspect"] == "square" - - figure = aplt.Figure(figsize=(6, 6)) - figure.is_for_subplot = True - - assert figure.config_dict["figsize"] == (6, 6) - assert figure.config_dict["aspect"] == "square" - - -def test__aspect_from(): - figure = aplt.Figure(aspect="auto") - - aspect = figure.aspect_from(shape_native=(2, 2)) - - assert aspect == "auto" - - figure = aplt.Figure(aspect="square") - - aspect = figure.aspect_from(shape_native=(2, 2)) - - assert aspect == 1.0 - - aspect = figure.aspect_from(shape_native=(4, 2)) - - assert aspect == 0.5 - - -def test__open_and_close__open_and_close_figures_correct(): - figure = aplt.Figure() - - figure.open() - - assert plt.fignum_exists(num=1) is True - - figure.close() - - assert plt.fignum_exists(num=1) is False diff --git a/test_autoarray/plot/wrap/base/test_label.py b/test_autoarray/plot/wrap/base/test_label.py deleted file mode 100644 index 9a6202d66..000000000 --- a/test_autoarray/plot/wrap/base/test_label.py +++ /dev/null @@ -1,41 +0,0 @@ -import autoarray.plot as aplt - - -def test__ylabel__loads_values_from_config_if_not_manually_input(): - ylabel = aplt.YLabel() - - assert ylabel.config_dict["fontsize"] == 1 - - ylabel = aplt.YLabel(fontsize=11) - - assert ylabel.config_dict["fontsize"] == 11 - - ylabel = aplt.YLabel() - ylabel.is_for_subplot = True - - assert ylabel.config_dict["fontsize"] == 2 - - ylabel = aplt.YLabel(fontsize=12) - ylabel.is_for_subplot = True - - assert ylabel.config_dict["fontsize"] == 12 - - -def test__xlabel__loads_values_from_config_if_not_manually_input(): - xlabel = aplt.XLabel() - - assert xlabel.config_dict["fontsize"] == 3 - - xlabel = aplt.XLabel(fontsize=11) - - assert xlabel.config_dict["fontsize"] == 11 - - xlabel = aplt.XLabel() - xlabel.is_for_subplot = True - - assert xlabel.config_dict["fontsize"] == 4 - - xlabel = aplt.XLabel(fontsize=12) - xlabel.is_for_subplot = True - - assert xlabel.config_dict["fontsize"] == 12 diff --git a/test_autoarray/plot/wrap/base/test_legend.py b/test_autoarray/plot/wrap/base/test_legend.py deleted file mode 100644 index ceb275e45..000000000 --- a/test_autoarray/plot/wrap/base/test_legend.py +++ /dev/null @@ -1,19 +0,0 @@ -import autoarray.plot as aplt - - -def test__set_legend_works_for_plot(): - figure = aplt.Figure(aspect="auto") - - figure.open() - - line = aplt.YXPlot(linewidth=2, linestyle="-", c="k") - - line.plot_y_vs_x( - y=[1.0, 2.0, 3.0], x=[1.0, 2.0, 3.0], plot_axis_type="linear", label="hi" - ) - - legend = aplt.Legend(fontsize=1) - - legend.set() - - figure.close() diff --git a/test_autoarray/plot/wrap/base/test_text.py b/test_autoarray/plot/wrap/base/test_text.py deleted file mode 100644 index 463620255..000000000 --- a/test_autoarray/plot/wrap/base/test_text.py +++ /dev/null @@ -1,21 +0,0 @@ -import autoarray.plot as aplt - - -def test__loads_values_from_config_if_not_manually_input(): - title = aplt.Text() - - assert title.config_dict["fontsize"] == 16 - - title = aplt.Text(fontsize=1) - - assert title.config_dict["fontsize"] == 1 - - title = aplt.Text() - title.is_for_subplot = True - - assert title.config_dict["fontsize"] == 10 - - title = aplt.Text(fontsize=2) - title.is_for_subplot = True - - assert title.config_dict["fontsize"] == 2 diff --git a/test_autoarray/plot/wrap/base/test_tickparams.py b/test_autoarray/plot/wrap/base/test_tickparams.py deleted file mode 100644 index 54dd2d966..000000000 --- a/test_autoarray/plot/wrap/base/test_tickparams.py +++ /dev/null @@ -1,20 +0,0 @@ -import autoarray.plot as aplt - - -def test__loads_values_from_config_if_not_manually_input(): - tick_params = aplt.TickParams() - - assert tick_params.config_dict["labelsize"] == 16 - - tick_params = aplt.TickParams(labelsize=24) - assert tick_params.config_dict["labelsize"] == 24 - - tick_params = aplt.TickParams() - tick_params.is_for_subplot = True - - assert tick_params.config_dict["labelsize"] == 10 - - tick_params = aplt.TickParams(labelsize=25) - tick_params.is_for_subplot = True - - assert tick_params.config_dict["labelsize"] == 25 diff --git a/test_autoarray/plot/wrap/base/test_ticks.py b/test_autoarray/plot/wrap/base/test_ticks.py deleted file mode 100644 index 51f8174e8..000000000 --- a/test_autoarray/plot/wrap/base/test_ticks.py +++ /dev/null @@ -1,125 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - -from autoarray.plot.wrap.base.ticks import LabelMaker - - -def test__labels_with_suffix_from(): - label_maker = LabelMaker( - tick_values=[1.0, 2.0, 3.0], - min_value=1.0, - max_value=3.0, - units=aplt.Units(use_scaled=False), - manual_suffix="", - ) - - labels = label_maker.with_appended_suffix(labels=["hi", "hello"]) - - assert labels == ["hi", "hello"] - - label_maker = LabelMaker( - tick_values=[1.0, 2.0, 3.0], - min_value=1.0, - max_value=3.0, - units=aplt.Units(use_scaled=False), - manual_suffix="11", - ) - - labels = label_maker.with_appended_suffix(labels=["hi", "hello"]) - - assert labels == ["hi11", "hello11"] - - -def test__yticks_loads_values_from_config_if_not_manually_input(): - yticks = aplt.YTicks() - - assert yticks.config_dict["fontsize"] == 16 - assert yticks.manual_values == None - assert yticks.manual_values == None - - yticks = aplt.YTicks(fontsize=24, manual_values=[1.0, 2.0]) - - assert yticks.config_dict["fontsize"] == 24 - assert yticks.manual_values == [1.0, 2.0] - - yticks = aplt.YTicks() - yticks.is_for_subplot = True - - assert yticks.config_dict["fontsize"] == 10 - assert yticks.manual_values == None - - yticks = aplt.YTicks(fontsize=25, manual_values=[1.0, 2.0]) - yticks.is_for_subplot = True - - assert yticks.config_dict["fontsize"] == 25 - assert yticks.manual_values == [1.0, 2.0] - - -def test__yticks__set(): - array = aa.Array2D.ones(shape_native=(2, 2), pixel_scales=1.0) - units = aplt.Units(use_scaled=True, ticks_convert_factor=None) - - yticks = aplt.YTicks(fontsize=34) - zoom = aa.Zoom2D(mask=array.mask) - array_zoom = zoom.array_2d_from(array=array, buffer=1) - extent = array_zoom.geometry.extent - yticks.set(min_value=extent[2], max_value=extent[3], units=units) - - yticks = aplt.YTicks(fontsize=34) - units = aplt.Units(use_scaled=False, ticks_convert_factor=None) - yticks.set(min_value=extent[2], max_value=extent[3], pixels=2, units=units) - - yticks = aplt.YTicks(fontsize=34) - units = aplt.Units(use_scaled=True, ticks_convert_factor=2.0) - yticks.set(min_value=extent[2], max_value=extent[3], units=units) - - yticks = aplt.YTicks(fontsize=34) - units = aplt.Units(use_scaled=False, ticks_convert_factor=2.0) - yticks.set(min_value=extent[2], max_value=extent[3], pixels=2, units=units) - - -def test__xticks_loads_values_from_config_if_not_manually_input(): - xticks = aplt.XTicks() - - assert xticks.config_dict["fontsize"] == 17 - assert xticks.manual_values == None - assert xticks.manual_values == None - - xticks = aplt.XTicks(fontsize=24, manual_values=[1.0, 2.0]) - - assert xticks.config_dict["fontsize"] == 24 - assert xticks.manual_values == [1.0, 2.0] - - xticks = aplt.XTicks() - xticks.is_for_subplot = True - - assert xticks.config_dict["fontsize"] == 11 - assert xticks.manual_values == None - - xticks = aplt.XTicks(fontsize=25, manual_values=[1.0, 2.0]) - xticks.is_for_subplot = True - - assert xticks.config_dict["fontsize"] == 25 - assert xticks.manual_values == [1.0, 2.0] - - -def test__xticks__set(): - array = aa.Array2D.ones(shape_native=(2, 2), pixel_scales=1.0) - units = aplt.Units(use_scaled=True, ticks_convert_factor=None) - xticks = aplt.XTicks(fontsize=34) - zoom = aa.Zoom2D(mask=array.mask) - array_zoom = zoom.array_2d_from(array=array, buffer=1) - extent = array_zoom.geometry.extent - xticks.set(min_value=extent[0], max_value=extent[1], units=units) - - xticks = aplt.XTicks(fontsize=34) - units = aplt.Units(use_scaled=False, ticks_convert_factor=None) - xticks.set(min_value=extent[0], max_value=extent[1], pixels=2, units=units) - - xticks = aplt.XTicks(fontsize=34) - units = aplt.Units(use_scaled=True, ticks_convert_factor=2.0) - xticks.set(min_value=extent[0], max_value=extent[1], units=units) - - xticks = aplt.XTicks(fontsize=34) - units = aplt.Units(use_scaled=False, ticks_convert_factor=2.0) - xticks.set(min_value=extent[0], max_value=extent[1], pixels=2, units=units) diff --git a/test_autoarray/plot/wrap/base/test_title.py b/test_autoarray/plot/wrap/base/test_title.py deleted file mode 100644 index c521e628c..000000000 --- a/test_autoarray/plot/wrap/base/test_title.py +++ /dev/null @@ -1,25 +0,0 @@ -import autoarray.plot as aplt - - -def test__loads_values_from_config_if_not_manually_input(): - title = aplt.Title() - - assert title.manual_label == None - assert title.config_dict["fontsize"] == 11 - - title = aplt.Title(label="OMG", fontsize=1) - - assert title.manual_label == "OMG" - assert title.config_dict["fontsize"] == 1 - - title = aplt.Title() - title.is_for_subplot = True - - assert title.manual_label == None - assert title.config_dict["fontsize"] == 15 - - title = aplt.Title(label="OMG2", fontsize=2) - title.is_for_subplot = True - - assert title.manual_label == "OMG2" - assert title.config_dict["fontsize"] == 2 diff --git a/test_autoarray/plot/wrap/base/test_units.py b/test_autoarray/plot/wrap/base/test_units.py deleted file mode 100644 index 85d8a1d24..000000000 --- a/test_autoarray/plot/wrap/base/test_units.py +++ /dev/null @@ -1,13 +0,0 @@ -import autoarray.plot as aplt - - -def test__loads_values_from_config_if_not_manually_input(): - units = aplt.Units() - - assert units.use_scaled is True - assert units.ticks_convert_factor == None - - units = aplt.Units(ticks_convert_factor=2.0) - - assert units.use_scaled is True - assert units.ticks_convert_factor == 2.0 diff --git a/test_autoarray/plot/wrap/one_d/test_axvline.py b/test_autoarray/plot/wrap/one_d/test_axvline.py deleted file mode 100644 index 738178f0c..000000000 --- a/test_autoarray/plot/wrap/one_d/test_axvline.py +++ /dev/null @@ -1,10 +0,0 @@ -import autoarray.plot as aplt - - -def test__plot_vertical_lines__works_for_reasonable_values(): - line = aplt.AXVLine(linewidth=2, linestyle="-", c="k") - - line.axvline_vertical_line(vertical_line=0.0, label="hi") - line.axvline_vertical_line( - vertical_line=0.0, vertical_errors=[-1.0, 1.0], label="hi" - ) diff --git a/test_autoarray/plot/wrap/one_d/test_fill_between.py b/test_autoarray/plot/wrap/one_d/test_fill_between.py deleted file mode 100644 index 500151b86..000000000 --- a/test_autoarray/plot/wrap/one_d/test_fill_between.py +++ /dev/null @@ -1,9 +0,0 @@ -import autoarray.plot as aplt - - -def test__plot_y_vs_x__works_for_reasonable_values(): - fill_between = aplt.FillBetween() - - fill_between.fill_between_shaded_regions( - x=[1, 2, 3], y1=[1.0, 2.0, 3.0], y2=[2.0, 3.0, 4.0] - ) diff --git a/test_autoarray/plot/wrap/one_d/test_yx_plot.py b/test_autoarray/plot/wrap/one_d/test_yx_plot.py deleted file mode 100644 index fc5334b3d..000000000 --- a/test_autoarray/plot/wrap/one_d/test_yx_plot.py +++ /dev/null @@ -1,30 +0,0 @@ -import autoarray.plot as aplt - - -def test__plot_y_vs_x__works_for_reasonable_values(): - line = aplt.YXPlot(linewidth=2, linestyle="-", c="k") - - line.plot_y_vs_x(y=[1.0, 2.0, 3.0], x=[1.0, 2.0, 3.0], plot_axis_type="linear") - line.plot_y_vs_x(y=[1.0, 2.0, 3.0], x=[1.0, 2.0, 3.0], plot_axis_type="semilogy") - line.plot_y_vs_x(y=[1.0, 2.0, 3.0], x=[1.0, 2.0, 3.0], plot_axis_type="loglog") - - line = aplt.YXPlot(c="k") - - line.plot_y_vs_x(y=[1.0, 2.0, 3.0], x=[1.0, 2.0, 3.0], plot_axis_type="scatter") - - line.plot_y_vs_x(y=[1.0, 2.0, 3.0], x=[1.0, 2.0, 3.0], plot_axis_type="errorbar") - - line.plot_y_vs_x( - y=[1.0, 2.0, 3.0], - x=[1.0, 2.0, 3.0], - plot_axis_type="errorbar", - y_errors=[1.0, 1.0, 1.0], - ) - - line.plot_y_vs_x( - y=[1.0, 2.0, 3.0], - x=[1.0, 2.0, 3.0], - plot_axis_type="errorbar", - y_errors=[1.0, 1.0, 1.0], - x_errors=[1.0, 1.0, 1.0], - ) diff --git a/test_autoarray/plot/wrap/one_d/test_yx_scatter.py b/test_autoarray/plot/wrap/one_d/test_yx_scatter.py deleted file mode 100644 index 9f4e85f36..000000000 --- a/test_autoarray/plot/wrap/one_d/test_yx_scatter.py +++ /dev/null @@ -1,7 +0,0 @@ -import autoarray.plot as aplt - - -def test__scatter_y_vs_x__works_for_reasonable_values(): - yx_scatter = aplt.YXScatter(linewidth=2, linestyle="-", c="k") - - yx_scatter.scatter_yx(y=[1.0, 2.0, 3.0], x=[1.0, 2.0, 3.0]) diff --git a/test_autoarray/plot/wrap/two_d/test_array_overlay.py b/test_autoarray/plot/wrap/two_d/test_array_overlay.py deleted file mode 100644 index ae43024d6..000000000 --- a/test_autoarray/plot/wrap/two_d/test_array_overlay.py +++ /dev/null @@ -1,14 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - - -def test__overlay_array__works_for_reasonable_values(): - arr = aa.Array2D.no_mask( - values=[[1.0, 2.0], [3.0, 4.0]], pixel_scales=0.5, origin=(2.0, 2.0) - ) - - figure = aplt.Figure(aspect="auto") - - array_overlay = aplt.ArrayOverlay(alpha=0.5) - - array_overlay.overlay_array(array=arr, figure=figure) diff --git a/test_autoarray/plot/wrap/two_d/test_contour.py b/test_autoarray/plot/wrap/two_d/test_contour.py deleted file mode 100644 index b5decac65..000000000 --- a/test_autoarray/plot/wrap/two_d/test_contour.py +++ /dev/null @@ -1,12 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - - -def test__contour__works_for_reasonable_values(): - arr = aa.Array2D.no_mask( - values=[[1.0, 2.0], [3.0, 4.0]], pixel_scales=0.5, origin=(2.0, 2.0) - ) - - contour = aplt.Contour() - - contour.set(array=arr, extent=[0.0, 1.0, 0.0, 1.0]) diff --git a/test_autoarray/plot/wrap/two_d/test_delaunay_drawer.py b/test_autoarray/plot/wrap/two_d/test_delaunay_drawer.py deleted file mode 100644 index f7f9e4521..000000000 --- a/test_autoarray/plot/wrap/two_d/test_delaunay_drawer.py +++ /dev/null @@ -1,26 +0,0 @@ -import autoarray.plot as aplt - -import numpy as np - - -def test__draws_delaunay_pixels_for_sensible_input(delaunay_mapper_9_3x3): - delaunay_drawer = aplt.DelaunayDrawer(linewidth=0.5, edgecolor="r", alpha=1.0) - - delaunay_drawer.draw_delaunay_pixels( - mapper=delaunay_mapper_9_3x3, - pixel_values=np.ones(9), - units=aplt.Units(), - cmap=aplt.Cmap(), - colorbar=None, - ) - - values = np.ones(9) - values[0] = 0.0 - - delaunay_drawer.draw_delaunay_pixels( - mapper=delaunay_mapper_9_3x3, - pixel_values=values, - units=aplt.Units(), - cmap=aplt.Cmap(), - colorbar=aplt.Colorbar(fraction=0.1, pad=0.05), - ) diff --git a/test_autoarray/plot/wrap/two_d/test_derived.py b/test_autoarray/plot/wrap/two_d/test_derived.py deleted file mode 100644 index 4a5eb2134..000000000 --- a/test_autoarray/plot/wrap/two_d/test_derived.py +++ /dev/null @@ -1,63 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - - -def test__all_class_load_and_inherit_correctly(grid_2d_irregular_7x7_list): - origin_scatter = aplt.OriginScatter() - origin_scatter.scatter_grid( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0) - ) - - assert origin_scatter.config_dict["s"] == 80 - - mask_scatter = aplt.MaskScatter() - mask_scatter.scatter_grid( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0) - ) - - assert mask_scatter.config_dict["s"] == 12 - - border_scatter = aplt.BorderScatter() - border_scatter.scatter_grid( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0) - ) - - assert border_scatter.config_dict["s"] == 13 - - positions_scatter = aplt.PositionsScatter() - positions_scatter.scatter_grid(grid=grid_2d_irregular_7x7_list) - - assert positions_scatter.config_dict["s"] == 15 - - index_scatter = aplt.IndexScatter() - index_scatter.scatter_grid_list(grid_list=grid_2d_irregular_7x7_list) - - assert index_scatter.config_dict["s"] == 20 - - mesh_grid_scatter = aplt.MeshGridScatter() - mesh_grid_scatter.scatter_grid( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0) - ) - - assert mesh_grid_scatter.config_dict["s"] == 5 - - parallel_overscan_plot = aplt.ParallelOverscanPlot() - parallel_overscan_plot.plot_rectangular_grid_lines( - extent=[0.0, 1.0, 0.0, 1.0], shape_native=(3, 2) - ) - - assert parallel_overscan_plot.config_dict["linewidth"] == 1 - - serial_overscan_plot = aplt.SerialOverscanPlot() - serial_overscan_plot.plot_rectangular_grid_lines( - extent=[0.0, 1.0, 0.0, 1.0], shape_native=(3, 2) - ) - - assert serial_overscan_plot.config_dict["linewidth"] == 2 - - serial_prescan_plot = aplt.SerialPrescanPlot() - serial_prescan_plot.plot_rectangular_grid_lines( - extent=[0.0, 1.0, 0.0, 1.0], shape_native=(3, 2) - ) - - assert serial_prescan_plot.config_dict["linewidth"] == 3 diff --git a/test_autoarray/plot/wrap/two_d/test_grid_errorbar.py b/test_autoarray/plot/wrap/two_d/test_grid_errorbar.py deleted file mode 100644 index d165037a6..000000000 --- a/test_autoarray/plot/wrap/two_d/test_grid_errorbar.py +++ /dev/null @@ -1,28 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - - -def test__errorbar_grid(): - errorbar = aplt.GridErrorbar(marker="x", c="k") - - errorbar.errorbar_grid( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0) - ) - - errorbar = aplt.GridErrorbar(marker="x", c="k") - - errorbar.errorbar_grid( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0), - y_errors=[1.0] * 9, - x_errors=[1.0] * 9, - ) - - -def test__errorbar_coordinates(): - errorbar = aplt.GridErrorbar(marker="x", c="k") - - errorbar.errorbar_grid_list( - grid_list=[aa.Grid2DIrregular([(1.0, 1.0), (2.0, 2.0)])], - y_errors=[1.0] * 2, - x_errors=[1.0] * 2, - ) diff --git a/test_autoarray/plot/wrap/two_d/test_grid_plot.py b/test_autoarray/plot/wrap/two_d/test_grid_plot.py deleted file mode 100644 index c297b8ddc..000000000 --- a/test_autoarray/plot/wrap/two_d/test_grid_plot.py +++ /dev/null @@ -1,49 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - -import matplotlib.pyplot as plt -import numpy as np - - -def test__plot_rectangular_grid_lines__draws_for_valid_extent_and_shape(): - line = aplt.GridPlot(linewidth=2, linestyle="--", c="k") - - line.plot_rectangular_grid_lines(extent=[0.0, 1.0, 0.0, 1.0], shape_native=(3, 2)) - line.plot_rectangular_grid_lines( - extent=[-4.0, 8.0, -3.0, 10.0], shape_native=(8, 3) - ) - - -def test__plot_grid_list(): - line = aplt.GridPlot(linewidth=2, linestyle="--", c="k") - - line.plot_grid_list(grid_list=[aa.Grid2DIrregular([(1.0, 1.0), (2.0, 2.0)])]) - line.plot_grid_list( - grid_list=[ - aa.Grid2DIrregular([(1.0, 1.0), (2.0, 2.0)]), - aa.Grid2DIrregular([(3.0, 3.0)]), - ] - ) - - -def test__errorbar_colored_grid__lists_of_coordinates_or_equivalent_2d_grids__with_color_array(): - errorbar = aplt.GridErrorbar(marker="x", c="k") - - cmap = plt.get_cmap("jet") - - errorbar.errorbar_grid_colored( - grid=aa.Grid2DIrregular( - [(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, 4.0), (5.0, 5.0)] - ), - color_array=np.array([2.0, 2.0, 2.0, 2.0, 2.0]), - y_errors=[1.0] * 5, - x_errors=[1.0] * 5, - cmap=cmap, - ) - errorbar.errorbar_grid_colored( - grid=aa.Grid2D.uniform(shape_native=(3, 2), pixel_scales=1.0), - color_array=np.array([2.0, 2.0, 2.0, 2.0, 2.0, 2.0]), - cmap=cmap, - y_errors=[1.0] * 6, - x_errors=[1.0] * 6, - ) diff --git a/test_autoarray/plot/wrap/two_d/test_grid_scatter.py b/test_autoarray/plot/wrap/two_d/test_grid_scatter.py deleted file mode 100644 index 6822a8d37..000000000 --- a/test_autoarray/plot/wrap/two_d/test_grid_scatter.py +++ /dev/null @@ -1,79 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - -import matplotlib.pyplot as plt -import numpy as np - - -def test__scatter_grid(): - scatter = aplt.GridScatter(s=2, marker="x", c="k") - - scatter.scatter_grid(grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0)) - - -def test__scatter_colored_grid__lists_of_coordinates_or_equivalent_2d_grids__with_color_array(): - scatter = aplt.GridScatter(s=2, marker="x", c="k") - - cmap = plt.get_cmap("jet") - - scatter.scatter_grid_colored( - grid=aa.Grid2DIrregular( - [(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, 4.0), (5.0, 5.0)] - ), - color_array=np.array([2.0, 2.0, 2.0, 2.0, 2.0]), - cmap=cmap, - ) - scatter.scatter_grid_colored( - grid=aa.Grid2D.uniform(shape_native=(3, 2), pixel_scales=1.0), - color_array=np.array([2.0, 2.0, 2.0, 2.0, 2.0, 2.0]), - cmap=cmap, - ) - - -def test__scatter_grid_indexes_1d__input_grid_is_ndarray_and_indexes_are_valid(): - scatter = aplt.GridScatter(s=2, marker="x", c="k") - - scatter.scatter_grid_indexes( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0), - indexes=[0, 1, 2], - ) - - scatter.scatter_grid_indexes( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0), - indexes=[[0, 1, 2]], - ) - - scatter.scatter_grid_indexes( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0), - indexes=[[0, 1], [2]], - ) - - -def test__scatter_grid_indexes_2d__input_grid_is_ndarray_and_indexes_are_valid(): - scatter = aplt.GridScatter(s=2, marker="x", c="k") - - scatter.scatter_grid_indexes( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0), - indexes=[(0, 0), (0, 1), (0, 2)], - ) - - scatter.scatter_grid_indexes( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0), - indexes=[[(0, 0), (0, 1), (0, 2)]], - ) - - scatter.scatter_grid_indexes( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0), - indexes=[[(0, 0), (0, 1)], [(0, 2)]], - ) - - scatter.scatter_grid_indexes( - grid=aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0), - indexes=[[[0, 0], [0, 1]], [[0, 2]]], - ) - - -def test__scatter_coordinates(): - scatter = aplt.GridScatter(s=2, marker="x", c="k") - - scatter.scatter_grid_list(grid_list=[aa.Grid2DIrregular([(1.0, 1.0), (2.0, 2.0)])]) diff --git a/test_autoarray/plot/wrap/two_d/test_patcher.py b/test_autoarray/plot/wrap/two_d/test_patcher.py deleted file mode 100644 index 7062fc3ca..000000000 --- a/test_autoarray/plot/wrap/two_d/test_patcher.py +++ /dev/null @@ -1,12 +0,0 @@ -import autoarray.plot as aplt - -from matplotlib.patches import Ellipse - - -def test__add_patches(): - patch_overlay = aplt.PatchOverlay(facecolor="c", edgecolor="none") - - patch_0 = Ellipse(xy=(1.0, 2.0), height=1.0, width=2.0, angle=1.0) - patch_1 = Ellipse(xy=(1.0, 2.0), height=1.0, width=2.0, angle=1.0) - - patch_overlay.overlay_patches(patches=[patch_0, patch_1]) diff --git a/test_autoarray/plot/wrap/two_d/test_vector_yx_quiver.py b/test_autoarray/plot/wrap/two_d/test_vector_yx_quiver.py deleted file mode 100644 index cdcac6e91..000000000 --- a/test_autoarray/plot/wrap/two_d/test_vector_yx_quiver.py +++ /dev/null @@ -1,20 +0,0 @@ -import autoarray as aa -import autoarray.plot as aplt - - -def test__quiver_vectors(): - quiver = aplt.VectorYXQuiver( - headlength=5, - pivot="middle", - linewidth=3, - units="xy", - angles="xy", - headwidth=6, - alpha=1.0, - ) - - vectors = aa.VectorYX2DIrregular( - values=[(1.0, 2.0), (2.0, 1.0)], grid=[(-1.0, 0.0), (-2.0, 0.0)] - ) - - quiver.quiver_vectors(vectors=vectors) diff --git a/test_autoarray/structures/arrays/test_uniform_2d.py b/test_autoarray/structures/arrays/test_uniform_2d.py index b306bbbeb..9615f8cd2 100644 --- a/test_autoarray/structures/arrays/test_uniform_2d.py +++ b/test_autoarray/structures/arrays/test_uniform_2d.py @@ -247,27 +247,32 @@ def test__constructor__1d_values_too_few_for_mask__raises_array_exception(): aa.Array2D(values=[1.0, 2.0], mask=mask) -@pytest.mark.parametrize("new_shape,expected_native", [ - ( - (7, 7), - np.array( - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 2.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ] +@pytest.mark.parametrize( + "new_shape,expected_native", + [ + ( + (7, 7), + np.array( + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 2.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ), ), - ), - ( - (3, 3), - np.array([[1.0, 1.0, 1.0], [1.0, 2.0, 1.0], [1.0, 1.0, 1.0]]), - ), -]) -def test__resized_from__5x5_array_with_center_marked__resized_array_pads_or_crops_correctly(new_shape, expected_native): + ( + (3, 3), + np.array([[1.0, 1.0, 1.0], [1.0, 2.0, 1.0], [1.0, 1.0, 1.0]]), + ), + ], +) +def test__resized_from__5x5_array_with_center_marked__resized_array_pads_or_crops_correctly( + new_shape, expected_native +): array_2d = np.ones((5, 5)) array_2d[2, 2] = 2.0 @@ -280,11 +285,16 @@ def test__resized_from__5x5_array_with_center_marked__resized_array_pads_or_crop assert array_2d.mask.pixel_scales == (1.0, 1.0) -@pytest.mark.parametrize("kernel_shape,expected_shape", [ - ((3, 3), (7, 7)), - ((5, 5), (9, 9)), -]) -def test__padded_before_convolution_from__5x5_array__output_shape_padded_by_kernel_size(kernel_shape, expected_shape): +@pytest.mark.parametrize( + "kernel_shape,expected_shape", + [ + ((3, 3), (7, 7)), + ((5, 5), (9, 9)), + ], +) +def test__padded_before_convolution_from__5x5_array__output_shape_padded_by_kernel_size( + kernel_shape, expected_shape +): array_2d = np.ones((5, 5)) array_2d[2, 2] = 2.0 @@ -312,11 +322,16 @@ def test__padded_before_convolution_from__9x9_array__output_shape_padded_by_7x7_ assert new_arr.mask.pixel_scales == (1.0, 1.0) -@pytest.mark.parametrize("kernel_shape,expected_native", [ - ((3, 3), np.array([[1.0, 1.0, 1.0], [1.0, 2.0, 1.0], [1.0, 1.0, 1.0]])), - ((5, 5), np.array([[2.0]])), -]) -def test__trimmed_after_convolution_from__5x5_array_with_center_marked__trims_to_non_padded_region(kernel_shape, expected_native): +@pytest.mark.parametrize( + "kernel_shape,expected_native", + [ + ((3, 3), np.array([[1.0, 1.0, 1.0], [1.0, 2.0, 1.0], [1.0, 1.0, 1.0]])), + ((5, 5), np.array([[2.0]])), + ], +) +def test__trimmed_after_convolution_from__5x5_array_with_center_marked__trims_to_non_padded_region( + kernel_shape, expected_native +): array_2d = np.ones((5, 5)) array_2d[2, 2] = 2.0 diff --git a/test_autoarray/structures/plot/test_structure_plotters.py b/test_autoarray/structures/plot/test_structure_plotters.py index 53b796798..0f66bab97 100644 --- a/test_autoarray/structures/plot/test_structure_plotters.py +++ b/test_autoarray/structures/plot/test_structure_plotters.py @@ -1,184 +1,147 @@ -import autoarray as aa -import autoarray.plot as aplt -from os import path -import pytest -import numpy as np -import shutil - -directory = path.dirname(path.realpath(__file__)) - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "structures" - ) - - -def test__plot_yx_line(plot_path, plot_patch): - visuals_1d = aplt.Visuals1D(vertical_line=1.0) - - mat_plot_1d = aplt.MatPlot1D( - yx_plot=aplt.YXPlot(plot_axis_type="loglog", c="k"), - vertical_line_axvline=aplt.AXVLine(c="k"), - output=aplt.Output(path=plot_path, filename="yx_1", format="png"), - ) - - yx_1d_plotter = aplt.YX1DPlotter( - y=aa.Array1D.no_mask([1.0, 2.0, 3.0], pixel_scales=1.0), - x=aa.Array1D.no_mask([0.5, 1.0, 1.5], pixel_scales=0.5), - mat_plot_1d=mat_plot_1d, - visuals_1d=visuals_1d, - ) - - yx_1d_plotter.figure_1d() - - assert path.join(plot_path, "yx_1.png") in plot_patch.paths - - -def test__array( - array_2d_7x7, - mask_2d_7x7, - grid_2d_7x7, - grid_2d_irregular_7x7_list, - plot_path, - plot_patch, -): - array_plotter = aplt.Array2DPlotter( - array=array_2d_7x7, - mat_plot_2d=aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="array1", format="png") - ), - ) - - array_plotter.figure_2d() - - assert path.join(plot_path, "array1.png") in plot_patch.paths - - array_plotter = aplt.Array2DPlotter( - array=array_2d_7x7, - mat_plot_2d=aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="array2", format="png") - ), - ) - - array_plotter.figure_2d() - - assert path.join(plot_path, "array2.png") in plot_patch.paths - - visuals_2d = aplt.Visuals2D( - origin=grid_2d_irregular_7x7_list, - mask=mask_2d_7x7, - border=mask_2d_7x7.derive_grid.border, - grid=grid_2d_7x7, - positions=grid_2d_irregular_7x7_list, - # lines=grid_2d_irregular_7x7_list, - array_overlay=array_2d_7x7, - ) - - array_plotter = aplt.Array2DPlotter( - array=array_2d_7x7, - visuals_2d=visuals_2d, - mat_plot_2d=aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="array3", format="png") - ), - ) - - array_plotter.figure_2d() - - assert path.join(plot_path, "array3.png") in plot_patch.paths - - -def test__array__fits_files_output_correctly(array_2d_7x7, plot_path): - plot_path = path.join(plot_path, "fits") - - array_plotter = aplt.Array2DPlotter( - array=array_2d_7x7, - mat_plot_2d=aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="array", format="fits") - ), - ) - - if path.exists(plot_path): - shutil.rmtree(plot_path) - - array_plotter.figure_2d() - - arr = aa.ndarray_via_fits_from(file_path=path.join(plot_path, "array.fits"), hdu=0) - - assert (arr == array_2d_7x7.native).all() - - -def test__grid( - array_2d_7x7, - grid_2d_7x7, - mask_2d_7x7, - grid_2d_irregular_7x7_list, - plot_path, - plot_patch, -): - grid_2d_plotter = aplt.Grid2DPlotter( - grid=grid_2d_7x7, - visuals_2d=aplt.Visuals2D(indexes=[0, 1, 2]), - mat_plot_2d=aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="grid1", format="png") - ), - ) - - color_array = np.linspace(start=0.0, stop=1.0, num=grid_2d_7x7.shape_slim) - - grid_2d_plotter.figure_2d(color_array=color_array) - - assert path.join(plot_path, "grid1.png") in plot_patch.paths - - grid_2d_plotter = aplt.Grid2DPlotter( - grid=grid_2d_7x7, - visuals_2d=aplt.Visuals2D(indexes=[0, 1, 2]), - mat_plot_2d=aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="grid2", format="png") - ), - ) - - grid_2d_plotter.figure_2d(color_array=color_array) - - assert path.join(plot_path, "grid2.png") in plot_patch.paths - - visuals_2d = aplt.Visuals2D( - origin=grid_2d_irregular_7x7_list, - mask=mask_2d_7x7, - border=mask_2d_7x7.derive_grid.border, - grid=grid_2d_7x7, - positions=grid_2d_irregular_7x7_list, - lines=grid_2d_irregular_7x7_list, - array_overlay=array_2d_7x7, - indexes=[0, 1, 2], - ) - - grid_2d_plotter = aplt.Grid2DPlotter( - grid=grid_2d_7x7, - mat_plot_2d=aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="grid3", format="png") - ), - visuals_2d=visuals_2d, - ) - - grid_2d_plotter.figure_2d(color_array=color_array) - - assert path.join(plot_path, "grid3.png") in plot_patch.paths - - -def test__array_rgb( - array_2d_rgb_7x7, - plot_path, - plot_patch, -): - array_plotter = aplt.Array2DPlotter( - array=array_2d_rgb_7x7, - mat_plot_2d=aplt.MatPlot2D( - output=aplt.Output(path=plot_path, filename="array_rgb", format="png") - ), - ) - - array_plotter.figure_2d() - - assert path.join(plot_path, "array_rgb.png") in plot_patch.paths +import autoarray as aa +import autoarray.plot as aplt +from os import path +import pytest +import numpy as np +import shutil + +directory = path.dirname(path.realpath(__file__)) + + +@pytest.fixture(name="plot_path") +def make_plot_path_setup(): + return path.join( + "{}".format(path.dirname(path.realpath(__file__))), "files", "structures" + ) + + +def test__plot_yx_line(plot_path, plot_patch): + aplt.plot_yx_1d( + y=aa.Array1D.no_mask([1.0, 2.0, 3.0], pixel_scales=1.0), + x=aa.Array1D.no_mask([0.5, 1.0, 1.5], pixel_scales=0.5), + output_path=plot_path, + output_filename="yx_1", + output_format="png", + plot_axis_type="loglog", + ) + + assert path.join(plot_path, "yx_1.png") in plot_patch.paths + + +def test__array( + array_2d_7x7, + mask_2d_7x7, + grid_2d_7x7, + grid_2d_irregular_7x7_list, + plot_path, + plot_patch, +): + aplt.plot_array_2d( + array=array_2d_7x7, + output_path=plot_path, + output_filename="array1", + output_format="png", + ) + + assert path.join(plot_path, "array1.png") in plot_patch.paths + + aplt.plot_array_2d( + array=array_2d_7x7, + output_path=plot_path, + output_filename="array2", + output_format="png", + ) + + assert path.join(plot_path, "array2.png") in plot_patch.paths + + aplt.plot_array_2d( + array=array_2d_7x7, + origin=grid_2d_irregular_7x7_list, + border=mask_2d_7x7.derive_grid.border, + grid=grid_2d_7x7, + positions=grid_2d_irregular_7x7_list, + array_overlay=array_2d_7x7, + output_path=plot_path, + output_filename="array3", + output_format="png", + ) + + assert path.join(plot_path, "array3.png") in plot_patch.paths + + +def test__array__fits_files_output_correctly(array_2d_7x7, plot_path): + plot_path = path.join(plot_path, "fits") + + if path.exists(plot_path): + shutil.rmtree(plot_path) + + aplt.plot_array_2d( + array=array_2d_7x7, + output_path=plot_path, + output_filename="array", + output_format="fits", + ) + + arr = aa.ndarray_via_fits_from(file_path=path.join(plot_path, "array.fits"), hdu=0) + + assert (arr == array_2d_7x7.native).all() + + +def test__grid( + array_2d_7x7, + grid_2d_7x7, + mask_2d_7x7, + grid_2d_irregular_7x7_list, + plot_path, + plot_patch, +): + color_array = np.linspace(start=0.0, stop=1.0, num=grid_2d_7x7.shape_slim) + + aplt.plot_grid_2d( + grid=grid_2d_7x7, + indexes=[0, 1, 2], + output_path=plot_path, + output_filename="grid1", + output_format="png", + color_array=color_array, + ) + + assert path.join(plot_path, "grid1.png") in plot_patch.paths + + aplt.plot_grid_2d( + grid=grid_2d_7x7, + indexes=[0, 1, 2], + output_path=plot_path, + output_filename="grid2", + output_format="png", + color_array=color_array, + ) + + assert path.join(plot_path, "grid2.png") in plot_patch.paths + + aplt.plot_grid_2d( + grid=grid_2d_7x7, + lines=grid_2d_irregular_7x7_list, + indexes=[0, 1, 2], + output_path=plot_path, + output_filename="grid3", + output_format="png", + color_array=color_array, + ) + + assert path.join(plot_path, "grid3.png") in plot_patch.paths + + +def test__array_rgb( + array_2d_rgb_7x7, + plot_path, + plot_patch, +): + aplt.plot_array_2d( + array=array_2d_rgb_7x7, + output_path=plot_path, + output_filename="array_rgb", + output_format="png", + ) + + assert path.join(plot_path, "array_rgb.png") in plot_patch.paths