Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/spatialdata_plot/_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,27 @@
import re
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
from _pytest.logging import LogCaptureFixture

# Holds the public-facing function name (e.g. "render_shapes") for log messages.
# Set at the top of each _render_* entry point so that all downstream helpers
# report the user-visible origin rather than internal function names.
_log_context: ContextVar[str] = ContextVar("_log_context", default="")


class _ContextFilter(logging.Filter):
"""Inject the public function name from ``_log_context`` into log records."""

def filter(self, record: logging.LogRecord) -> bool:
ctx = _log_context.get()
if ctx:
record.funcName = ctx
return True


def _setup_logger() -> "logging.Logger":
from rich.console import Console
Expand All @@ -20,6 +36,8 @@ def _setup_logger() -> "logging.Logger":
if console.is_jupyter is True:
console.is_jupyter = False
ch = RichHandler(show_path=False, console=console, show_time=False)
ch.setFormatter(logging.Formatter("%(funcName)s: %(message)s"))
ch.addFilter(_ContextFilter())
logger.addHandler(ch)

# this prevents double outputs
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from xarray import DataArray, DataTree

from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot._logging import logger
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl.render import (
_render_images,
_render_labels,
Expand Down Expand Up @@ -826,6 +826,7 @@ def show(
sd.SpatialData
A SpatialData object.
"""
_log_context.set("show")
# copy the SpatialData object so we don't modify the original
try:
plotting_tree = self._sdata.plotting_tree
Expand Down
6 changes: 5 additions & 1 deletion src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from spatialdata.transformations.transformations import Identity
from xarray import DataTree

from spatialdata_plot._logging import logger
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl.render_params import (
Color,
ColorbarSpec,
Expand Down Expand Up @@ -121,6 +121,7 @@ def _render_shapes(
legend_params: LegendParams,
colorbar_requests: list[ColorbarSpec] | None = None,
) -> None:
_log_context.set("render_shapes")
element = render_params.element
col_for_color = render_params.col_for_color
groups = render_params.groups
Expand Down Expand Up @@ -608,6 +609,7 @@ def _render_points(
legend_params: LegendParams,
colorbar_requests: list[ColorbarSpec] | None = None,
) -> None:
_log_context.set("render_points")
element = render_params.element
col_for_color = render_params.col_for_color
table_name = render_params.table_name
Expand Down Expand Up @@ -998,6 +1000,7 @@ def _render_images(
rasterize: bool,
colorbar_requests: list[ColorbarSpec] | None = None,
) -> None:
_log_context.set("render_images")
sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_tables=False,
Expand Down Expand Up @@ -1254,6 +1257,7 @@ def _render_labels(
rasterize: bool,
colorbar_requests: list[ColorbarSpec] | None = None,
) -> None:
_log_context.set("render_labels")
element = render_params.element
table_name = render_params.table_name
table_layer = render_params.table_layer
Expand Down