diff --git a/src/spatialdata_plot/_logging.py b/src/spatialdata_plot/_logging.py index 364cba27..454df9ae 100644 --- a/src/spatialdata_plot/_logging.py +++ b/src/spatialdata_plot/_logging.py @@ -4,11 +4,27 @@ import re from collections.abc import Iterator from contextlib import contextmanager +from contextvars import ContextVar from typing import TYPE_CHECKING if TYPE_CHECKING: # pragma: no cover from _pytest.logging import LogCaptureFixture +# Holds the public-facing function name (e.g. "render_shapes") for log messages. +# Set at the top of each _render_* entry point so that all downstream helpers +# report the user-visible origin rather than internal function names. +_log_context: ContextVar[str] = ContextVar("_log_context", default="") + + +class _ContextFilter(logging.Filter): + """Inject the public function name from ``_log_context`` into log records.""" + + def filter(self, record: logging.LogRecord) -> bool: + ctx = _log_context.get() + if ctx: + record.funcName = ctx + return True + def _setup_logger() -> "logging.Logger": from rich.console import Console @@ -20,6 +36,8 @@ def _setup_logger() -> "logging.Logger": if console.is_jupyter is True: console.is_jupyter = False ch = RichHandler(show_path=False, console=console, show_time=False) + ch.setFormatter(logging.Formatter("%(funcName)s: %(message)s")) + ch.addFilter(_ContextFilter()) logger.addHandler(ch) # this prevents double outputs diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index d6216341..52950a12 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -25,7 +25,7 @@ from xarray import DataArray, DataTree from spatialdata_plot._accessor import register_spatial_data_accessor -from spatialdata_plot._logging import logger +from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl.render import ( _render_images, _render_labels, @@ -826,6 +826,7 @@ def show( sd.SpatialData A SpatialData object. """ + _log_context.set("show") # copy the SpatialData object so we don't modify the original try: plotting_tree = self._sdata.plotting_tree diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 73786d3c..014c3cc5 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -23,7 +23,7 @@ from spatialdata.transformations.transformations import Identity from xarray import DataTree -from spatialdata_plot._logging import logger +from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl.render_params import ( Color, ColorbarSpec, @@ -121,6 +121,7 @@ def _render_shapes( legend_params: LegendParams, colorbar_requests: list[ColorbarSpec] | None = None, ) -> None: + _log_context.set("render_shapes") element = render_params.element col_for_color = render_params.col_for_color groups = render_params.groups @@ -608,6 +609,7 @@ def _render_points( legend_params: LegendParams, colorbar_requests: list[ColorbarSpec] | None = None, ) -> None: + _log_context.set("render_points") element = render_params.element col_for_color = render_params.col_for_color table_name = render_params.table_name @@ -998,6 +1000,7 @@ def _render_images( rasterize: bool, colorbar_requests: list[ColorbarSpec] | None = None, ) -> None: + _log_context.set("render_images") sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, filter_tables=False, @@ -1254,6 +1257,7 @@ def _render_labels( rasterize: bool, colorbar_requests: list[ColorbarSpec] | None = None, ) -> None: + _log_context.set("render_labels") element = render_params.element table_name = render_params.table_name table_layer = render_params.table_layer