From 71c2110b21df2bd1ee54d66d95918869814c254e Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 28 Mar 2024 09:05:13 +0100 Subject: [PATCH 01/12] simplify conditional --- src/napari_spatialdata/_view.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/napari_spatialdata/_view.py b/src/napari_spatialdata/_view.py index 27b7fbea..39ca4291 100644 --- a/src/napari_spatialdata/_view.py +++ b/src/napari_spatialdata/_view.py @@ -2,8 +2,6 @@ import napari from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame -from geopandas.geodataframe import GeoDataFrame from loguru import logger from napari._qt.qt_resources import get_stylesheet from napari._qt.utils import QImg2array @@ -317,11 +315,7 @@ def _select_layer(self) -> None: self.var_widget.clear() self.obsm_widget.clear() self.color_by.clear() - if ( - isinstance(layer, (Points, Shapes)) - and isinstance(layer.metadata["sdata"][layer.metadata["name"]], (DaskDataFrame, GeoDataFrame)) - and (cols_df := layer.metadata["_columns_df"]) is not None - ): + if isinstance(layer, (Points, Shapes)) and (cols_df := layer.metadata.get("_columns_df")) is not None: self.dataframe_columns_widget.addItems(map(str, cols_df.columns)) self.model.system_name = layer.metadata.get("name", None) return From b66f1abb78e6fbb7af79b892d75c43acdf7d9e23 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 2 Jun 2025 13:00:43 +0200 Subject: [PATCH 02/12] add channel_widget, remove colorbar --- src/napari_spatialdata/_sdata_widgets.py | 380 ++++++++++++++++++++++- src/napari_spatialdata/_view.py | 9 - src/napari_spatialdata/_viewer.py | 20 +- src/napari_spatialdata/_widgets.py | 194 +----------- src/napari_spatialdata/utils/_utils.py | 8 + 5 files changed, 393 insertions(+), 218 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 9915706a..4288d283 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -1,3 +1,11 @@ +"""Widgets for displaying and interacting with SpatialData objects in napari. + +This module provides a set of Qt widgets for visualizing and interacting with +SpatialData objects within the napari viewer. It includes widgets for selecting +coordinate systems, browsing elements within SpatialData objects, and handling +channel selection for multidimensional image data. +""" + from __future__ import annotations import platform @@ -18,6 +26,7 @@ from qtpy.QtWidgets import QLabel, QListWidget, QListWidgetItem, QProgressBar, QVBoxLayout, QWidget from spatialdata import SpatialData from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM +from xarray import DataArray, DataTree from napari_spatialdata._viewer import SpatialDataViewer from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD @@ -47,7 +56,27 @@ class ElementWidget(QListWidget): + """Widget for displaying and selecting elements from SpatialData objects. + + This widget shows a list of available elements (images, labels, points, shapes) + from the SpatialData objects, with warnings for elements that might be slow to render. + + Attributes + ---------- + _icon: Icon used for warning indicators. + _sdata: List of SpatialData objects. + _duplicate_element_names: Dictionary of duplicate element names. + _elements: Dictionary mapping element names to their metadata. + """ + def __init__(self, sdata: EventedList): + """Initialize the ElementWidget. + + Parameters + ---------- + sdata : EventedList + List of SpatialData objects to display elements from. + """ super().__init__() self._icon = QIcon(str(icon_path)) self._sdata = sdata @@ -55,12 +84,30 @@ def __init__(self, sdata: EventedList): self._elements: None | dict[str, dict[str, str | int]] = None def _onItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + """Update the element list when the coordinate system selection changes. + + Parameters + ---------- + selected_coordinate_system : QListWidgetItem or int or Iterable[str] + The newly selected coordinate system. + Can be a QListWidgetItem, an index, or an iterable of strings. + """ self.clear() elements, _ = get_elements_meta_mapping(self._sdata, selected_coordinate_system, self._duplicate_element_names) self._set_element_widget_items(elements) self._elements = elements def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) -> None: + """Populate the widget with element items. + + Adds each element as an item in the list widget, with warning icons for elements + that might be slow to render (e.g., many circles or shapes). + + Parameters + ---------- + elements : dict[str, dict[str, str | int]] + Dictionary mapping element names to their metadata. + """ for key, dict_val in sorted(elements.items(), key=itemgetter(0)): sdata = self._sdata[dict_val["sdata_index"]] element_type = dict_val["element_type"] @@ -91,27 +138,170 @@ def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) - class CoordinateSystemWidget(QListWidget): + """Widget for selecting coordinate systems from SpatialData objects. + + This widget displays a list of available coordinate systems from all SpatialData + objects, allowing the user to select one for visualization. + + Attributes + ---------- + _sdata + List of SpatialData objects. + _system + Currently selected coordinate system. + """ + def __init__(self, sdata: EventedList): + """Initialize the CoordinateSystemWidget. + + Parameters + ---------- + sdata : EventedList + List of SpatialData objects to extract coordinate systems from. + """ super().__init__() self._sdata = sdata self._system: None | str = None # Sort alphabetically, but keep default "global" at the top. - coordinate_systems = sorted(cs for sdata in self._sdata for cs in sdata.coordinate_systems) + coordinate_systems = sorted({cs for sdata in self._sdata for cs in sdata.coordinate_systems}) if DEFAULT_COORDINATE_SYSTEM in coordinate_systems: coordinate_systems.remove(DEFAULT_COORDINATE_SYSTEM) coordinate_systems.insert(0, DEFAULT_COORDINATE_SYSTEM) self.addItems(coordinate_systems) def _select_coord_sys(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + """Store the currently selected coordinate system. + + Parameters + ---------- + selected_coordinate_system : QListWidgetItem or int or Iterable[str] + The selected coordinate system. + Can be a QListWidgetItem, an index, or an iterable of strings. + """ self._system = str(selected_coordinate_system) +class ChannelWidget(QListWidget): + """Widget for selecting channels from multidimensional image data. + + This widget displays available channels for image elements, allowing users + to select individual channels for visualization. + + Attributes + ---------- + _sdata + List of SpatialData objects. + _element_widget_text + Text of the currently selected element. + _element_dict + Dictionary with metadata of the currently selected element. + _channels + List of available channels for the current element. + """ + + def __init__(self, sdata: EventedList): + """Initialize the ChannelWidget. + + Parameters + ---------- + sdata : EventedList + List of SpatialData objects. + """ + super().__init__() + self._sdata = sdata + self._element_widget_text: str | None = None + self._element_dict: dict[str, str | int] | None = None + self._channels: list[str] | None = None + + def _on_element_item_changed( + self, sdata: SpatialData, element_widget_text: str, element_dict: dict[str, str | int] + ) -> None: + """Update the channel list when the selected element changes. + + Clears the current channel list and populates it with channels from the + selected element if it's an image. + + Parameters + ---------- + sdata : SpatialData + The SpatialData object containing the selected element. + element_widget_text : str + Text of the selected element in the ElementWidget. + element_dict : dict + Dictionary with metadata of the selected element. + """ + self.clear() + self._element_dict = None + self._channels = None + self._element_widget_text = element_widget_text + if element_dict["element_type"] == "images": + element: DataArray | DataTree = sdata[element_dict["original_name"]] + self._element_dict = element_dict + self._element_widget_text = element_widget_text + self._set_channel_widget_items(element) + + def _set_channel_widget_items(self, element: DataArray | DataTree) -> None: + """Populate the widget with channel items from the selected image element. + + Adds each channel as an item in the list widget, except for RGB(A) channels + which are handled differently. + + Parameters + ---------- + element : object + The image element to extract channels from. + """ + channels = list(element.c.to_numpy()) + self._channels = channels + if channels not in [["r", "g", "b"], ["r", "g", "b", "a"]]: + for ch in channels: + item = QListWidgetItem(ch) + self.addItem(item) + + class DataLoadThread(QThread): + """Thread for asynchronously loading SpatialData elements. + + This thread handles loading different types of data (images, labels, points, shapes) + from SpatialData objects without blocking the UI. + + Parameters + ---------- + parent : SdataWidget + Parent SdataWidget that owns this thread. + + Attributes + ---------- + returned : Signal + Signal emitted when data loading is complete, carrying the created layer. + sdata_widget : SdataWidget + Parent SdataWidget that owns this thread. + _data_type : str + Type of data to load (images, labels, points, shapes). + _text : str + Name of the element to load. + _sdata : SpatialData + SpatialData object containing the element. + _selected_cs : str + Selected coordinate system. + _multi : bool + Boolean indicating if multiple SpatialData objects are present. + _channel_name : str, optional + Optional channel name for image data. + """ + returned = Signal(object) def __init__(self, parent: SdataWidget): + """Initialize the DataLoadThread. + + Parameters + ---------- + parent : SdataWidget + Parent SdataWidget that owns this thread. + """ super().__init__(parent=parent) self.sdata_widget = parent self._data_type = "" @@ -120,11 +310,42 @@ def __init__(self, parent: SdataWidget): self._selected_cs: str = "" self._multi: bool = False - def load_data(self, data_type: str, text: str, sdata: SpatialData, selected_cs: str, multi: bool) -> None: + def load_data( + self, + data_type: str, + text: str, + sdata: SpatialData, + selected_cs: str, + multi: bool, + channel_name: str | None = None, + ) -> None: + """Set up data loading parameters and start the thread. + + Parameters + ---------- + data_type : str + Type of data to load (images, labels, points, shapes). + text : str + Name of the element to load. + sdata : SpatialData + SpatialData object containing the element. + selected_cs : str + Selected coordinate system. + multi : bool + Boolean indicating if multiple SpatialData objects are present. + channel_name : str, optional + Optional channel name for image data. + + Raises + ------ + RuntimeError + If the thread is already running. + """ if self.isRunning(): raise RuntimeError("Thread is already running.") self._data_type = data_type self._text = text + self._channel_name = channel_name self._sdata = sdata self._selected_cs = selected_cs self._multi = multi @@ -135,6 +356,11 @@ def load_data(self, data_type: str, text: str, sdata: SpatialData, selected_cs: self.start() def run(self) -> None: + """Execute the data loading operation. + + Loads the specified data element based on its type and emits the + returned layer through the 'returned' signal. + """ if not self._data_type: return if self._data_type == "labels": @@ -143,7 +369,7 @@ def run(self) -> None: ) elif self._data_type == "images": layer = self.sdata_widget.viewer_model.get_sdata_image( - self._sdata, self._text, self._selected_cs, self._multi + self._sdata, self._text, self._selected_cs, self._multi, self._channel_name ) elif self._data_type == "points": layer = self.sdata_widget.viewer_model.get_sdata_points( @@ -158,7 +384,41 @@ def run(self) -> None: class SdataWidget(QWidget): + """Main widget for interacting with SpatialData objects in napari. + + This widget combines coordinate system selection, element browsing, and channel + selection into a unified interface for visualizing SpatialData objects in napari. + It manages the loading and display of different data types and handles coordinate + system transformations. + + Attributes + ---------- + _sdata + List of SpatialData objects. + viewer_model + SpatialDataViewer instance for interacting with napari. + worker_thread + Thread for asynchronous data loading. + coordinate_system_widget + Widget for selecting coordinate systems. + elements_widget + Widget for browsing and selecting elements. + channel_widget + Widget for selecting channels in image data. + slider + Progress bar shown during data loading. + """ + def __init__(self, viewer: Viewer, sdata: EventedList): + """Initialize the SdataWidget. + + Parameters + ---------- + viewer : Viewer + napari Viewer instance. + sdata : EventedList + List of SpatialData objects to visualize. + """ super().__init__() self._sdata = sdata self.viewer_model = SpatialDataViewer(viewer, self._sdata) @@ -170,6 +430,7 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.coordinate_system_widget = CoordinateSystemWidget(self._sdata) self.elements_widget = ElementWidget(self._sdata) + self.channel_widget = ChannelWidget(self._sdata) self.slider = QProgressBar(self) self.slider.setRange(0, 0) self.slider.setVisible(False) @@ -179,7 +440,11 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.layout().addWidget(self.coordinate_system_widget) self.layout().addWidget(QLabel("Elements:")) self.layout().addWidget(self.elements_widget) - self.elements_widget.itemDoubleClicked.connect(self._on_click_item) + self.layout().addWidget(QLabel("Channels:")) + self.layout().addWidget(self.channel_widget) + self.elements_widget.currentItemChanged.connect(self._on_element_item_changed) + self.elements_widget.itemDoubleClicked.connect(self._on_doubleclick_element_item) + self.channel_widget.itemDoubleClicked.connect(self._on_doubleclick_channel_item) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.elements_widget._onItemChange(item.text()) ) @@ -194,24 +459,79 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.viewer_model.viewer.layers.events.inserted.connect(self._on_insert_layer) def _on_insert_layer(self, event: Event) -> None: + """Connect visibility events for newly inserted layers. + + Parameters + ---------- + event : Event + Event containing the newly inserted layer. + """ layer = event.value layer.events.visible.connect(self._update_visible_in_coordinate_system) - def _on_click_item(self, item: QListWidgetItem) -> None: + def _on_element_item_changed(self, item: QListWidgetItem) -> None: + """Handle selection changes in the elements widget. + + Updates the channel widget with channels from the selected element. + + Parameters + ---------- + item : QListWidgetItem + The newly selected element item. + """ + if self.elements_widget._elements: + sdata, _ = _get_sdata_key(self._sdata, self.elements_widget._elements, item.text()) + self.channel_widget._on_element_item_changed( + sdata, item.text(), self.elements_widget._elements[item.text()] + ) + + def _on_doubleclick_channel_item(self, item: QListWidgetItem) -> None: + """Handle double-click events on channel items in the channel widget. + + Loads and displays the selected channel of the current element. + + Parameters + ---------- + item : QListWidgetItem + The double-clicked channel item. + """ + if self.channel_widget._element_widget_text: + self._onClick(self.channel_widget._element_widget_text, item.text()) + + def _on_doubleclick_element_item(self, item: QListWidgetItem) -> None: + """Handle double-click events on element items in the element widget. + + Loads and displays the selected element. + + Parameters + ---------- + item : QListWidgetItem + The double-clicked element item. + """ self._onClick(item.text()) def _hide_slider(self) -> None: + """Hide the progress slider when data loading is complete.""" self.slider.setVisible(False) - def _onClick(self, text: str) -> None: + def _onClick(self, element_name: str, channel_name: str | None = None) -> None: + """Handle click events to load and display data elements. + + Parameters + ---------- + element_name : str + Name of the element to load. + channel_name : str, optional + Name of the channel to load for image elements. + """ selected_cs = self.coordinate_system_widget._system if self.worker_thread.isRunning(): show_info("Please wait for the current operation to finish.") return if selected_cs and self.elements_widget._elements: - sdata, multi = _get_sdata_key(self._sdata, self.elements_widget._elements, text) - if (type_ := self.elements_widget._elements[text]["element_type"]) not in { + sdata, multi = _get_sdata_key(self._sdata, self.elements_widget._elements, element_name) + if (type_ := self.elements_widget._elements[element_name]["element_type"]) not in { "labels", "images", "shapes", @@ -221,12 +541,18 @@ def _onClick(self, text: str) -> None: type_ = cast(str, type_) - self.worker_thread.load_data(type_, text, sdata, selected_cs, multi) + self.worker_thread.load_data(type_, element_name, sdata, selected_cs, multi, channel_name) if not PROBLEMATIC_NUMPY_MACOS: self.slider.setVisible(True) def _update_visible_in_coordinate_system(self, event: Event) -> None: - """Toggle active in the coordinate system metadata when changing visibility of layer.""" + """Toggle active status in the coordinate system metadata when changing layer visibility. + + Parameters + ---------- + event : Event + Event triggered by changing layer visibility. + """ metadata = event.source.metadata layer_active = metadata.get("_active_in_cs") selected_coordinate_system = self.coordinate_system_widget._system @@ -240,7 +566,12 @@ def _update_visible_in_coordinate_system(self, event: Event) -> None: layer_active.remove(selected_coordinate_system) def _update_layers_visibility(self) -> None: - """Toggle layer visibility dependent on presence in currently selected coordinate system.""" + """Toggle layer visibility based on presence in the currently selected coordinate system. + + Updates the visibility of all layers based on whether they are active in the + currently selected coordinate system. Also updates layer metadata to track + coordinate system information. + """ elements = self.elements_widget._elements coordinate_system = self.coordinate_system_widget._system # No layer selected on first time coordinate system selection @@ -259,6 +590,33 @@ def _update_layers_visibility(self) -> None: layer.metadata["_current_cs"] = coordinate_system def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points: + """Load and create appropriate layer for shape data. + + Determines the geometry type of the shapes element and calls the appropriate + method to create either a Points layer (for Point geometries) or a Shapes + layer (for Polygon or MultiPolygon geometries). + + Parameters + ---------- + sdata : SpatialData + SpatialData object containing the shapes element. + key : str + Name of the shapes element to load. + selected_cs : str + Selected coordinate system. + multi : bool + Whether multiple SpatialData objects are present. + + Returns + ------- + Shapes or Points + The created napari layer. + + Raises + ------ + TypeError + If the geometry type is not Point, Polygon, or MultiPolygon. + """ original_name = key[: key.rfind("_")] if multi else key if type(sdata.shapes[original_name].iloc[0].geometry) is shapely.geometry.point.Point: diff --git a/src/napari_spatialdata/_view.py b/src/napari_spatialdata/_view.py index 12037729..d6e39dac 100644 --- a/src/napari_spatialdata/_view.py +++ b/src/napari_spatialdata/_view.py @@ -41,9 +41,7 @@ from napari_spatialdata._widgets import ( AListWidget, AnnDataSaveDialog, - CBarWidget, ComponentWidget, - RangeSliderWidget, SaveDialog, ScatterAnnotationDialog, ) @@ -497,13 +495,6 @@ def __init__(self, napari_viewer: Viewer, model: DataModel | None = None) -> Non self.color_by = QLabel("Colored by:") self.layout().addWidget(self.color_by) - # scalebar - colorbar = CBarWidget(model=self.model) - self.slider = RangeSliderWidget(self.viewer, self.model, colorbar=colorbar) - self._viewer.window.add_dock_widget(self.slider, area="left", name="slider") - self._viewer.window.add_dock_widget(colorbar, area="left", name="colorbar") - self.viewer.layers.selection.events.active.connect(self.slider._onLayerChange) - if (layer := self.viewer.layers.selection.active) is not None and layer.metadata.get("adata") is not None: self._on_layer_update() diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index de90a4de..fe503b88 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -29,6 +29,7 @@ _get_ellipses_from_circles, _get_init_metadata_adata, _get_transform, + _obtain_channel_image, _transform_coordinates, get_duplicate_element_names, get_napari_version, @@ -442,10 +443,14 @@ def clean_worker(self) -> None: """Clean the worker.""" self.worker = None - def add_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> None: - self.add_layer(self.get_sdata_image(sdata, key, selected_cs, multi)) + def add_sdata_image( + self, sdata: SpatialData, key: str, selected_cs: str, multi: bool, channel_name: str | None = None + ) -> None: + self.add_layer(self.get_sdata_image(sdata, key, selected_cs, multi, channel_name)) - def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Image: + def get_sdata_image( + self, sdata: SpatialData, key: str, selected_cs: str, multi: bool, channel_name: str | None = None + ) -> Image: """ Add an image in a spatial data object to the viewer. @@ -465,7 +470,12 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: original_name = original_name[: original_name.rfind("_")] affine = _get_transform(sdata.images[original_name], selected_cs) - rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name]) + if channel_name: + image = _obtain_channel_image(element=sdata.images[original_name], channel_name=channel_name) + rgb = False + key = key + f"_{channel_name}" + else: + image, rgb = _adjust_channels_order(element=sdata.images[original_name]) channels = ("RGB(A)",) if rgb else get_channels(sdata.images[original_name]) @@ -473,7 +483,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: # TODO: type check return Image( - rgb_image, + image, rgb=rgb, name=key, affine=affine, diff --git a/src/napari_spatialdata/_widgets.py b/src/napari_spatialdata/_widgets.py index 5431036f..aff10673 100644 --- a/src/napari_spatialdata/_widgets.py +++ b/src/napari_spatialdata/_widgets.py @@ -19,17 +19,12 @@ from qtpy import QtCore, QtWidgets from qtpy.QtCore import Qt, Signal from scanpy.plotting._utils import _set_colors_for_categorical_obs -from sklearn.preprocessing import MinMaxScaler from spatialdata._types import ArrayLike -from superqt import QRangeSlider -from vispy import scene -from vispy.color.colormap import Colormap, MatplotlibColormap -from vispy.scene.widgets import ColorBarWidget from napari_spatialdata._model import DataModel from napari_spatialdata.utils._utils import _min_max_norm, get_napari_version -__all__ = ["AListWidget", "CBarWidget", "RangeSliderWidget", "ComponentWidget"] +__all__ = ["AListWidget", "ComponentWidget"] # label string: attribute name # TODO(giovp): remove since layer controls private? @@ -409,193 +404,6 @@ def attr(self, field: str | None) -> None: self._attr = field -class CBarWidget(QtWidgets.QWidget): - FORMAT = "{0:0.2f}" - - cmapChanged = Signal(str) - climChanged = Signal((float, float)) - - def __init__( - self, - model: DataModel, - cmap: str = "viridis", - label: str | None = None, - width: int | None = 250, - height: int | None = 50, - **kwargs: Any, - ): - super().__init__(**kwargs) - - self._model = model - - self._clim = (0.0, 1.0) - self._oclim = self._clim - - self._width = width - self._height = height - self._label = label - - self.__init_UI() - - def __init_UI(self) -> None: - self.setFixedWidth(self._width) - self.setFixedHeight(self._height) - - # use napari's BG color for dark mode - self._canvas = scene.SceneCanvas( - size=(self._width, self._height), bgcolor="#262930", parent=self, decorate=False, resizable=False, dpi=150 - ) - self._colorbar = ColorBarWidget( - self._create_colormap(self.cmap), - orientation="top", - label=self._label, - label_color="white", - clim=self.getClim(), - border_width=1.0, - border_color="black", - padding=(0.3, 0.167), - axis_ratio=0.05, - ) - - self._canvas.central_widget.add_widget(self._colorbar) - - self.climChanged.connect(self.onClimChanged) - self.cmapChanged.connect(self.onCmapChanged) - - def _create_colormap(self, cmap: str) -> Colormap: - ominn, omaxx = self.getOclim() - delta = omaxx - ominn + 1e-12 - - minn, maxx = self.getClim() - minn = (minn - ominn) / delta - maxx = (maxx - ominn) / delta - - assert 0 <= minn <= 1, f"Expected `min` to be in `[0, 1]`, found `{minn}`" - assert 0 <= maxx <= 1, f"Expected `maxx` to be in `[0, 1]`, found `{maxx}`" - - cm = MatplotlibColormap(cmap) - - return Colormap(cm[np.linspace(minn, maxx, len(cm.colors))], interpolation="linear") - - def getCmap(self) -> str: - return self.cmap - - def onCmapChanged(self, value: str) -> None: - # this does not trigger update for some reason... - self._colorbar.cmap = self._create_colormap(value) - self._colorbar._colorbar._update() - - def setClim(self, value: tuple[float, float]) -> None: - if value == self._clim: - return - - self._clim = value - self.climChanged.emit(*value) - - def getClim(self) -> tuple[float, float]: - return self._clim - - def getOclim(self) -> tuple[float, float]: - return self._oclim - - def setOclim(self, value: tuple[float, float]) -> None: - # original color limit used for 0-1 normalization - self._oclim = value - - def onClimChanged(self, minn: float, maxx: float) -> None: - # ticks are not working with vispy's colorbar - self._colorbar.cmap = self._create_colormap(self.cmap) - self._colorbar.clim = (self.FORMAT.format(minn), self.FORMAT.format(maxx)) - - def getCanvas(self) -> scene.SceneCanvas: - return self._canvas - - def getColorBar(self) -> ColorBarWidget: - return self._colorbar - - def setLayout(self, layout: QtWidgets.QLayout) -> None: - layout.addWidget(self.getCanvas().native) - super().setLayout(layout) - - def update_color(self) -> None: - # when changing selected layers that have the same limit - # could also trigger it as self._colorbar.clim = self.getClim() - # but the above option also updates geometry - # cbarwidget->cbar->cbarvisual - self._colorbar._colorbar._colorbar._update() - - @property - def cmap(self) -> str: - return self._model.cmap - - -class RangeSliderWidget(QRangeSlider): - def __init__(self, viewer: Viewer, model: DataModel, colorbar: CBarWidget, **kwargs: Any): - super().__init__(**kwargs) - - self._viewer = viewer - self._model = model - self._colorbar = colorbar - self._cmap = plt.get_cmap(self._colorbar.cmap) - self.setValue((0, 100)) - self.setSliderPosition((0, 100)) - self.setSingleStep(0.01) - self.setOrientation(Qt.Horizontal) - self.valueChanged.connect(self._onValueChange) - - def _onLayerChange(self) -> None: - layer = self.viewer.layers.selection.active - if layer is not None: - self._onValueChange((0, 100)) - - def _onValueChange(self, percentile: tuple[float, float]) -> None: - layer = self.viewer.layers.selection.active - # TODO(michalk8): use constants - if "data" not in layer.metadata: - return None # noqa: RET501 - v = layer.metadata["data"] - # this code is currently not used since the slider is not enabled; so I silenced the mypy error; 2. there is a - # mismatch for this error with the mypy in the CI, so I silenced the unused-ignore from the local mypy. - # when this code is re-enabled, let's fix mypy - clipped = np.clip(v, *np.percentile(v, percentile)) # type: ignore[misc,unused-ignore] - - if isinstance(layer, Points): - layer.metadata = {**layer.metadata, "perc": percentile} - layer.face_color = "value" - layer.properties = {"value": clipped} - layer.refresh_colors() - elif isinstance(layer, Labels): - norm_vec = self._scale_vec(clipped) - color_vec = self._cmap(norm_vec) - layer.color = dict(zip(layer.color.keys(), color_vec, strict=False)) - layer.properties = {"value": clipped} - layer.refresh() - - self._colorbar.setOclim(layer.metadata["minmax"]) - self._colorbar.setClim((np.min(layer.properties["value"]), np.max(layer.properties["value"]))) - self._colorbar.update_color() - - def _scale_vec(self, vec: ArrayLike) -> ArrayLike: - ominn, omaxx = self._colorbar.getOclim() - delta = omaxx - ominn + 1e-12 - - minn, maxx = self._colorbar.getClim() - minn = (minn - ominn) / delta - maxx = (maxx - ominn) / delta - scaler = MinMaxScaler(feature_range=(minn, maxx)) - return scaler.fit_transform(vec.reshape(-1, 1)) - - @property - def viewer(self) -> napari.Viewer: - """:mod:`napari` viewer.""" - return self._viewer - - @property - def model(self) -> DataModel: - """:mod:`napari` viewer.""" - return self._model - - class SaveDialog(QtWidgets.QDialog): def __init__(self, layer: Layer, table_name: str) -> None: super().__init__() diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 237b84b7..34fc41ee 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -221,6 +221,14 @@ def _points_inside_triangles(points: ArrayLike, triangles: ArrayLike) -> ArrayLi return out +def _obtain_channel_image(element: DataArray | DataTree, channel_name: str) -> DataArray | list[DataArray]: + if isinstance(element, DataArray): + new_raster = element.sel(c=channel_name) + else: + pass + return new_raster + + def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | list[DataArray], bool]: """Swap the axes to y, x, c and check if an image supports rgb(a) visualization. From a3a0761ec2bd19948a2e89cc41be0ce96620ff7c Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 2 Jun 2025 15:02:57 +0200 Subject: [PATCH 03/12] allow multiscale channel selection --- src/napari_spatialdata/_sdata_widgets.py | 7 ++++-- src/napari_spatialdata/_viewer.py | 2 +- src/napari_spatialdata/utils/_utils.py | 32 ++++++++++++++---------- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 4288d283..59474616 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -253,11 +253,14 @@ def _set_channel_widget_items(self, element: DataArray | DataTree) -> None: element : object The image element to extract channels from. """ - channels = list(element.c.to_numpy()) + if isinstance(element, DataArray): + channels = list(element.c.to_numpy()) + else: + channels = list(element["scale0"].c.to_numpy()) self._channels = channels if channels not in [["r", "g", "b"], ["r", "g", "b", "a"]]: for ch in channels: - item = QListWidgetItem(ch) + item = QListWidgetItem(str(ch)) self.addItem(item) diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index fe503b88..cbfbc2d8 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -473,7 +473,7 @@ def get_sdata_image( if channel_name: image = _obtain_channel_image(element=sdata.images[original_name], channel_name=channel_name) rgb = False - key = key + f"_{channel_name}" + key = key + f"_ch:{channel_name}" else: image, rgb = _adjust_channels_order(element=sdata.images[original_name]) diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 34fc41ee..7c2bb7f1 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -221,14 +221,27 @@ def _points_inside_triangles(points: ArrayLike, triangles: ArrayLike) -> ArrayLi return out -def _obtain_channel_image(element: DataArray | DataTree, channel_name: str) -> DataArray | list[DataArray]: - if isinstance(element, DataArray): - new_raster = element.sel(c=channel_name) - else: - pass +def _datatree_to_dataarray_list(new_raster: DataArray | DataTree) -> DataArray | list[DataArray]: + if isinstance(new_raster, DataTree): + list_of_xdata = [] + for k in new_raster: + v = new_raster[k].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + list_of_xdata.append(xdata) + return list_of_xdata return new_raster +def _obtain_channel_image(element: DataArray | DataTree, channel_name: str | int) -> DataArray | list[DataArray]: + if np.issubdtype(element["scale0"].c.to_numpy().dtype, np.integer) and isinstance(channel_name, str): + channel_name = int(channel_name) + + # works for both DataArray and DataTree + new_raster = element.sel(c=channel_name) + return _datatree_to_dataarray_list(new_raster) + + def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | list[DataArray], bool]: """Swap the axes to y, x, c and check if an image supports rgb(a) visualization. @@ -272,14 +285,7 @@ def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | l rgb = False new_raster = element - if isinstance(new_raster, DataTree): - list_of_xdata = [] - for k in new_raster: - v = new_raster[k].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - list_of_xdata.append(xdata) - new_raster = list_of_xdata + new_raster = _datatree_to_dataarray_list(new_raster) return new_raster, rgb From 40b0e861222e826ff4c30fd2d266ffcdfdbad8c4 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 2 Jun 2025 17:22:38 +0200 Subject: [PATCH 04/12] add test and fix --- src/napari_spatialdata/utils/_utils.py | 6 ++- tests/conftest.py | 14 ++++++- tests/test_spatialdata.py | 53 ++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 7c2bb7f1..45866062 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -234,7 +234,11 @@ def _datatree_to_dataarray_list(new_raster: DataArray | DataTree) -> DataArray | def _obtain_channel_image(element: DataArray | DataTree, channel_name: str | int) -> DataArray | list[DataArray]: - if np.issubdtype(element["scale0"].c.to_numpy().dtype, np.integer) and isinstance(channel_name, str): + is_multiscale_int_ch = isinstance(element, DataTree) and np.issubdtype( + element["scale0"].c.to_numpy().dtype, np.integer + ) + is_int_ch = isinstance(element, DataArray) and np.issubdtype(element.c.to_numpy().dtype, np.integer) + if isinstance(channel_name, str) and (is_multiscale_int_ch or is_int_ch): channel_name = int(channel_name) # works for both DataArray and DataTree diff --git a/tests/conftest.py b/tests/conftest.py index bb658ec1..ac397b03 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ from spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.datasets import blobs -from spatialdata.models import TableModel +from spatialdata.models import Image2DModel, TableModel from napari_spatialdata.utils._test_utils import save_image, take_screenshot @@ -131,6 +131,18 @@ def sdata_blobs() -> SpatialData: return blobs() +@pytest.fixture() +def sdata_channel_images() -> SpatialData: + sdata = blobs() + sdata["blobs_image_str_ch"] = Image2DModel.parse( + sdata["blobs_image"], c_coords=["channel1", "channel2", "channel3"] + ) + sdata["blobs_multiscale_image_str_ch"] = Image2DModel.parse( + sdata["blobs_image"], c_coords=["channel1", "channel2", "channel3"], scale_factors=[2, 2] + ) + return sdata + + @pytest.fixture def image(): _, image = _get_blobs_galaxy() diff --git a/tests/test_spatialdata.py b/tests/test_spatialdata.py index 00c53007..fdf35ae7 100644 --- a/tests/test_spatialdata.py +++ b/tests/test_spatialdata.py @@ -73,6 +73,59 @@ def test_sdatawidget_images(make_napari_viewer: Any, blobs_extra_cs: SpatialData del blobs_extra_cs.images["image"] +@pytest.mark.parametrize( + "images", [["blobs_image", "blobs_image_str_ch"], ["blobs_multiscale_image", "blobs_multiscale_image_str_ch"]] +) +def test_channel_selection(qtbot, make_napari_viewer, sdata_channel_images, images): + """Test selecting a channel from an image with integer channel names.""" + # Create a viewer + viewer = make_napari_viewer() + + # Create the SdataWidget + widget = SdataWidget(viewer, EventedList([sdata_channel_images])) + + # Click on 'global' coordinate system + center_pos = get_center_pos_listitem(widget.coordinate_system_widget, "global") + click_list_widget_item(qtbot, widget.coordinate_system_widget, center_pos, "currentItemChanged") + + # Click on the image element to populate the channel widget + center_pos = get_center_pos_listitem(widget.elements_widget, images[0]) + click_list_widget_item(qtbot, widget.elements_widget, center_pos, "currentItemChanged") + + # Verify that the channel widget has been populated with the correct channels + assert widget.channel_widget.count() == 3 + assert widget.channel_widget.item(0).text() == "0" + assert widget.channel_widget.item(1).text() == "1" + assert widget.channel_widget.item(2).text() == "2" + + # Double-click on a channel to add it as a layer + center_pos = get_center_pos_listitem(widget.channel_widget, "1") + click_list_widget_item(qtbot, widget.channel_widget, center_pos, "currentItemChanged", "double") + widget._onClick(images[0], "1") + + # Verify that the layer has been added with the correct name and data + assert len(viewer.layers) == 1 + assert viewer.layers[0].name == f"{images[0]}_ch:1" + + # Verify that the layer contains only the selected channel + assert viewer.layers[0].data.shape == (512, 512) + + center_pos = get_center_pos_listitem(widget.elements_widget, images[1]) + click_list_widget_item(qtbot, widget.elements_widget, center_pos, "currentItemChanged") + + assert widget.channel_widget.count() == 3 + assert widget.channel_widget.item(0).text() == "channel1" + assert widget.channel_widget.item(1).text() == "channel2" + assert widget.channel_widget.item(2).text() == "channel3" + + center_pos = get_center_pos_listitem(widget.channel_widget, "channel2") + click_list_widget_item(qtbot, widget.channel_widget, center_pos, "currentItemChanged", "double") + widget._onClick(images[1], "channel2") + + assert len(viewer.layers) == 2 + assert viewer.layers[1].name == f"{images[1]}_ch:channel2" + + def test_sdatawidget_labels(qtbot, make_napari_viewer: Any, blobs_extra_cs: SpatialData): viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) From 2959acd2c1d962815b74019f507a8df3bd378281 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 2 Jun 2025 17:46:12 +0200 Subject: [PATCH 05/12] copy dask pin from spatialdata --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 84213d3f..799f6e4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = anndata click cycler - dask>=2024.4.1 + dask>=2024.4.1,<=2024.11.2" geopandas loguru matplotlib From ee067adb4d7e118cef138853a46e29f852637666 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 2 Jun 2025 17:50:08 +0200 Subject: [PATCH 06/12] remove quote --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 799f6e4c..3d0fd099 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = anndata click cycler - dask>=2024.4.1,<=2024.11.2" + dask>=2024.4.1,<=2024.11.2 geopandas loguru matplotlib From abb01c260f649b71ceda3bdc546a6b6cbdbebc5a Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 2 Jun 2025 21:30:29 +0200 Subject: [PATCH 07/12] merge three widgets --- src/napari_spatialdata/_sdata_widgets.py | 132 ++++++++--------------- src/napari_spatialdata/utils/_utils.py | 42 +++++++- tests/test_spatialdata.py | 16 +-- 3 files changed, 89 insertions(+), 101 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 59474616..2617ded5 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -13,7 +13,7 @@ from importlib.metadata import version from operator import itemgetter from pathlib import Path -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Literal, cast import numpy as np import shapely @@ -30,7 +30,12 @@ from napari_spatialdata._viewer import SpatialDataViewer from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD -from napari_spatialdata.utils._utils import _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping +from napari_spatialdata.utils._utils import ( + _get_sdata_key, + get_duplicate_element_names, + get_elements_meta_mapping, + requires_widget_type, +) if TYPE_CHECKING: from napari import Viewer @@ -55,11 +60,12 @@ PROBLEMATIC_NUMPY_MACOS = False -class ElementWidget(QListWidget): - """Widget for displaying and selecting elements from SpatialData objects. +class ListWidget(QListWidget): + """Widget for displaying and selecting coordinate systems or elements from SpatialData objects or channels. - This widget shows a list of available elements (images, labels, points, shapes) - from the SpatialData objects, with warnings for elements that might be slow to render. + This widget can show a list of coordinate systems or available elements (images, labels, points, shapes) + from the SpatialData objects, with warnings for elements that might be slow to render. A third option is to + let it show channels from image elements. Attributes ---------- @@ -69,22 +75,37 @@ class ElementWidget(QListWidget): _elements: Dictionary mapping element names to their metadata. """ - def __init__(self, sdata: EventedList): - """Initialize the ElementWidget. + def __init__(self, sdata: EventedList, widget_type: Literal["coordinate_system", "element", "channel"]): + """Initialize the Widget. Parameters ---------- sdata : EventedList List of SpatialData objects to display elements from. + widget_type: Literal["coordinate_system", "element", "channel"] + The type of the widget. This determines what kind of items it will show. """ super().__init__() + self._widget_type = widget_type self._icon = QIcon(str(icon_path)) self._sdata = sdata self._duplicate_element_names, _ = get_duplicate_element_names(self._sdata) self._elements: None | dict[str, dict[str, str | int]] = None + self._element_widget_text: str | None = None + self._element_dict: dict[str, str | int] | None = None + self._system: None | str = None - def _onItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: - """Update the element list when the coordinate system selection changes. + if widget_type == "coordinate_system": + # Sort alphabetically, but keep default "global" at the top. + coordinate_systems = sorted({cs for sdata in self._sdata for cs in sdata.coordinate_systems}) + if DEFAULT_COORDINATE_SYSTEM in coordinate_systems: + coordinate_systems.remove(DEFAULT_COORDINATE_SYSTEM) + coordinate_systems.insert(0, DEFAULT_COORDINATE_SYSTEM) + self.addItems(coordinate_systems) + + @requires_widget_type("element") + def _onCsItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + """Update the element list of an element widget when the coordinate system selection changes. Parameters ---------- @@ -97,8 +118,9 @@ def _onItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iter self._set_element_widget_items(elements) self._elements = elements + @requires_widget_type("element") def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) -> None: - """Populate the widget with element items. + """Populate an element widget with element items. Adds each element as an item in the list widget, with warning icons for elements that might be slow to render (e.g., many circles or shapes). @@ -136,41 +158,7 @@ def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) - ) self.addItem(item) - -class CoordinateSystemWidget(QListWidget): - """Widget for selecting coordinate systems from SpatialData objects. - - This widget displays a list of available coordinate systems from all SpatialData - objects, allowing the user to select one for visualization. - - Attributes - ---------- - _sdata - List of SpatialData objects. - _system - Currently selected coordinate system. - """ - - def __init__(self, sdata: EventedList): - """Initialize the CoordinateSystemWidget. - - Parameters - ---------- - sdata : EventedList - List of SpatialData objects to extract coordinate systems from. - """ - super().__init__() - - self._sdata = sdata - self._system: None | str = None - - # Sort alphabetically, but keep default "global" at the top. - coordinate_systems = sorted({cs for sdata in self._sdata for cs in sdata.coordinate_systems}) - if DEFAULT_COORDINATE_SYSTEM in coordinate_systems: - coordinate_systems.remove(DEFAULT_COORDINATE_SYSTEM) - coordinate_systems.insert(0, DEFAULT_COORDINATE_SYSTEM) - self.addItems(coordinate_systems) - + @requires_widget_type("coordinate_system") def _select_coord_sys(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: """Store the currently selected coordinate system. @@ -182,43 +170,11 @@ def _select_coord_sys(self, selected_coordinate_system: QListWidgetItem | int | """ self._system = str(selected_coordinate_system) - -class ChannelWidget(QListWidget): - """Widget for selecting channels from multidimensional image data. - - This widget displays available channels for image elements, allowing users - to select individual channels for visualization. - - Attributes - ---------- - _sdata - List of SpatialData objects. - _element_widget_text - Text of the currently selected element. - _element_dict - Dictionary with metadata of the currently selected element. - _channels - List of available channels for the current element. - """ - - def __init__(self, sdata: EventedList): - """Initialize the ChannelWidget. - - Parameters - ---------- - sdata : EventedList - List of SpatialData objects. - """ - super().__init__() - self._sdata = sdata - self._element_widget_text: str | None = None - self._element_dict: dict[str, str | int] | None = None - self._channels: list[str] | None = None - + @requires_widget_type("channel") def _on_element_item_changed( self, sdata: SpatialData, element_widget_text: str, element_dict: dict[str, str | int] ) -> None: - """Update the channel list when the selected element changes. + """Update the channel items in the channel widget when the selected element changes. Clears the current channel list and populates it with channels from the selected element if it's an image. @@ -234,7 +190,6 @@ def _on_element_item_changed( """ self.clear() self._element_dict = None - self._channels = None self._element_widget_text = element_widget_text if element_dict["element_type"] == "images": element: DataArray | DataTree = sdata[element_dict["original_name"]] @@ -242,8 +197,9 @@ def _on_element_item_changed( self._element_widget_text = element_widget_text self._set_channel_widget_items(element) + @requires_widget_type("channel") def _set_channel_widget_items(self, element: DataArray | DataTree) -> None: - """Populate the widget with channel items from the selected image element. + """Populate a channel widget with channel items from the selected image element. Adds each channel as an item in the list widget, except for RGB(A) channels which are handled differently. @@ -257,7 +213,7 @@ def _set_channel_widget_items(self, element: DataArray | DataTree) -> None: channels = list(element.c.to_numpy()) else: channels = list(element["scale0"].c.to_numpy()) - self._channels = channels + if channels not in [["r", "g", "b"], ["r", "g", "b", "a"]]: for ch in channels: item = QListWidgetItem(str(ch)) @@ -431,9 +387,9 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.setLayout(QVBoxLayout()) - self.coordinate_system_widget = CoordinateSystemWidget(self._sdata) - self.elements_widget = ElementWidget(self._sdata) - self.channel_widget = ChannelWidget(self._sdata) + self.coordinate_system_widget = ListWidget(self._sdata, "coordinate_system") + self.elements_widget = ListWidget(self._sdata, "element") + self.channel_widget = ListWidget(self._sdata, "channel") self.slider = QProgressBar(self) self.slider.setRange(0, 0) self.slider.setVisible(False) @@ -449,12 +405,12 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.elements_widget.itemDoubleClicked.connect(self._on_doubleclick_element_item) self.channel_widget.itemDoubleClicked.connect(self._on_doubleclick_channel_item) self.coordinate_system_widget.currentItemChanged.connect( - lambda item: self.elements_widget._onItemChange(item.text()) + lambda item: self.elements_widget._onCsItemChange(item.text()) ) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.coordinate_system_widget._select_coord_sys(item.text()) ) - self.viewer_model.layer_saved.connect(self.elements_widget._onItemChange) + self.viewer_model.layer_saved.connect(self.elements_widget._onCsItemChange) self.coordinate_system_widget.currentItemChanged.connect(self._update_layers_visibility) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.viewer_model._affine_transform_layers(item.text()) diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 45866062..95ef9af7 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from functools import wraps from random import randint -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast import numpy as np import packaging.version @@ -44,7 +44,7 @@ from napari.utils.events import EventedList from qtpy.QtWidgets import QListWidgetItem - from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget + from napari_spatialdata._sdata_widgets import ListWidget from spatialdata._types import ArrayLike @@ -405,9 +405,7 @@ def _get_init_metadata_adata(sdata: SpatialData, table_name: str | None, element return adata -def get_itemindex_by_text( - list_widget: CoordinateSystemWidget | ElementWidget, item_text: str -) -> None | QListWidgetItem: +def get_itemindex_by_text(list_widget: ListWidget, item_text: str) -> None | QListWidgetItem: """ Get the item in a listwidget based on its text. @@ -511,3 +509,37 @@ def block_signals(widget: QObject) -> Generator[None, None, None]: yield finally: widget.blockSignals(False) + + +WidgetType = Literal["coordinate_system", "element", "channel"] +F = TypeVar("F", bound=Callable[..., Any]) + + +def requires_widget_type(*allowed_types: WidgetType) -> Callable[[F], F]: + """ + Restrict method execution to specific widget types. + + Parameters + ---------- + *allowed_types + The widget types for which the decorated method is valid. + + Returns + ------- + Callable + A decorator function that wraps the method. + """ + + def decorator(method: F) -> F: + @wraps(method) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if getattr(self, "_widget_type", None) not in allowed_types: + raise RuntimeError( + f"{method.__name__} is only valid when _widget_type is one of {allowed_types}, " + f"but got '{self._widget_type}'" + ) + return method(self, *args, **kwargs) + + return cast(F, wrapper) + + return decorator diff --git a/tests/test_spatialdata.py b/tests/test_spatialdata.py index fdf35ae7..2efcc02e 100644 --- a/tests/test_spatialdata.py +++ b/tests/test_spatialdata.py @@ -21,7 +21,7 @@ from xarray import DataArray, DataTree from napari_spatialdata import QtAdataViewWidget -from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget, SdataWidget +from napari_spatialdata._sdata_widgets import ListWidget, SdataWidget from napari_spatialdata.constants import config from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem @@ -30,10 +30,10 @@ def test_elementwidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): _ = make_napari_viewer() - widget = ElementWidget(EventedList([blobs_extra_cs])) + widget = ListWidget(EventedList([blobs_extra_cs]), "element") assert widget._sdata is not None assert not widget._elements - widget._onItemChange("global") + widget._onCsItemChange("global") assert widget._elements for name in blobs_extra_cs.images: assert widget._elements[name]["element_type"] == "images" @@ -47,7 +47,7 @@ def test_elementwidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): def test_coordinatewidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): _ = make_napari_viewer() - widget = CoordinateSystemWidget(EventedList([blobs_extra_cs])) + widget = ListWidget(EventedList([blobs_extra_cs]), "coordinate_system") items = [widget.item(x).text() for x in range(widget.count())] assert len(items) == len(blobs_extra_cs.coordinate_systems) for item in items: @@ -59,13 +59,13 @@ def test_sdatawidget_images(make_napari_viewer: Any, blobs_extra_cs: SpatialData widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.images.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert isinstance(widget.viewer_model.viewer.layers[0], Image) assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.images.keys())[0] blobs_extra_cs.images["image"] = to_multiscale(blobs_extra_cs.images["blobs_image"], [2, 4]) - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("image") assert len(widget.viewer_model.viewer.layers) == 2 @@ -131,7 +131,7 @@ def test_sdatawidget_labels(qtbot, make_napari_viewer: Any, blobs_extra_cs: Spat widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.labels.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.labels.keys())[0] @@ -172,7 +172,7 @@ def test_sdatawidget_points(caplog, make_napari_viewer: Any, blobs_extra_cs: Spa widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.points.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.points.keys())[0] From 3484cf4bdb7c3f08ef5ff2e366488bcd888d0859 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 2 Jun 2025 22:02:38 +0200 Subject: [PATCH 08/12] update docstring --- src/napari_spatialdata/_sdata_widgets.py | 27 +++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 2617ded5..76f74fc1 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -1,7 +1,7 @@ """Widgets for displaying and interacting with SpatialData objects in napari. This module provides a set of Qt widgets for visualizing and interacting with -SpatialData objects within the napari viewer. It includes widgets for selecting +SpatialData objects within the napari viewer. It includes a ListWidget for selecting coordinate systems, browsing elements within SpatialData objects, and handling channel selection for multidimensional image data. """ @@ -67,12 +67,29 @@ class ListWidget(QListWidget): from the SpatialData objects, with warnings for elements that might be slow to render. A third option is to let it show channels from image elements. + The widget's behavior is determined by the `widget_type` parameter passed during initialization: + - "coordinate_system": Displays available coordinate systems from SpatialData objects + - "element": Displays available elements (images, labels, points, shapes) from SpatialData objects + - "channel": Displays available channels from selected image elements + Attributes ---------- - _icon: Icon used for warning indicators. - _sdata: List of SpatialData objects. - _duplicate_element_names: Dictionary of duplicate element names. - _elements: Dictionary mapping element names to their metadata. + _widget_type : str + Type of widget ("coordinate_system", "element", or "channel") determining its behavior. + _icon : QIcon + Icon used for warning indicators for elements that might be slow to render. + _sdata : EventedList + List of SpatialData objects. + _duplicate_element_names : dict + Dictionary of duplicate element names across SpatialData objects. + _elements : dict or None + Dictionary mapping element names to their metadata. + _element_widget_text : str or None + Text of the currently selected element in the ElementWidget. + _element_dict : dict or None + Dictionary with metadata of the currently selected element. + _system : str or None + Currently selected coordinate system. """ def __init__(self, sdata: EventedList, widget_type: Literal["coordinate_system", "element", "channel"]): From d9dd413f61c8ca683d066cf78e1e41d75d5ccfaa Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 3 Jun 2025 21:04:10 +0200 Subject: [PATCH 09/12] remove require_widget wrapper Given that the code is private and we know which function is used for which widget I removed the overhead that the wrapper creates --- src/napari_spatialdata/_sdata_widgets.py | 6 ----- src/napari_spatialdata/utils/_utils.py | 32 +----------------------- 2 files changed, 1 insertion(+), 37 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 76f74fc1..36497314 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -34,7 +34,6 @@ _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping, - requires_widget_type, ) if TYPE_CHECKING: @@ -120,7 +119,6 @@ def __init__(self, sdata: EventedList, widget_type: Literal["coordinate_system", coordinate_systems.insert(0, DEFAULT_COORDINATE_SYSTEM) self.addItems(coordinate_systems) - @requires_widget_type("element") def _onCsItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: """Update the element list of an element widget when the coordinate system selection changes. @@ -135,7 +133,6 @@ def _onCsItemChange(self, selected_coordinate_system: QListWidgetItem | int | It self._set_element_widget_items(elements) self._elements = elements - @requires_widget_type("element") def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) -> None: """Populate an element widget with element items. @@ -175,7 +172,6 @@ def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) - ) self.addItem(item) - @requires_widget_type("coordinate_system") def _select_coord_sys(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: """Store the currently selected coordinate system. @@ -187,7 +183,6 @@ def _select_coord_sys(self, selected_coordinate_system: QListWidgetItem | int | """ self._system = str(selected_coordinate_system) - @requires_widget_type("channel") def _on_element_item_changed( self, sdata: SpatialData, element_widget_text: str, element_dict: dict[str, str | int] ) -> None: @@ -214,7 +209,6 @@ def _on_element_item_changed( self._element_widget_text = element_widget_text self._set_channel_widget_items(element) - @requires_widget_type("channel") def _set_channel_widget_items(self, element: DataArray | DataTree) -> None: """Populate a channel widget with channel items from the selected image element. diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 95ef9af7..50db2062 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from functools import wraps from random import randint -from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, TypeVar import numpy as np import packaging.version @@ -513,33 +513,3 @@ def block_signals(widget: QObject) -> Generator[None, None, None]: WidgetType = Literal["coordinate_system", "element", "channel"] F = TypeVar("F", bound=Callable[..., Any]) - - -def requires_widget_type(*allowed_types: WidgetType) -> Callable[[F], F]: - """ - Restrict method execution to specific widget types. - - Parameters - ---------- - *allowed_types - The widget types for which the decorated method is valid. - - Returns - ------- - Callable - A decorator function that wraps the method. - """ - - def decorator(method: F) -> F: - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - if getattr(self, "_widget_type", None) not in allowed_types: - raise RuntimeError( - f"{method.__name__} is only valid when _widget_type is one of {allowed_types}, " - f"but got '{self._widget_type}'" - ) - return method(self, *args, **kwargs) - - return cast(F, wrapper) - - return decorator From 4cc8410913104aa4e8f8ca4255be83fef2829879 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 3 Jun 2025 21:06:50 +0200 Subject: [PATCH 10/12] remove unused element attribute --- src/napari_spatialdata/_sdata_widgets.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 36497314..9ee32b53 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -81,8 +81,6 @@ class ListWidget(QListWidget): List of SpatialData objects. _duplicate_element_names : dict Dictionary of duplicate element names across SpatialData objects. - _elements : dict or None - Dictionary mapping element names to their metadata. _element_widget_text : str or None Text of the currently selected element in the ElementWidget. _element_dict : dict or None @@ -106,7 +104,6 @@ def __init__(self, sdata: EventedList, widget_type: Literal["coordinate_system", self._icon = QIcon(str(icon_path)) self._sdata = sdata self._duplicate_element_names, _ = get_duplicate_element_names(self._sdata) - self._elements: None | dict[str, dict[str, str | int]] = None self._element_widget_text: str | None = None self._element_dict: dict[str, str | int] | None = None self._system: None | str = None @@ -131,7 +128,6 @@ def _onCsItemChange(self, selected_coordinate_system: QListWidgetItem | int | It self.clear() elements, _ = get_elements_meta_mapping(self._sdata, selected_coordinate_system, self._duplicate_element_names) self._set_element_widget_items(elements) - self._elements = elements def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) -> None: """Populate an element widget with element items. From 02416a0bb6adf5d63bda182ecc8c636fb94ad4fc Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 3 Jun 2025 21:23:02 +0200 Subject: [PATCH 11/12] change widget type to bool param --- src/napari_spatialdata/_sdata_widgets.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 9ee32b53..8e289b36 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -13,7 +13,7 @@ from importlib.metadata import version from operator import itemgetter from pathlib import Path -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, cast import numpy as np import shapely @@ -89,7 +89,7 @@ class ListWidget(QListWidget): Currently selected coordinate system. """ - def __init__(self, sdata: EventedList, widget_type: Literal["coordinate_system", "element", "channel"]): + def __init__(self, sdata: EventedList, coordinate_system: bool = False): """Initialize the Widget. Parameters @@ -100,7 +100,6 @@ def __init__(self, sdata: EventedList, widget_type: Literal["coordinate_system", The type of the widget. This determines what kind of items it will show. """ super().__init__() - self._widget_type = widget_type self._icon = QIcon(str(icon_path)) self._sdata = sdata self._duplicate_element_names, _ = get_duplicate_element_names(self._sdata) @@ -108,7 +107,7 @@ def __init__(self, sdata: EventedList, widget_type: Literal["coordinate_system", self._element_dict: dict[str, str | int] | None = None self._system: None | str = None - if widget_type == "coordinate_system": + if coordinate_system: # Sort alphabetically, but keep default "global" at the top. coordinate_systems = sorted({cs for sdata in self._sdata for cs in sdata.coordinate_systems}) if DEFAULT_COORDINATE_SYSTEM in coordinate_systems: @@ -394,9 +393,9 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.setLayout(QVBoxLayout()) - self.coordinate_system_widget = ListWidget(self._sdata, "coordinate_system") - self.elements_widget = ListWidget(self._sdata, "element") - self.channel_widget = ListWidget(self._sdata, "channel") + self.coordinate_system_widget = ListWidget(self._sdata, coordinate_system=True) + self.elements_widget = ListWidget(self._sdata) + self.channel_widget = ListWidget(self._sdata) self.slider = QProgressBar(self) self.slider.setRange(0, 0) self.slider.setVisible(False) From 003bc9446d6a8da7357e60c46b7e4c73d58aa6d9 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 3 Jun 2025 22:08:04 +0200 Subject: [PATCH 12/12] readd elements for cache but remove dict --- src/napari_spatialdata/_sdata_widgets.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 8e289b36..5edce551 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -104,7 +104,7 @@ def __init__(self, sdata: EventedList, coordinate_system: bool = False): self._sdata = sdata self._duplicate_element_names, _ = get_duplicate_element_names(self._sdata) self._element_widget_text: str | None = None - self._element_dict: dict[str, str | int] | None = None + self._elements: dict[str, dict[str, str | int]] | None = None self._system: None | str = None if coordinate_system: @@ -127,6 +127,7 @@ def _onCsItemChange(self, selected_coordinate_system: QListWidgetItem | int | It self.clear() elements, _ = get_elements_meta_mapping(self._sdata, selected_coordinate_system, self._duplicate_element_names) self._set_element_widget_items(elements) + self._elements = elements def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) -> None: """Populate an element widget with element items. @@ -196,11 +197,9 @@ def _on_element_item_changed( Dictionary with metadata of the selected element. """ self.clear() - self._element_dict = None self._element_widget_text = element_widget_text if element_dict["element_type"] == "images": element: DataArray | DataTree = sdata[element_dict["original_name"]] - self._element_dict = element_dict self._element_widget_text = element_widget_text self._set_channel_widget_items(element)