diff --git a/setup.cfg b/setup.cfg index 84213d3f..3d0fd099 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = anndata click cycler - dask>=2024.4.1 + dask>=2024.4.1,<=2024.11.2 geopandas loguru matplotlib diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 9915706a..5edce551 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -1,3 +1,11 @@ +"""Widgets for displaying and interacting with SpatialData objects in napari. + +This module provides a set of Qt widgets for visualizing and interacting with +SpatialData objects within the napari viewer. It includes a ListWidget for selecting +coordinate systems, browsing elements within SpatialData objects, and handling +channel selection for multidimensional image data. +""" + from __future__ import annotations import platform @@ -18,10 +26,15 @@ from qtpy.QtWidgets import QLabel, QListWidget, QListWidgetItem, QProgressBar, QVBoxLayout, QWidget from spatialdata import SpatialData from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM +from xarray import DataArray, DataTree from napari_spatialdata._viewer import SpatialDataViewer from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD -from napari_spatialdata.utils._utils import _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping +from napari_spatialdata.utils._utils import ( + _get_sdata_key, + get_duplicate_element_names, + get_elements_meta_mapping, +) if TYPE_CHECKING: from napari import Viewer @@ -46,21 +59,87 @@ PROBLEMATIC_NUMPY_MACOS = False -class ElementWidget(QListWidget): - def __init__(self, sdata: EventedList): +class ListWidget(QListWidget): + """Widget for displaying and selecting coordinate systems or elements from SpatialData objects or channels. + + This widget can show a list of coordinate systems or available elements (images, labels, points, shapes) + from the SpatialData objects, with warnings for elements that might be slow to render. A third option is to + let it show channels from image elements. + + The widget's behavior is determined by the `widget_type` parameter passed during initialization: + - "coordinate_system": Displays available coordinate systems from SpatialData objects + - "element": Displays available elements (images, labels, points, shapes) from SpatialData objects + - "channel": Displays available channels from selected image elements + + Attributes + ---------- + _widget_type : str + Type of widget ("coordinate_system", "element", or "channel") determining its behavior. + _icon : QIcon + Icon used for warning indicators for elements that might be slow to render. + _sdata : EventedList + List of SpatialData objects. + _duplicate_element_names : dict + Dictionary of duplicate element names across SpatialData objects. + _element_widget_text : str or None + Text of the currently selected element in the ElementWidget. + _element_dict : dict or None + Dictionary with metadata of the currently selected element. + _system : str or None + Currently selected coordinate system. + """ + + def __init__(self, sdata: EventedList, coordinate_system: bool = False): + """Initialize the Widget. + + Parameters + ---------- + sdata : EventedList + List of SpatialData objects to display elements from. + widget_type: Literal["coordinate_system", "element", "channel"] + The type of the widget. This determines what kind of items it will show. + """ super().__init__() self._icon = QIcon(str(icon_path)) self._sdata = sdata self._duplicate_element_names, _ = get_duplicate_element_names(self._sdata) - self._elements: None | dict[str, dict[str, str | int]] = None + self._element_widget_text: str | None = None + self._elements: dict[str, dict[str, str | int]] | None = None + self._system: None | str = None - def _onItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + if coordinate_system: + # Sort alphabetically, but keep default "global" at the top. + coordinate_systems = sorted({cs for sdata in self._sdata for cs in sdata.coordinate_systems}) + if DEFAULT_COORDINATE_SYSTEM in coordinate_systems: + coordinate_systems.remove(DEFAULT_COORDINATE_SYSTEM) + coordinate_systems.insert(0, DEFAULT_COORDINATE_SYSTEM) + self.addItems(coordinate_systems) + + def _onCsItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + """Update the element list of an element widget when the coordinate system selection changes. + + Parameters + ---------- + selected_coordinate_system : QListWidgetItem or int or Iterable[str] + The newly selected coordinate system. + Can be a QListWidgetItem, an index, or an iterable of strings. + """ self.clear() elements, _ = get_elements_meta_mapping(self._sdata, selected_coordinate_system, self._duplicate_element_names) self._set_element_widget_items(elements) self._elements = elements def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) -> None: + """Populate an element widget with element items. + + Adds each element as an item in the list widget, with warning icons for elements + that might be slow to render (e.g., many circles or shapes). + + Parameters + ---------- + elements : dict[str, dict[str, str | int]] + Dictionary mapping element names to their metadata. + """ for key, dict_val in sorted(elements.items(), key=itemgetter(0)): sdata = self._sdata[dict_val["sdata_index"]] element_type = dict_val["element_type"] @@ -89,29 +168,104 @@ def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) - ) self.addItem(item) - -class CoordinateSystemWidget(QListWidget): - def __init__(self, sdata: EventedList): - super().__init__() - - self._sdata = sdata - self._system: None | str = None - - # Sort alphabetically, but keep default "global" at the top. - coordinate_systems = sorted(cs for sdata in self._sdata for cs in sdata.coordinate_systems) - if DEFAULT_COORDINATE_SYSTEM in coordinate_systems: - coordinate_systems.remove(DEFAULT_COORDINATE_SYSTEM) - coordinate_systems.insert(0, DEFAULT_COORDINATE_SYSTEM) - self.addItems(coordinate_systems) - def _select_coord_sys(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + """Store the currently selected coordinate system. + + Parameters + ---------- + selected_coordinate_system : QListWidgetItem or int or Iterable[str] + The selected coordinate system. + Can be a QListWidgetItem, an index, or an iterable of strings. + """ self._system = str(selected_coordinate_system) + def _on_element_item_changed( + self, sdata: SpatialData, element_widget_text: str, element_dict: dict[str, str | int] + ) -> None: + """Update the channel items in the channel widget when the selected element changes. + + Clears the current channel list and populates it with channels from the + selected element if it's an image. + + Parameters + ---------- + sdata : SpatialData + The SpatialData object containing the selected element. + element_widget_text : str + Text of the selected element in the ElementWidget. + element_dict : dict + Dictionary with metadata of the selected element. + """ + self.clear() + self._element_widget_text = element_widget_text + if element_dict["element_type"] == "images": + element: DataArray | DataTree = sdata[element_dict["original_name"]] + self._element_widget_text = element_widget_text + self._set_channel_widget_items(element) + + def _set_channel_widget_items(self, element: DataArray | DataTree) -> None: + """Populate a channel widget with channel items from the selected image element. + + Adds each channel as an item in the list widget, except for RGB(A) channels + which are handled differently. + + Parameters + ---------- + element : object + The image element to extract channels from. + """ + if isinstance(element, DataArray): + channels = list(element.c.to_numpy()) + else: + channels = list(element["scale0"].c.to_numpy()) + + if channels not in [["r", "g", "b"], ["r", "g", "b", "a"]]: + for ch in channels: + item = QListWidgetItem(str(ch)) + self.addItem(item) + class DataLoadThread(QThread): + """Thread for asynchronously loading SpatialData elements. + + This thread handles loading different types of data (images, labels, points, shapes) + from SpatialData objects without blocking the UI. + + Parameters + ---------- + parent : SdataWidget + Parent SdataWidget that owns this thread. + + Attributes + ---------- + returned : Signal + Signal emitted when data loading is complete, carrying the created layer. + sdata_widget : SdataWidget + Parent SdataWidget that owns this thread. + _data_type : str + Type of data to load (images, labels, points, shapes). + _text : str + Name of the element to load. + _sdata : SpatialData + SpatialData object containing the element. + _selected_cs : str + Selected coordinate system. + _multi : bool + Boolean indicating if multiple SpatialData objects are present. + _channel_name : str, optional + Optional channel name for image data. + """ + returned = Signal(object) def __init__(self, parent: SdataWidget): + """Initialize the DataLoadThread. + + Parameters + ---------- + parent : SdataWidget + Parent SdataWidget that owns this thread. + """ super().__init__(parent=parent) self.sdata_widget = parent self._data_type = "" @@ -120,11 +274,42 @@ def __init__(self, parent: SdataWidget): self._selected_cs: str = "" self._multi: bool = False - def load_data(self, data_type: str, text: str, sdata: SpatialData, selected_cs: str, multi: bool) -> None: + def load_data( + self, + data_type: str, + text: str, + sdata: SpatialData, + selected_cs: str, + multi: bool, + channel_name: str | None = None, + ) -> None: + """Set up data loading parameters and start the thread. + + Parameters + ---------- + data_type : str + Type of data to load (images, labels, points, shapes). + text : str + Name of the element to load. + sdata : SpatialData + SpatialData object containing the element. + selected_cs : str + Selected coordinate system. + multi : bool + Boolean indicating if multiple SpatialData objects are present. + channel_name : str, optional + Optional channel name for image data. + + Raises + ------ + RuntimeError + If the thread is already running. + """ if self.isRunning(): raise RuntimeError("Thread is already running.") self._data_type = data_type self._text = text + self._channel_name = channel_name self._sdata = sdata self._selected_cs = selected_cs self._multi = multi @@ -135,6 +320,11 @@ def load_data(self, data_type: str, text: str, sdata: SpatialData, selected_cs: self.start() def run(self) -> None: + """Execute the data loading operation. + + Loads the specified data element based on its type and emits the + returned layer through the 'returned' signal. + """ if not self._data_type: return if self._data_type == "labels": @@ -143,7 +333,7 @@ def run(self) -> None: ) elif self._data_type == "images": layer = self.sdata_widget.viewer_model.get_sdata_image( - self._sdata, self._text, self._selected_cs, self._multi + self._sdata, self._text, self._selected_cs, self._multi, self._channel_name ) elif self._data_type == "points": layer = self.sdata_widget.viewer_model.get_sdata_points( @@ -158,7 +348,41 @@ def run(self) -> None: class SdataWidget(QWidget): + """Main widget for interacting with SpatialData objects in napari. + + This widget combines coordinate system selection, element browsing, and channel + selection into a unified interface for visualizing SpatialData objects in napari. + It manages the loading and display of different data types and handles coordinate + system transformations. + + Attributes + ---------- + _sdata + List of SpatialData objects. + viewer_model + SpatialDataViewer instance for interacting with napari. + worker_thread + Thread for asynchronous data loading. + coordinate_system_widget + Widget for selecting coordinate systems. + elements_widget + Widget for browsing and selecting elements. + channel_widget + Widget for selecting channels in image data. + slider + Progress bar shown during data loading. + """ + def __init__(self, viewer: Viewer, sdata: EventedList): + """Initialize the SdataWidget. + + Parameters + ---------- + viewer : Viewer + napari Viewer instance. + sdata : EventedList + List of SpatialData objects to visualize. + """ super().__init__() self._sdata = sdata self.viewer_model = SpatialDataViewer(viewer, self._sdata) @@ -168,8 +392,9 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.setLayout(QVBoxLayout()) - self.coordinate_system_widget = CoordinateSystemWidget(self._sdata) - self.elements_widget = ElementWidget(self._sdata) + self.coordinate_system_widget = ListWidget(self._sdata, coordinate_system=True) + self.elements_widget = ListWidget(self._sdata) + self.channel_widget = ListWidget(self._sdata) self.slider = QProgressBar(self) self.slider.setRange(0, 0) self.slider.setVisible(False) @@ -179,14 +404,18 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.layout().addWidget(self.coordinate_system_widget) self.layout().addWidget(QLabel("Elements:")) self.layout().addWidget(self.elements_widget) - self.elements_widget.itemDoubleClicked.connect(self._on_click_item) + self.layout().addWidget(QLabel("Channels:")) + self.layout().addWidget(self.channel_widget) + self.elements_widget.currentItemChanged.connect(self._on_element_item_changed) + self.elements_widget.itemDoubleClicked.connect(self._on_doubleclick_element_item) + self.channel_widget.itemDoubleClicked.connect(self._on_doubleclick_channel_item) self.coordinate_system_widget.currentItemChanged.connect( - lambda item: self.elements_widget._onItemChange(item.text()) + lambda item: self.elements_widget._onCsItemChange(item.text()) ) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.coordinate_system_widget._select_coord_sys(item.text()) ) - self.viewer_model.layer_saved.connect(self.elements_widget._onItemChange) + self.viewer_model.layer_saved.connect(self.elements_widget._onCsItemChange) self.coordinate_system_widget.currentItemChanged.connect(self._update_layers_visibility) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.viewer_model._affine_transform_layers(item.text()) @@ -194,24 +423,79 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.viewer_model.viewer.layers.events.inserted.connect(self._on_insert_layer) def _on_insert_layer(self, event: Event) -> None: + """Connect visibility events for newly inserted layers. + + Parameters + ---------- + event : Event + Event containing the newly inserted layer. + """ layer = event.value layer.events.visible.connect(self._update_visible_in_coordinate_system) - def _on_click_item(self, item: QListWidgetItem) -> None: + def _on_element_item_changed(self, item: QListWidgetItem) -> None: + """Handle selection changes in the elements widget. + + Updates the channel widget with channels from the selected element. + + Parameters + ---------- + item : QListWidgetItem + The newly selected element item. + """ + if self.elements_widget._elements: + sdata, _ = _get_sdata_key(self._sdata, self.elements_widget._elements, item.text()) + self.channel_widget._on_element_item_changed( + sdata, item.text(), self.elements_widget._elements[item.text()] + ) + + def _on_doubleclick_channel_item(self, item: QListWidgetItem) -> None: + """Handle double-click events on channel items in the channel widget. + + Loads and displays the selected channel of the current element. + + Parameters + ---------- + item : QListWidgetItem + The double-clicked channel item. + """ + if self.channel_widget._element_widget_text: + self._onClick(self.channel_widget._element_widget_text, item.text()) + + def _on_doubleclick_element_item(self, item: QListWidgetItem) -> None: + """Handle double-click events on element items in the element widget. + + Loads and displays the selected element. + + Parameters + ---------- + item : QListWidgetItem + The double-clicked element item. + """ self._onClick(item.text()) def _hide_slider(self) -> None: + """Hide the progress slider when data loading is complete.""" self.slider.setVisible(False) - def _onClick(self, text: str) -> None: + def _onClick(self, element_name: str, channel_name: str | None = None) -> None: + """Handle click events to load and display data elements. + + Parameters + ---------- + element_name : str + Name of the element to load. + channel_name : str, optional + Name of the channel to load for image elements. + """ selected_cs = self.coordinate_system_widget._system if self.worker_thread.isRunning(): show_info("Please wait for the current operation to finish.") return if selected_cs and self.elements_widget._elements: - sdata, multi = _get_sdata_key(self._sdata, self.elements_widget._elements, text) - if (type_ := self.elements_widget._elements[text]["element_type"]) not in { + sdata, multi = _get_sdata_key(self._sdata, self.elements_widget._elements, element_name) + if (type_ := self.elements_widget._elements[element_name]["element_type"]) not in { "labels", "images", "shapes", @@ -221,12 +505,18 @@ def _onClick(self, text: str) -> None: type_ = cast(str, type_) - self.worker_thread.load_data(type_, text, sdata, selected_cs, multi) + self.worker_thread.load_data(type_, element_name, sdata, selected_cs, multi, channel_name) if not PROBLEMATIC_NUMPY_MACOS: self.slider.setVisible(True) def _update_visible_in_coordinate_system(self, event: Event) -> None: - """Toggle active in the coordinate system metadata when changing visibility of layer.""" + """Toggle active status in the coordinate system metadata when changing layer visibility. + + Parameters + ---------- + event : Event + Event triggered by changing layer visibility. + """ metadata = event.source.metadata layer_active = metadata.get("_active_in_cs") selected_coordinate_system = self.coordinate_system_widget._system @@ -240,7 +530,12 @@ def _update_visible_in_coordinate_system(self, event: Event) -> None: layer_active.remove(selected_coordinate_system) def _update_layers_visibility(self) -> None: - """Toggle layer visibility dependent on presence in currently selected coordinate system.""" + """Toggle layer visibility based on presence in the currently selected coordinate system. + + Updates the visibility of all layers based on whether they are active in the + currently selected coordinate system. Also updates layer metadata to track + coordinate system information. + """ elements = self.elements_widget._elements coordinate_system = self.coordinate_system_widget._system # No layer selected on first time coordinate system selection @@ -259,6 +554,33 @@ def _update_layers_visibility(self) -> None: layer.metadata["_current_cs"] = coordinate_system def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points: + """Load and create appropriate layer for shape data. + + Determines the geometry type of the shapes element and calls the appropriate + method to create either a Points layer (for Point geometries) or a Shapes + layer (for Polygon or MultiPolygon geometries). + + Parameters + ---------- + sdata : SpatialData + SpatialData object containing the shapes element. + key : str + Name of the shapes element to load. + selected_cs : str + Selected coordinate system. + multi : bool + Whether multiple SpatialData objects are present. + + Returns + ------- + Shapes or Points + The created napari layer. + + Raises + ------ + TypeError + If the geometry type is not Point, Polygon, or MultiPolygon. + """ original_name = key[: key.rfind("_")] if multi else key if type(sdata.shapes[original_name].iloc[0].geometry) is shapely.geometry.point.Point: diff --git a/src/napari_spatialdata/_view.py b/src/napari_spatialdata/_view.py index 12037729..d6e39dac 100644 --- a/src/napari_spatialdata/_view.py +++ b/src/napari_spatialdata/_view.py @@ -41,9 +41,7 @@ from napari_spatialdata._widgets import ( AListWidget, AnnDataSaveDialog, - CBarWidget, ComponentWidget, - RangeSliderWidget, SaveDialog, ScatterAnnotationDialog, ) @@ -497,13 +495,6 @@ def __init__(self, napari_viewer: Viewer, model: DataModel | None = None) -> Non self.color_by = QLabel("Colored by:") self.layout().addWidget(self.color_by) - # scalebar - colorbar = CBarWidget(model=self.model) - self.slider = RangeSliderWidget(self.viewer, self.model, colorbar=colorbar) - self._viewer.window.add_dock_widget(self.slider, area="left", name="slider") - self._viewer.window.add_dock_widget(colorbar, area="left", name="colorbar") - self.viewer.layers.selection.events.active.connect(self.slider._onLayerChange) - if (layer := self.viewer.layers.selection.active) is not None and layer.metadata.get("adata") is not None: self._on_layer_update() diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index de90a4de..cbfbc2d8 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -29,6 +29,7 @@ _get_ellipses_from_circles, _get_init_metadata_adata, _get_transform, + _obtain_channel_image, _transform_coordinates, get_duplicate_element_names, get_napari_version, @@ -442,10 +443,14 @@ def clean_worker(self) -> None: """Clean the worker.""" self.worker = None - def add_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> None: - self.add_layer(self.get_sdata_image(sdata, key, selected_cs, multi)) + def add_sdata_image( + self, sdata: SpatialData, key: str, selected_cs: str, multi: bool, channel_name: str | None = None + ) -> None: + self.add_layer(self.get_sdata_image(sdata, key, selected_cs, multi, channel_name)) - def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Image: + def get_sdata_image( + self, sdata: SpatialData, key: str, selected_cs: str, multi: bool, channel_name: str | None = None + ) -> Image: """ Add an image in a spatial data object to the viewer. @@ -465,7 +470,12 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: original_name = original_name[: original_name.rfind("_")] affine = _get_transform(sdata.images[original_name], selected_cs) - rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name]) + if channel_name: + image = _obtain_channel_image(element=sdata.images[original_name], channel_name=channel_name) + rgb = False + key = key + f"_ch:{channel_name}" + else: + image, rgb = _adjust_channels_order(element=sdata.images[original_name]) channels = ("RGB(A)",) if rgb else get_channels(sdata.images[original_name]) @@ -473,7 +483,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: # TODO: type check return Image( - rgb_image, + image, rgb=rgb, name=key, affine=affine, diff --git a/src/napari_spatialdata/_widgets.py b/src/napari_spatialdata/_widgets.py index 5431036f..aff10673 100644 --- a/src/napari_spatialdata/_widgets.py +++ b/src/napari_spatialdata/_widgets.py @@ -19,17 +19,12 @@ from qtpy import QtCore, QtWidgets from qtpy.QtCore import Qt, Signal from scanpy.plotting._utils import _set_colors_for_categorical_obs -from sklearn.preprocessing import MinMaxScaler from spatialdata._types import ArrayLike -from superqt import QRangeSlider -from vispy import scene -from vispy.color.colormap import Colormap, MatplotlibColormap -from vispy.scene.widgets import ColorBarWidget from napari_spatialdata._model import DataModel from napari_spatialdata.utils._utils import _min_max_norm, get_napari_version -__all__ = ["AListWidget", "CBarWidget", "RangeSliderWidget", "ComponentWidget"] +__all__ = ["AListWidget", "ComponentWidget"] # label string: attribute name # TODO(giovp): remove since layer controls private? @@ -409,193 +404,6 @@ def attr(self, field: str | None) -> None: self._attr = field -class CBarWidget(QtWidgets.QWidget): - FORMAT = "{0:0.2f}" - - cmapChanged = Signal(str) - climChanged = Signal((float, float)) - - def __init__( - self, - model: DataModel, - cmap: str = "viridis", - label: str | None = None, - width: int | None = 250, - height: int | None = 50, - **kwargs: Any, - ): - super().__init__(**kwargs) - - self._model = model - - self._clim = (0.0, 1.0) - self._oclim = self._clim - - self._width = width - self._height = height - self._label = label - - self.__init_UI() - - def __init_UI(self) -> None: - self.setFixedWidth(self._width) - self.setFixedHeight(self._height) - - # use napari's BG color for dark mode - self._canvas = scene.SceneCanvas( - size=(self._width, self._height), bgcolor="#262930", parent=self, decorate=False, resizable=False, dpi=150 - ) - self._colorbar = ColorBarWidget( - self._create_colormap(self.cmap), - orientation="top", - label=self._label, - label_color="white", - clim=self.getClim(), - border_width=1.0, - border_color="black", - padding=(0.3, 0.167), - axis_ratio=0.05, - ) - - self._canvas.central_widget.add_widget(self._colorbar) - - self.climChanged.connect(self.onClimChanged) - self.cmapChanged.connect(self.onCmapChanged) - - def _create_colormap(self, cmap: str) -> Colormap: - ominn, omaxx = self.getOclim() - delta = omaxx - ominn + 1e-12 - - minn, maxx = self.getClim() - minn = (minn - ominn) / delta - maxx = (maxx - ominn) / delta - - assert 0 <= minn <= 1, f"Expected `min` to be in `[0, 1]`, found `{minn}`" - assert 0 <= maxx <= 1, f"Expected `maxx` to be in `[0, 1]`, found `{maxx}`" - - cm = MatplotlibColormap(cmap) - - return Colormap(cm[np.linspace(minn, maxx, len(cm.colors))], interpolation="linear") - - def getCmap(self) -> str: - return self.cmap - - def onCmapChanged(self, value: str) -> None: - # this does not trigger update for some reason... - self._colorbar.cmap = self._create_colormap(value) - self._colorbar._colorbar._update() - - def setClim(self, value: tuple[float, float]) -> None: - if value == self._clim: - return - - self._clim = value - self.climChanged.emit(*value) - - def getClim(self) -> tuple[float, float]: - return self._clim - - def getOclim(self) -> tuple[float, float]: - return self._oclim - - def setOclim(self, value: tuple[float, float]) -> None: - # original color limit used for 0-1 normalization - self._oclim = value - - def onClimChanged(self, minn: float, maxx: float) -> None: - # ticks are not working with vispy's colorbar - self._colorbar.cmap = self._create_colormap(self.cmap) - self._colorbar.clim = (self.FORMAT.format(minn), self.FORMAT.format(maxx)) - - def getCanvas(self) -> scene.SceneCanvas: - return self._canvas - - def getColorBar(self) -> ColorBarWidget: - return self._colorbar - - def setLayout(self, layout: QtWidgets.QLayout) -> None: - layout.addWidget(self.getCanvas().native) - super().setLayout(layout) - - def update_color(self) -> None: - # when changing selected layers that have the same limit - # could also trigger it as self._colorbar.clim = self.getClim() - # but the above option also updates geometry - # cbarwidget->cbar->cbarvisual - self._colorbar._colorbar._colorbar._update() - - @property - def cmap(self) -> str: - return self._model.cmap - - -class RangeSliderWidget(QRangeSlider): - def __init__(self, viewer: Viewer, model: DataModel, colorbar: CBarWidget, **kwargs: Any): - super().__init__(**kwargs) - - self._viewer = viewer - self._model = model - self._colorbar = colorbar - self._cmap = plt.get_cmap(self._colorbar.cmap) - self.setValue((0, 100)) - self.setSliderPosition((0, 100)) - self.setSingleStep(0.01) - self.setOrientation(Qt.Horizontal) - self.valueChanged.connect(self._onValueChange) - - def _onLayerChange(self) -> None: - layer = self.viewer.layers.selection.active - if layer is not None: - self._onValueChange((0, 100)) - - def _onValueChange(self, percentile: tuple[float, float]) -> None: - layer = self.viewer.layers.selection.active - # TODO(michalk8): use constants - if "data" not in layer.metadata: - return None # noqa: RET501 - v = layer.metadata["data"] - # this code is currently not used since the slider is not enabled; so I silenced the mypy error; 2. there is a - # mismatch for this error with the mypy in the CI, so I silenced the unused-ignore from the local mypy. - # when this code is re-enabled, let's fix mypy - clipped = np.clip(v, *np.percentile(v, percentile)) # type: ignore[misc,unused-ignore] - - if isinstance(layer, Points): - layer.metadata = {**layer.metadata, "perc": percentile} - layer.face_color = "value" - layer.properties = {"value": clipped} - layer.refresh_colors() - elif isinstance(layer, Labels): - norm_vec = self._scale_vec(clipped) - color_vec = self._cmap(norm_vec) - layer.color = dict(zip(layer.color.keys(), color_vec, strict=False)) - layer.properties = {"value": clipped} - layer.refresh() - - self._colorbar.setOclim(layer.metadata["minmax"]) - self._colorbar.setClim((np.min(layer.properties["value"]), np.max(layer.properties["value"]))) - self._colorbar.update_color() - - def _scale_vec(self, vec: ArrayLike) -> ArrayLike: - ominn, omaxx = self._colorbar.getOclim() - delta = omaxx - ominn + 1e-12 - - minn, maxx = self._colorbar.getClim() - minn = (minn - ominn) / delta - maxx = (maxx - ominn) / delta - scaler = MinMaxScaler(feature_range=(minn, maxx)) - return scaler.fit_transform(vec.reshape(-1, 1)) - - @property - def viewer(self) -> napari.Viewer: - """:mod:`napari` viewer.""" - return self._viewer - - @property - def model(self) -> DataModel: - """:mod:`napari` viewer.""" - return self._model - - class SaveDialog(QtWidgets.QDialog): def __init__(self, layer: Layer, table_name: str) -> None: super().__init__() diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 237b84b7..50db2062 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from functools import wraps from random import randint -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, TypeVar import numpy as np import packaging.version @@ -44,7 +44,7 @@ from napari.utils.events import EventedList from qtpy.QtWidgets import QListWidgetItem - from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget + from napari_spatialdata._sdata_widgets import ListWidget from spatialdata._types import ArrayLike @@ -221,6 +221,31 @@ def _points_inside_triangles(points: ArrayLike, triangles: ArrayLike) -> ArrayLi return out +def _datatree_to_dataarray_list(new_raster: DataArray | DataTree) -> DataArray | list[DataArray]: + if isinstance(new_raster, DataTree): + list_of_xdata = [] + for k in new_raster: + v = new_raster[k].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + list_of_xdata.append(xdata) + return list_of_xdata + return new_raster + + +def _obtain_channel_image(element: DataArray | DataTree, channel_name: str | int) -> DataArray | list[DataArray]: + is_multiscale_int_ch = isinstance(element, DataTree) and np.issubdtype( + element["scale0"].c.to_numpy().dtype, np.integer + ) + is_int_ch = isinstance(element, DataArray) and np.issubdtype(element.c.to_numpy().dtype, np.integer) + if isinstance(channel_name, str) and (is_multiscale_int_ch or is_int_ch): + channel_name = int(channel_name) + + # works for both DataArray and DataTree + new_raster = element.sel(c=channel_name) + return _datatree_to_dataarray_list(new_raster) + + def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | list[DataArray], bool]: """Swap the axes to y, x, c and check if an image supports rgb(a) visualization. @@ -264,14 +289,7 @@ def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | l rgb = False new_raster = element - if isinstance(new_raster, DataTree): - list_of_xdata = [] - for k in new_raster: - v = new_raster[k].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - list_of_xdata.append(xdata) - new_raster = list_of_xdata + new_raster = _datatree_to_dataarray_list(new_raster) return new_raster, rgb @@ -387,9 +405,7 @@ def _get_init_metadata_adata(sdata: SpatialData, table_name: str | None, element return adata -def get_itemindex_by_text( - list_widget: CoordinateSystemWidget | ElementWidget, item_text: str -) -> None | QListWidgetItem: +def get_itemindex_by_text(list_widget: ListWidget, item_text: str) -> None | QListWidgetItem: """ Get the item in a listwidget based on its text. @@ -493,3 +509,7 @@ def block_signals(widget: QObject) -> Generator[None, None, None]: yield finally: widget.blockSignals(False) + + +WidgetType = Literal["coordinate_system", "element", "channel"] +F = TypeVar("F", bound=Callable[..., Any]) diff --git a/tests/conftest.py b/tests/conftest.py index bb658ec1..ac397b03 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ from spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.datasets import blobs -from spatialdata.models import TableModel +from spatialdata.models import Image2DModel, TableModel from napari_spatialdata.utils._test_utils import save_image, take_screenshot @@ -131,6 +131,18 @@ def sdata_blobs() -> SpatialData: return blobs() +@pytest.fixture() +def sdata_channel_images() -> SpatialData: + sdata = blobs() + sdata["blobs_image_str_ch"] = Image2DModel.parse( + sdata["blobs_image"], c_coords=["channel1", "channel2", "channel3"] + ) + sdata["blobs_multiscale_image_str_ch"] = Image2DModel.parse( + sdata["blobs_image"], c_coords=["channel1", "channel2", "channel3"], scale_factors=[2, 2] + ) + return sdata + + @pytest.fixture def image(): _, image = _get_blobs_galaxy() diff --git a/tests/test_spatialdata.py b/tests/test_spatialdata.py index 00c53007..2efcc02e 100644 --- a/tests/test_spatialdata.py +++ b/tests/test_spatialdata.py @@ -21,7 +21,7 @@ from xarray import DataArray, DataTree from napari_spatialdata import QtAdataViewWidget -from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget, SdataWidget +from napari_spatialdata._sdata_widgets import ListWidget, SdataWidget from napari_spatialdata.constants import config from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem @@ -30,10 +30,10 @@ def test_elementwidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): _ = make_napari_viewer() - widget = ElementWidget(EventedList([blobs_extra_cs])) + widget = ListWidget(EventedList([blobs_extra_cs]), "element") assert widget._sdata is not None assert not widget._elements - widget._onItemChange("global") + widget._onCsItemChange("global") assert widget._elements for name in blobs_extra_cs.images: assert widget._elements[name]["element_type"] == "images" @@ -47,7 +47,7 @@ def test_elementwidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): def test_coordinatewidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): _ = make_napari_viewer() - widget = CoordinateSystemWidget(EventedList([blobs_extra_cs])) + widget = ListWidget(EventedList([blobs_extra_cs]), "coordinate_system") items = [widget.item(x).text() for x in range(widget.count())] assert len(items) == len(blobs_extra_cs.coordinate_systems) for item in items: @@ -59,13 +59,13 @@ def test_sdatawidget_images(make_napari_viewer: Any, blobs_extra_cs: SpatialData widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.images.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert isinstance(widget.viewer_model.viewer.layers[0], Image) assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.images.keys())[0] blobs_extra_cs.images["image"] = to_multiscale(blobs_extra_cs.images["blobs_image"], [2, 4]) - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("image") assert len(widget.viewer_model.viewer.layers) == 2 @@ -73,12 +73,65 @@ def test_sdatawidget_images(make_napari_viewer: Any, blobs_extra_cs: SpatialData del blobs_extra_cs.images["image"] +@pytest.mark.parametrize( + "images", [["blobs_image", "blobs_image_str_ch"], ["blobs_multiscale_image", "blobs_multiscale_image_str_ch"]] +) +def test_channel_selection(qtbot, make_napari_viewer, sdata_channel_images, images): + """Test selecting a channel from an image with integer channel names.""" + # Create a viewer + viewer = make_napari_viewer() + + # Create the SdataWidget + widget = SdataWidget(viewer, EventedList([sdata_channel_images])) + + # Click on 'global' coordinate system + center_pos = get_center_pos_listitem(widget.coordinate_system_widget, "global") + click_list_widget_item(qtbot, widget.coordinate_system_widget, center_pos, "currentItemChanged") + + # Click on the image element to populate the channel widget + center_pos = get_center_pos_listitem(widget.elements_widget, images[0]) + click_list_widget_item(qtbot, widget.elements_widget, center_pos, "currentItemChanged") + + # Verify that the channel widget has been populated with the correct channels + assert widget.channel_widget.count() == 3 + assert widget.channel_widget.item(0).text() == "0" + assert widget.channel_widget.item(1).text() == "1" + assert widget.channel_widget.item(2).text() == "2" + + # Double-click on a channel to add it as a layer + center_pos = get_center_pos_listitem(widget.channel_widget, "1") + click_list_widget_item(qtbot, widget.channel_widget, center_pos, "currentItemChanged", "double") + widget._onClick(images[0], "1") + + # Verify that the layer has been added with the correct name and data + assert len(viewer.layers) == 1 + assert viewer.layers[0].name == f"{images[0]}_ch:1" + + # Verify that the layer contains only the selected channel + assert viewer.layers[0].data.shape == (512, 512) + + center_pos = get_center_pos_listitem(widget.elements_widget, images[1]) + click_list_widget_item(qtbot, widget.elements_widget, center_pos, "currentItemChanged") + + assert widget.channel_widget.count() == 3 + assert widget.channel_widget.item(0).text() == "channel1" + assert widget.channel_widget.item(1).text() == "channel2" + assert widget.channel_widget.item(2).text() == "channel3" + + center_pos = get_center_pos_listitem(widget.channel_widget, "channel2") + click_list_widget_item(qtbot, widget.channel_widget, center_pos, "currentItemChanged", "double") + widget._onClick(images[1], "channel2") + + assert len(viewer.layers) == 2 + assert viewer.layers[1].name == f"{images[1]}_ch:channel2" + + def test_sdatawidget_labels(qtbot, make_napari_viewer: Any, blobs_extra_cs: SpatialData): viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.labels.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.labels.keys())[0] @@ -119,7 +172,7 @@ def test_sdatawidget_points(caplog, make_napari_viewer: Any, blobs_extra_cs: Spa widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.points.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.points.keys())[0]