diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index d6216341..a881cd5f 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Literal, cast +import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -789,6 +790,7 @@ def show( ax: list[Axes] | Axes | None = None, return_ax: bool = False, save: str | Path | None = None, + show: bool | None = None, ) -> sd.SpatialData: """ Plot the images in the SpatialData object. @@ -813,6 +815,12 @@ def show( Number of columns in the figure. Default is 4. return_ax : Whether to return the axes object created. False by default. + show : + Whether to call ``plt.show()`` at the end. If ``None`` (default), the plot is shown + automatically when running in non-interactive mode (scripts) and suppressed in + interactive sessions (e.g. Jupyter). Set to ``False`` to prevent ``plt.show()`` + from being called, which is useful when you want to save or further modify the + figure after calling this method. colorbar : Global switch to enable/disable all colorbars. Per-layer settings are ignored when this is False. colorbar_params : @@ -856,6 +864,7 @@ def show( ax, return_ax, save, + show, ) sdata = self._copy() @@ -1211,8 +1220,12 @@ def _draw_colorbar( if fig_params.fig is not None and save is not None: save_fig(fig_params.fig, path=save) - # Manually show plot if we're not in interactive mode - # https://stackoverflow.com/a/64523765 - if not hasattr(sys, "ps1"): + # Show the plot unless the caller opted out. + # Default (show=None): display in non-interactive mode (scripts), suppress in interactive + # sessions. We check both sys.ps1 (standard REPL) and matplotlib.is_interactive() + # (covers IPython, Jupyter, plt.ion(), and IDE consoles like PyCharm). + if show is None: + show = not hasattr(sys, "ps1") and not matplotlib.is_interactive() + if show: plt.show() return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 11d5d10d..9d3b2954 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2000,6 +2000,7 @@ def _validate_show_parameters( ax: list[Axes] | Axes | None, return_ax: bool, save: str | Path | None, + show: bool | None, ) -> None: if coordinate_systems is not None and not isinstance(coordinate_systems, list | str): raise TypeError("Parameter 'coordinate_systems' must be a string or a list of strings.") @@ -2089,6 +2090,9 @@ def _validate_show_parameters( if save is not None and not isinstance(save, str | Path): raise TypeError("Parameter 'save' must be a string or a pathlib.Path.") + if show is not None and not isinstance(show, bool): + raise TypeError("Parameter 'show' must be a boolean or None.") + def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]: colorbar = param_dict.get("colorbar", "auto")