diff --git a/autoarray/config/visualize/general.yaml b/autoarray/config/visualize/general.yaml index 0c841906..e9ce3f6e 100644 --- a/autoarray/config/visualize/general.yaml +++ b/autoarray/config/visualize/general.yaml @@ -15,13 +15,27 @@ units: cb_unit: $\,\,\mathrm{e^{-}}\,\mathrm{s^{-1}}$ # The string or latex unit label used for the colorbar of the image, for example electrons per second. scaled_symbol: '"' # The symbol used when plotting spatial coordinates computed via the pixel_scale (e.g. for Astronomy data this is arc-seconds). unscaled_symbol: pix # The symbol used when plotting spatial coordinates in unscaled pixel units. +colormap: autoarray # Default colormap for 2D plots. Use any matplotlib name to override (e.g. jet, viridis). +ticks: + extent_factor_2d: 0.75 # Fraction of half-extent used for 2D tick positions (< 1.0 pulls ticks inward from edges). + number_of_ticks_2d: 3 # Number of ticks on each spatial axis of 2D plots. +colorbar: + fraction: 0.047 # Fraction of original axes to use for the colorbar. + pad: 0.01 # Padding between colorbar and axes. + labelrotation: 90 # Rotation of colorbar tick labels in degrees. + labelsize: 22 # Font size of colorbar tick labels for single-panel figures. + labelsize_subplot: 24 # Font size of colorbar tick labels for subplot panels. mat_plot: figure: figsize: (7, 7) # Default figure size. Override via aplt.Figure(figsize=(...)). yticks: fontsize: 22 # Default y-tick font size. Override via aplt.YTicks(fontsize=...). + yticks_subplot: + fontsize: 22 # Default y-tick font size for subplot panels. xticks: fontsize: 22 # Default x-tick font size. Override via aplt.XTicks(fontsize=...). + xticks_subplot: + fontsize: 22 # Default x-tick font size for subplot panels. title: fontsize: 24 # Default title font size. Override via aplt.Title(fontsize=...). ylabel: diff --git a/autoarray/dataset/plot/imaging_plots.py b/autoarray/dataset/plot/imaging_plots.py index 8ea93c0b..fa0f19a4 100644 --- a/autoarray/dataset/plot/imaging_plots.py +++ b/autoarray/dataset/plot/imaging_plots.py @@ -92,6 +92,7 @@ def subplot_imaging_dataset( title="Point Spread Function", colormap=colormap, use_log10=use_log10, + cb_unit="", ) plot_array( dataset.psf.kernel, @@ -99,6 +100,7 @@ def subplot_imaging_dataset( title="PSF (log10)", colormap=colormap, use_log10=True, + cb_unit="", ) plot_array( @@ -107,6 +109,7 @@ def subplot_imaging_dataset( title="Signal-To-Noise Map", colormap=colormap, use_log10=use_log10, + cb_unit="", grid=grid, positions=positions, lines=lines, @@ -120,6 +123,7 @@ def subplot_imaging_dataset( title="Over Sample Size (Light Profiles)", colormap=colormap, use_log10=use_log10, + cb_unit="", ) over_sample_size_pix = getattr(getattr(dataset, "grids", None), "over_sample_size_pixelization", None) @@ -130,8 +134,11 @@ def subplot_imaging_dataset( title="Over Sample Size (Pixelization)", colormap=colormap, use_log10=use_log10, + cb_unit="", ) + from autoarray.plot.utils import hide_unused_axes + hide_unused_axes(axes) plt.tight_layout() subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/dataset/plot/interferometer_plots.py b/autoarray/dataset/plot/interferometer_plots.py index 7841ea50..d3cd10e8 100644 --- a/autoarray/dataset/plot/interferometer_plots.py +++ b/autoarray/dataset/plot/interferometer_plots.py @@ -6,7 +6,7 @@ from autoarray.plot.array import plot_array from autoarray.plot.grid import plot_grid from autoarray.plot.yx import plot_yx -from autoarray.plot.utils import subplot_save +from autoarray.plot.utils import subplot_save, hide_unused_axes from autoarray.structures.grids.irregular_2d import Grid2DIrregular @@ -84,6 +84,7 @@ def subplot_interferometer_dataset( use_log10=use_log10, ) + hide_unused_axes(axes) plt.tight_layout() subplot_save(fig, output_path, output_filename, output_format) @@ -138,5 +139,6 @@ def subplot_interferometer_dirty_images( use_log10=use_log10, ) + hide_unused_axes(axes) plt.tight_layout() subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/fit/plot/fit_imaging_plots.py b/autoarray/fit/plot/fit_imaging_plots.py index 4a314b4c..0733bbc0 100644 --- a/autoarray/fit/plot/fit_imaging_plots.py +++ b/autoarray/fit/plot/fit_imaging_plots.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt from autoarray.plot.array import plot_array -from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax +from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax, hide_unused_axes def subplot_fit_imaging( @@ -118,5 +118,6 @@ def subplot_fit_imaging( lines=lines, ) + hide_unused_axes(axes) plt.tight_layout() subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/fit/plot/fit_interferometer_plots.py b/autoarray/fit/plot/fit_interferometer_plots.py index aaceec08..af4312ed 100644 --- a/autoarray/fit/plot/fit_interferometer_plots.py +++ b/autoarray/fit/plot/fit_interferometer_plots.py @@ -5,7 +5,7 @@ from autoarray.plot.array import plot_array from autoarray.plot.yx import plot_yx -from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax +from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax, hide_unused_axes def subplot_fit_interferometer( @@ -98,6 +98,7 @@ def subplot_fit_interferometer( plot_axis_type="scatter", ) + hide_unused_axes(axes) plt.tight_layout() subplot_save(fig, output_path, output_filename, output_format) @@ -191,5 +192,6 @@ def subplot_fit_interferometer_dirty_images( use_log10=use_log10, ) + hide_unused_axes(axes) plt.tight_layout() subplot_save(fig, output_path, output_filename, output_format) diff --git a/autoarray/inversion/plot/inversion_plots.py b/autoarray/inversion/plot/inversion_plots.py index e3f521a6..6b0d30ce 100644 --- a/autoarray/inversion/plot/inversion_plots.py +++ b/autoarray/inversion/plot/inversion_plots.py @@ -1,13 +1,15 @@ +import csv import logging import numpy as np -from typing import Optional +from pathlib import Path +from typing import Optional, Union import matplotlib.pyplot as plt from autoconf import conf from autoarray.inversion.mappers.abstract import Mapper from autoarray.plot.array import plot_array -from autoarray.plot.utils import numpy_grid, numpy_lines, numpy_positions, subplot_save +from autoarray.plot.utils import numpy_grid, numpy_lines, numpy_positions, subplot_save, hide_unused_axes from autoarray.inversion.plot.mapper_plots import plot_mapper from autoarray.structures.arrays.uniform_2d import Array2D @@ -215,6 +217,7 @@ def _recon_array(): except (TypeError, Exception): pass + hide_unused_axes(axes) plt.tight_layout() subplot_save(fig, output_path, f"{output_filename}_{mapper_index}", output_format) @@ -332,7 +335,40 @@ def subplot_mappings( lines=lines, ) + hide_unused_axes(axes) plt.tight_layout() subplot_save( fig, output_path, f"{output_filename}_{pixelization_index}", output_format ) + + +def save_reconstruction_csv( + inversion, + output_path: Union[str, Path], +) -> None: + """Write a CSV of each mapper's reconstruction and noise map to *output_path*. + + One file is written per mapper: ``source_plane_reconstruction_{i}.csv``, + with columns ``y``, ``x``, ``reconstruction``, ``noise_map``. + + Parameters + ---------- + inversion + An ``AbstractInversion`` instance. + output_path + Directory in which to write the CSV files. + """ + output_path = Path(output_path) + mapper_list = inversion.cls_list_from(cls=Mapper) + + for i, mapper in enumerate(mapper_list): + y = mapper.source_plane_mesh_grid[:, 0] + x = mapper.source_plane_mesh_grid[:, 1] + reconstruction = inversion.reconstruction_dict[mapper] + noise_map = inversion.reconstruction_noise_map_dict[mapper] + + with open(output_path / f"source_plane_reconstruction_{i}.csv", mode="w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["y", "x", "reconstruction", "noise_map"]) + for j in range(len(x)): + writer.writerow([float(y[j]), float(x[j]), float(reconstruction[j]), float(noise_map[j])]) diff --git a/autoarray/plot/array.py b/autoarray/plot/array.py index c23028ed..9ea94960 100644 --- a/autoarray/plot/array.py +++ b/autoarray/plot/array.py @@ -47,11 +47,11 @@ def plot_array( title: str = "", xlabel: str = "", ylabel: str = "", - colormap: str = "jet", + colormap: Optional[str] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, use_log10: bool = False, - aspect: str = "auto", + cb_unit: Optional[str] = None, origin_imshow: str = "upper", # --- figure control (used only when ax is None) ----------------------------- figsize: Optional[Tuple[int, int]] = None, @@ -105,8 +105,6 @@ def plot_array( Explicit color scale limits. use_log10 When ``True`` a ``LogNorm`` is applied. - aspect - Passed directly to ``imshow``. origin_imshow Passed directly to ``imshow`` (``"upper"`` or ``"lower"``). figsize @@ -135,6 +133,10 @@ def plot_array( if array is None or array.size == 0: return + if colormap is None: + from autoarray.plot.utils import _default_colormap + colormap = _default_colormap() + # convert overlay params (safe for None and already-numpy inputs) border = numpy_grid(border) origin = numpy_grid(origin) @@ -180,16 +182,33 @@ def plot_array( else: norm = None + # Compute the axes-box aspect ratio from the data extent so that the + # physical cell is correctly shaped and tight_layout has no whitespace + # to absorb. This reproduces the old "square" subplot behaviour where + # ratio = x_range / y_range was passed to plt.subplot(aspect=ratio). + if extent is not None: + x_range = abs(extent[1] - extent[0]) + y_range = abs(extent[3] - extent[2]) + _box_aspect = (x_range / y_range) if y_range > 0 else 1.0 + else: + h, w = array.shape[:2] + _box_aspect = (w / h) if h > 0 else 1.0 + im = ax.imshow( array, cmap=colormap, norm=norm, extent=extent, - aspect=aspect, + aspect="auto", # image fills the axes box; box shape set below origin=origin_imshow, ) - plt.colorbar(im, ax=ax) + # Shape the axes box to match the data so there is no surrounding + # whitespace when the panel is embedded in a subplot grid. + ax.set_aspect(_box_aspect, adjustable="box") + + from autoarray.plot.utils import _apply_colorbar + _apply_colorbar(im, ax, cb_unit=cb_unit, is_subplot=not owns_figure) # --- overlays -------------------------------------------------------------- if array_overlay is not None: @@ -198,7 +217,7 @@ def plot_array( cmap="Greys", alpha=0.5, extent=extent, - aspect=aspect, + aspect="auto", origin=origin_imshow, ) @@ -223,7 +242,7 @@ def plot_array( ax.scatter(mesh_grid[:, 1], mesh_grid[:, 0], s=1, c="w", alpha=0.5) if positions is not None: - colors = ["r", "g", "b", "m", "c", "y"] + colors = ["k", "g", "b", "m", "c", "y"] for i, pos in enumerate(positions): ax.scatter(pos[:, 1], pos[:, 0], s=20, c=colors[i % len(colors)], zorder=5) @@ -263,7 +282,7 @@ def plot_array( pass # --- labels / ticks -------------------------------------------------------- - apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel) + apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel, is_subplot=not owns_figure) if extent is not None: apply_extent(ax, extent) diff --git a/autoarray/plot/grid.py b/autoarray/plot/grid.py index 88cbac5f..33287202 100644 --- a/autoarray/plot/grid.py +++ b/autoarray/plot/grid.py @@ -32,7 +32,7 @@ def plot_grid( title: str = "", xlabel: str = 'x (")', ylabel: str = 'y (")', - colormap: str = "jet", + colormap: Optional[str] = None, buffer: float = 0.1, extent: Optional[Tuple[float, float, float, float]] = None, force_symmetric_extent: bool = True, @@ -101,6 +101,10 @@ def plot_grid( lines = numpy_lines(lines) + if colormap is None: + from autoarray.plot.utils import _default_colormap + colormap = _default_colormap() + owns_figure = ax is None if owns_figure: figsize = figsize or conf_figsize("figures") @@ -126,7 +130,8 @@ def plot_grid( ecolor=colors, ) - plt.colorbar(sc, ax=ax) + from autoarray.plot.utils import _apply_colorbar + _apply_colorbar(sc, ax) else: if y_errors is None and x_errors is None: ax.scatter(grid[:, 1], grid[:, 0], s=1, c="k") diff --git a/autoarray/plot/inversion.py b/autoarray/plot/inversion.py index baa01436..67767fb8 100644 --- a/autoarray/plot/inversion.py +++ b/autoarray/plot/inversion.py @@ -21,7 +21,7 @@ def plot_inversion_reconstruction( title: str = "Reconstruction", xlabel: str = 'x (")', ylabel: str = 'y (")', - colormap: str = "jet", + colormap: Optional[str] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, use_log10: bool = False, @@ -76,6 +76,10 @@ def plot_inversion_reconstruction( from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay from autoarray.inversion.mesh.interpolator.knn import InterpolatorKNearestNeighbor + if colormap is None: + from autoarray.plot.utils import _default_colormap + colormap = _default_colormap() + owns_figure = ax is None if owns_figure: figsize = figsize or conf_figsize("figures") @@ -192,7 +196,8 @@ def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent): aspect="auto", origin="upper", ) - plt.colorbar(im, ax=ax) + from autoarray.plot.utils import _apply_colorbar + _apply_colorbar(im, ax) else: y_edges, x_edges = mapper.mesh_geometry.edges_transformed.T Y, X = np.meshgrid(y_edges, x_edges, indexing="ij") @@ -204,7 +209,8 @@ def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent): norm=norm, cmap=colormap, ) - plt.colorbar(im, ax=ax) + from autoarray.plot.utils import _apply_colorbar + _apply_colorbar(im, ax) def _plot_delaunay(ax, pixel_values, mapper, norm, colormap): @@ -240,4 +246,5 @@ def _plot_delaunay(ax, pixel_values, mapper, norm, colormap): vals = pixel_values tc = ax.tripcolor(x, y, vals, cmap=colormap, norm=norm, shading="gouraud") - plt.colorbar(tc, ax=ax) + from autoarray.plot.utils import _apply_colorbar + _apply_colorbar(tc, ax) diff --git a/autoarray/plot/segmentdata.py b/autoarray/plot/segmentdata.py index 5f30e524..5c641873 100644 --- a/autoarray/plot/segmentdata.py +++ b/autoarray/plot/segmentdata.py @@ -1042,3 +1042,15 @@ ] ), } + +COLORMAP_NAME = "autoarray" + + +def register(): + """Register the autoarray segmentdata colormap with matplotlib (idempotent).""" + import matplotlib + import matplotlib.colors as mcolors + + if COLORMAP_NAME not in matplotlib.colormaps: + cmap = mcolors.LinearSegmentedColormap(COLORMAP_NAME, segmentdata) + matplotlib.colormaps.register(cmap) diff --git a/autoarray/plot/utils.py b/autoarray/plot/utils.py index caa2c7ec..be7e538a 100644 --- a/autoarray/plot/utils.py +++ b/autoarray/plot/utils.py @@ -235,7 +235,7 @@ def symmetric_cmap_from(array, symmetric_value=None): return colors.Normalize(vmin=-abs_max, vmax=abs_max) -def set_with_color_values(ax, cmap, color_values, norm=None, fraction=0.047, pad=0.01): +def set_with_color_values(ax, cmap, color_values, norm=None): """Attach a colorbar to *ax* driven by *color_values* rather than a plotted artist. Useful for Delaunay mapper visualisation where ``ax.tripcolor`` already draws @@ -252,8 +252,6 @@ def set_with_color_values(ax, cmap, color_values, norm=None, fraction=0.047, pad norm A ``matplotlib.colors.Normalize`` instance. If ``None`` a default ``Normalize(vmin, vmax)`` is created from *color_values*. - fraction, pad - Passed directly to ``plt.colorbar``. """ import matplotlib.cm as cm import matplotlib.colors as mcolors @@ -264,7 +262,7 @@ def set_with_color_values(ax, cmap, color_values, norm=None, fraction=0.047, pad mappable = cm.ScalarMappable(norm=norm, cmap=cmap) mappable.set_array(color_values) - return plt.colorbar(mappable=mappable, ax=ax, fraction=fraction, pad=pad) + return _apply_colorbar(mappable, ax) def subplot_save(fig, output_path, output_filename, output_format): @@ -369,36 +367,32 @@ def apply_labels( title: str = "", xlabel: str = "", ylabel: str = "", + is_subplot: bool = False, ) -> None: """Apply title, axis labels, and tick font sizes to *ax* from config. Reads font sizes from the ``mat_plot`` section of - ``visualize/general.yaml`` so that users can override them globally - without touching call sites. Falls back to the values that the - old ``MatWrap`` system used when the config is unavailable. - - Parameters - ---------- - ax - The matplotlib axes to configure. - title - Title string. - xlabel - X-axis label string. - ylabel - Y-axis label string. + ``visualize/general.yaml``. When *is_subplot* is ``True``, reads + ``*_subplot`` keys (defaulting to the single-figure values / 10 for ticks). """ - title_fs = conf_mat_plot_fontsize("title", default=16) - xlabel_fs = conf_mat_plot_fontsize("xlabel", default=14) - ylabel_fs = conf_mat_plot_fontsize("ylabel", default=14) - xticks_fs = conf_mat_plot_fontsize("xticks", default=12) - yticks_fs = conf_mat_plot_fontsize("yticks", default=12) + if is_subplot: + title_fs = conf_mat_plot_fontsize("title_subplot", default=conf_mat_plot_fontsize("title", default=16)) + xlabel_fs = conf_mat_plot_fontsize("xlabel_subplot", default=conf_mat_plot_fontsize("xlabel", default=14)) + ylabel_fs = conf_mat_plot_fontsize("ylabel_subplot", default=conf_mat_plot_fontsize("ylabel", default=14)) + xticks_fs = conf_mat_plot_fontsize("xticks_subplot", default=22) + yticks_fs = conf_mat_plot_fontsize("yticks_subplot", default=22) + else: + title_fs = conf_mat_plot_fontsize("title", default=16) + xlabel_fs = conf_mat_plot_fontsize("xlabel", default=14) + ylabel_fs = conf_mat_plot_fontsize("ylabel", default=14) + xticks_fs = conf_mat_plot_fontsize("xticks", default=12) + yticks_fs = conf_mat_plot_fontsize("yticks", default=12) ax.set_title(title, fontsize=title_fs) ax.set_xlabel(xlabel, fontsize=xlabel_fs) ax.set_ylabel(ylabel, fontsize=ylabel_fs) ax.tick_params(axis="x", labelsize=xticks_fs) - ax.tick_params(axis="y", labelsize=yticks_fs) + ax.tick_params(axis="y", labelsize=yticks_fs, labelrotation=90) def save_figure( @@ -494,34 +488,161 @@ def plot_visibilities_1d(vis, ax: plt.Axes, title: str = "") -> None: ax.legend(fontsize=8) +def _conf_colorbar(key: str, default): + try: + from autoconf import conf + return conf.instance["visualize"]["general"]["colorbar"][key] + except Exception: + return default + + +def _colorbar_tick_values(norm) -> Optional[List[float]]: + """Return [min, mid, max] tick positions from *norm*, with mid in log-space for LogNorm.""" + if norm is None or norm.vmin is None or norm.vmax is None: + return None + import matplotlib.colors as mcolors + lo, hi = float(norm.vmin), float(norm.vmax) + if isinstance(norm, mcolors.LogNorm): + mid = 10 ** ((np.log10(lo) + np.log10(hi)) / 2.0) + else: + mid = (lo + hi) / 2.0 + return [lo, mid, hi] + + +def _fmt_tick(v: float) -> str: + """Format a single tick value to 2 decimal places without scientific notation.""" + return f"{v:.2f}" + + +def _colorbar_tick_labels(tick_values: List[float], cb_unit: Optional[str] = None) -> List[str]: + """Format tick values without scientific notation, appending *cb_unit* to the middle label. + + If *cb_unit* is ``None`` the unit is read from config; pass ``""`` for unitless panels. + """ + if cb_unit is None: + try: + from autoconf import conf + cb_unit = conf.instance["visualize"]["general"]["units"]["cb_unit"] + except Exception: + cb_unit = "" + labels = [_fmt_tick(v) for v in tick_values] + mid = len(labels) // 2 + labels[mid] = f"{labels[mid]}{cb_unit}" + return labels + + +def _apply_colorbar( + mappable, + ax: plt.Axes, + cb_unit: Optional[str] = None, + is_subplot: bool = False, +) -> None: + """Create a colorbar with 3 ticks (min/mid/max), unit on middle label, config styling. + + Parameters + ---------- + cb_unit + Override the unit string on the middle tick. Pass ``""`` for unitless panels. + ``None`` reads the unit from config. + is_subplot + When ``True`` uses ``labelsize_subplot`` from config (default 22) instead of + the single-figure ``labelsize`` (default 22). + """ + tick_values = _colorbar_tick_values(getattr(mappable, "norm", None)) + + cb = plt.colorbar( + mappable, + ax=ax, + fraction=float(_conf_colorbar("fraction", 0.047)), + pad=float(_conf_colorbar("pad", 0.01)), + ticks=tick_values, + ) + labelsize_key = "labelsize_subplot" if is_subplot else "labelsize" + labelsize_default = 24 if is_subplot else 22 + labelsize = float(_conf_colorbar(labelsize_key, labelsize_default)) + if tick_values is not None: + cb.ax.set_yticklabels( + _colorbar_tick_labels(tick_values, cb_unit=cb_unit), + va="center", + fontsize=labelsize, + ) + cb.ax.tick_params( + labelrotation=float(_conf_colorbar("labelrotation", 90)), + labelsize=labelsize, + ) + + +def hide_unused_axes(axes) -> None: + """Turn off any axes in the flattened *axes* array that have no plotted data.""" + for ax in axes: + if not ax.has_data(): + ax.axis("off") + + +def _default_colormap() -> str: + """Return the colormap name from config, registering the custom one if needed.""" + try: + from autoconf import conf + name = conf.instance["visualize"]["general"]["colormap"] + except Exception: + name = "autoarray" + if name == "autoarray": + from autoarray.plot.segmentdata import register + register() + return name + + +def _conf_ticks(key: str, default: float) -> float: + try: + from autoconf import conf + return float(conf.instance["visualize"]["general"]["ticks"][key]) + except Exception: + return default + + +def _inward_ticks(lo: float, hi: float, factor: float, n: int) -> np.ndarray: + """Return *n* tick positions pulled inward from the extent edges by *factor*.""" + centre = (lo + hi) / 2.0 + return np.linspace( + centre + (lo - centre) * factor, + centre + (hi - centre) * factor, + n, + ) + + +def _round_ticks(values: np.ndarray, sig: int = 2) -> np.ndarray: + """Round *values* to *sig* significant figures.""" + with np.errstate(divide="ignore", invalid="ignore"): + nonzero = np.where(values != 0, np.abs(values), 1.0) + mags = np.where(values != 0, 10 ** (sig - 1 - np.floor(np.log10(nonzero))), 1.0) + return np.round(values * mags) / mags + + def _arcsec_labels(ticks) -> List[str]: - """Format tick values as arcsecond strings, e.g. ``-1``, ``0``, ``1"``.""" + """Format tick values as arcsecond strings, e.g. ``-1"``, ``0"``, ``1"``.""" return [f'{v:g}"' for v in ticks] def apply_extent( ax: plt.Axes, extent: Tuple[float, float, float, float], - n_ticks: int = 3, ) -> None: """ - Apply axis limits and evenly spaced linear ticks to *ax*. + Apply axis limits and inward-pulled, rounded, arcsecond-labelled ticks to *ax*. - Parameters - ---------- - ax - The matplotlib axes to configure. - extent - ``[xmin, xmax, ymin, ymax]`` limits. - n_ticks - Number of ticks on each axis. ``3`` produces ``[-R, 0, R]`` for - a symmetric extent, matching the reference ``plot_grid`` example. + Tick count and inward factor are read from ``visualize/general.yaml`` + (``ticks.number_of_ticks_2d`` and ``ticks.extent_factor_2d``), defaulting + to 3 ticks and factor 0.75. """ + factor = _conf_ticks("extent_factor_2d", 0.75) + n = int(_conf_ticks("number_of_ticks_2d", 3)) + xmin, xmax, ymin, ymax = extent ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) - xticks = np.linspace(xmin, xmax, n_ticks) - yticks = np.linspace(ymin, ymax, n_ticks) + + xticks = _round_ticks(_inward_ticks(xmin, xmax, factor, n)) + yticks = _round_ticks(_inward_ticks(ymin, ymax, factor, n)) ax.set_xticks(xticks) ax.set_yticks(yticks) ax.set_xticklabels(_arcsec_labels(xticks))