diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1003fddf1..ec753d95f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,7 +4,7 @@ on: push: branches: [main] tags: - - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 + - "v*" pull_request: branches: "*" @@ -13,26 +13,24 @@ jobs: runs-on: ${{ matrix.os }} defaults: run: - shell: bash -e {0} # -e to fail on error + shell: bash -e {0} strategy: fail-fast: false matrix: - python: ["3.11", "3.13"] - os: [ubuntu-latest] include: - - os: macos-latest - python: "3.11" - - os: macos-latest - python: "3.12" - pip-flags: "--pre" - name: "Python 3.12 (pre-release)" - - os: windows-latest - python: "3.11" - + - {os: windows-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"} + - {os: windows-latest, python: "3.11", dask-version: "latest", name: "Dask latest"} + - {os: ubuntu-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"} + - {os: ubuntu-latest, python: "3.11", dask-version: "latest", name: "Dask latest"} + - {os: ubuntu-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} + - {os: macos-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"} + - {os: macos-latest, python: "3.11", dask-version: "latest", name: "Dask latest"} + - {os: macos-latest, python: "3.12", pip-flags: "--pre", name: "Python 3.12 (pre-release)"} env: OS: ${{ matrix.os }} PYTHON: ${{ matrix.python }} + DASK_VERSION: ${{ matrix.dask-version }} steps: - uses: actions/checkout@v2 @@ -42,7 +40,15 @@ jobs: version: "latest" python-version: ${{ matrix.python }} - name: Install dependencies - run: "uv sync --extra test" + run: | + uv sync --extra test + if [[ -n "${DASK_VERSION}" ]]; then + if [[ "${DASK_VERSION}" == "latest" ]]; then + uv pip install --upgrade dask + else + uv pip install dask==${DASK_VERSION} + fi + fi - name: Test env: MPLBACKEND: agg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f181578ea..1e54869b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,7 @@ repos: rev: v3.5.3 hooks: - id: prettier + exclude: ^.github/workflows/test.yaml - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.15.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 766f75152..1e1c6d426 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "anndata>=0.9.1", "click", "dask-image", - "dask>=2024.10.0,<=2024.11.2", + "dask>=2025.2.0", "datashader", "fsspec[s3,http]", "geopandas>=0.14", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 5d84e172b..2fb483505 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -1,20 +1,7 @@ -import dask - -dask.config.set({"dataframe.query-planning": False}) -import dask.dataframe as dd - -# Setting `dataframe.query-planning` to False is effective only if run before `dask.dataframe` is initialized. In -# the case in which the user had initilized `dask.dataframe` before, we would have DASK_EXPER_ENABLED set to `True`. -# Here we check that this does not happen. -if hasattr(dd, "DASK_EXPR_ENABLED") and dd.DASK_EXPR_ENABLED: - raise RuntimeError( - "Unsupported backend: dask-expr has been detected as the backend of dask.dataframe. Please " - "use:\nimport dask\ndask.config.set({'dataframe.query-planning': False})\nbefore importing " - "dask.dataframe to disable dask-expr. The support is being worked on, for more information please see" - "https://github.com/scverse/spatialdata/pull/570" - ) from importlib.metadata import version +import spatialdata.models._accessor # noqa: F401 + __version__ = version("spatialdata") __all__ = [ diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index 6a5b43367..8b8c0b5ce 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -94,9 +94,12 @@ def _(gdf: GeoDataFrame) -> GeoDataFrame: @deepcopy.register(DaskDataFrame) def _(df: DaskDataFrame) -> DaskDataFrame: # bug: the parser may change the order of the columns - new_ddf = PointsModel.parse(df.compute().copy(deep=True)) + compute_df = df.compute().copy(deep=True) + new_ddf = PointsModel.parse(compute_df) # the problem is not .copy(deep=True), but the parser, which discards some metadata https://github.com/scverse/spatialdata/issues/503#issuecomment-2015275322 - new_ddf.attrs = _deepcopy(df.attrs) + # We need to use the compute_df here as with deepcopy, df._attrs does not exist anymore. + # print(type(new_ddf.attrs)) + new_ddf.attrs.update(_deepcopy(compute_df.attrs)) return new_ddf diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index a075aeb38..6da0a7cc8 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -653,12 +653,14 @@ def rasterize_shapes_points( table_name = table_name if table_name is not None else "table" + index = False if value_key is not None: kwargs = {"sdata": sdata, "element_name": element_name} if element_name is not None else {"element": data} data[VALUES_COLUMN] = get_values(value_key, table_name=table_name, **kwargs).iloc[:, 0] # type: ignore[arg-type, union-attr] elif isinstance(data, GeoDataFrame) or isinstance(data, DaskDataFrame) and return_regions_as_labels is True: value_key = VALUES_COLUMN data[VALUES_COLUMN] = data.index.astype("category") + index = True else: value_key = VALUES_COLUMN data[VALUES_COLUMN] = 1 @@ -666,7 +668,13 @@ def rasterize_shapes_points( label_index_to_category = None if VALUES_COLUMN in data and data[VALUES_COLUMN].dtype == "category": if isinstance(data, DaskDataFrame): - data[VALUES_COLUMN] = data[VALUES_COLUMN].cat.as_known() + # We have to do this because as_known() does not preserve the order anymore in latest dask versions + # TODO discuss whether we can always expect the index from before to be monotonically increasing, because + # then we don't have to check order. + if index: + data[VALUES_COLUMN] = data[VALUES_COLUMN].cat.set_categories(data.index, ordered=True) + else: + data[VALUES_COLUMN] = data[VALUES_COLUMN].cat.as_known() label_index_to_category = dict(enumerate(data[VALUES_COLUMN].cat.categories, start=1)) if return_single_channel is None: diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index da56dc391..8340f23e7 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -3,11 +3,12 @@ import itertools import warnings from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import dask.array as da import dask_image.ndinterp import numpy as np +import pandas as pd from dask.array.core import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame @@ -432,18 +433,20 @@ def _( xtransformed = transformation._transform_coordinates(xdata) transformed = data.drop(columns=list(axes)).copy() # dummy transformation that will be replaced by _adjust_transformation() - transformed.attrs[TRANSFORM_KEY] = {DEFAULT_COORDINATE_SYSTEM: Identity()} - # TODO: the following line, used in place of the line before, leads to an incorrect aggregation result. Look into - # this! Reported here: ... - # transformed.attrs = {TRANSFORM_KEY: {DEFAULT_COORDINATE_SYSTEM: Identity()}} - assert isinstance(transformed, DaskDataFrame) + default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()} + transformed.attrs[TRANSFORM_KEY] = default_cs + for ax in axes: indices = xtransformed["dim"] == ax new_ax = xtransformed[:, indices] - transformed[ax] = new_ax.data.flatten() + # TODO: discuss with dask team + # This is not nice, but otherwise there is a problem with the joint graph of new_ax and transformed, causing + # a getattr missing dependency of dependent from_dask_array. + new_col = pd.Series(new_ax.data.flatten().compute(), index=transformed.index) + transformed[ax] = new_col + + old_transformations = cast(dict[str, Any], get_transformation(data, get_all=True)) - old_transformations = get_transformation(data, get_all=True) - assert isinstance(old_transformations, dict) _set_transformation_for_transformed_elements( transformed, old_transformations, diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index e6dccb458..36b40748a 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -672,14 +672,24 @@ def _( max_coordinate=max_coordinate_intrinsic, ) - # assert that the number of bounding boxes is correct - assert len(in_intrinsic_bounding_box) == len(min_coordinate) + if not (len_df := len(in_intrinsic_bounding_box)) == (len_bb := len(min_coordinate)): + raise ValueError(f"Number of dataframes `{len_df}` is not equal to the number of bounding boxes `{len_bb}`.") points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = [] + points_pd = points.compute() + attrs = points.attrs.copy() for mask in in_intrinsic_bounding_box: if mask.sum() == 0: points_in_intrinsic_bounding_box.append(None) else: - points_in_intrinsic_bounding_box.append(points.loc[mask]) + # TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now. + # we can't compute either mask or points as when we calculate either one of them + # test_query_points_multiple_partitions will fail as the mask will be used to index each partition. + # However, if we compute and then create the dask array again we get the mixed dask graph problem. + mask_np = mask.compute() + filtered_pd = points_pd[mask_np] + points_filtered = dd.from_pandas(filtered_pd, npartitions=points.npartitions) + points_filtered.attrs.update(attrs) + points_in_intrinsic_bounding_box.append(points_filtered) if len(points_in_intrinsic_bounding_box) == 0: return None diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 3d6a9ed06..f92bc9f54 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -13,8 +13,7 @@ import zarr from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from dask.dataframe import read_parquet -from dask.delayed import Delayed +from dask.dataframe import Scalar, read_parquet from geopandas import GeoDataFrame from shapely import MultiPolygon, Polygon from xarray import DataArray, DataTree @@ -1985,9 +1984,7 @@ def h(s: str) -> str: else: shape_str = ( "(" - + ", ".join( - [(str(dim) if not isinstance(dim, Delayed) else "") for dim in v.shape] - ) + + ", ".join([(str(dim) if not isinstance(dim, Scalar) else "") for dim in v.shape]) + ")" ) descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} with shape: {shape_str} {dim_string}" diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 20c236275..a8e194a7b 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -14,6 +14,7 @@ import zarr from anndata import AnnData +from dask._task_spec import Task from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame @@ -301,6 +302,19 @@ def _get_backing_files(element: DaskArray | DaskDataFrame) -> list[str]: return files +def _find_piece_dict(obj: dict[str, tuple[str | None]] | Task) -> dict[str, tuple[str | None | None]] | None: + """Recursively search for dict containing the key 'piece' in Dask task specs containing the parquet file path.""" + if isinstance(obj, dict): + if "piece" in obj: + return obj + elif hasattr(obj, "args"): # Handles dask._task_spec.* objects like Task and List + for v in obj.args: + result = _find_piece_dict(v) + if result is not None: + return result + return None + + def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> None: # see the types allowed for the dask graph here: https://docs.dask.org/en/stable/spec.html @@ -327,25 +341,31 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No path = getattr(v.store, "path", None) if getattr(v.store, "path", None) else v.store.root files.append(str(UPath(path).resolve())) elif name.startswith("read-parquet") or name.startswith("read_parquet"): - if hasattr(v, "creation_info"): - # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L625 - t = v.creation_info["args"] - if not isinstance(t, tuple) or len(t) != 1: - raise ValueError( - f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please " - f"report this bug." - ) - parquet_file = t[0] - files.append(str(UPath(parquet_file).resolve())) - elif isinstance(v, tuple) and len(v) > 1 and isinstance(v[1], dict) and "piece" in v[1]: + # Here v is a read_parquet task with arguments and the only value is a dictionary. + if "piece" in v.args[0]: # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L870 - parquet_file, check0, check1 = v[1]["piece"] + parquet_file, check0, check1 = v.args[0]["piece"] if not parquet_file.endswith(".parquet") or check0 is not None or check1 is not None: raise ValueError( f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please " f"report this bug." ) files.append(os.path.realpath(parquet_file)) + else: + # This occurs when for example points and images are mixed, the main task still starts with + # read_parquet, but the execution happens through a subgraph which we iterate over to get the + # actual read_parquet task. + for task in v.args[0].values(): + # Recursively go through tasks, this is required because differences between dask versions. + piece_dict = _find_piece_dict(task) + if isinstance(piece_dict, dict) and "piece" in piece_dict: + parquet_file, check0, check1 = piece_dict["piece"] # type: ignore[misc] + if not parquet_file.endswith(".parquet") or check0 is not None or check1 is not None: + raise ValueError( + f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please " + f"report this bug." + ) + files.append(os.path.realpath(parquet_file)) def _backed_elements_contained_in_path(path: Path, object: SpatialData | SpatialElement | AnnData) -> list[bool]: diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index ad2b22274..bc8206db1 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -347,7 +347,9 @@ def _write_raster_datatree( compute=False, ) # Compute all pyramid levels at once to allow Dask to optimize the computational graph. - da.compute(*dask_delayed) + # Optimize_graph is set to False for now as this causes permission denied errors when during atomic writes + # os.replace is called. These can also be alleviated by using 'single-threaded' scheduler. + da.compute(*dask_delayed, optimize_graph=False) trans_group = group["labels"][element_name] if raster_type == "labels" else group overwrite_coordinate_transformations_raster( diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 63c137cdc..ea38d739b 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -365,7 +365,7 @@ def blobs_annotating_element(name: BlobsTypes) -> SpatialData: instance_id = get_element_instances(sdata[name]).tolist() else: index = sdata[name].index - instance_id = index.compute().tolist() if isinstance(index, dask.dataframe.core.Index) else index.tolist() + instance_id = index.compute().tolist() if isinstance(index, dask.dataframe.Index) else index.tolist() n = len(instance_id) new_table = AnnData(shape=(n, 0), obs={"region": pd.Categorical([name] * n), "instance_id": instance_id}) new_table = TableModel.parse(new_table, region=name, region_key="region", instance_key="instance_id") diff --git a/src/spatialdata/models/_accessor.py b/src/spatialdata/models/_accessor.py new file mode 100644 index 000000000..a8b19653d --- /dev/null +++ b/src/spatialdata/models/_accessor.py @@ -0,0 +1,140 @@ +from collections.abc import Iterator, MutableMapping +from typing import Any, Literal + +from dask.dataframe import DataFrame as DaskDataFrame +from dask.dataframe import Series as DaskSeries +from dask.dataframe.extensions import ( + register_dataframe_accessor, + register_series_accessor, +) + + +@register_dataframe_accessor("attrs") +@register_series_accessor("attrs") +class AttrsAccessor(MutableMapping[str, str | dict[str, Any]]): + """Accessor that stores a dict of arbitrary metadata on Dask objects.""" + + def __init__(self, dask_obj: DaskDataFrame | DaskSeries): + self._obj = dask_obj + if not hasattr(dask_obj, "_attrs"): + dask_obj._attrs = {} + + def __getitem__(self, key: str) -> Any: + return self._obj._attrs[key] + + def __setitem__(self, key: str, value: str | dict[str, Any]) -> None: + self._obj._attrs[key] = value + + def __delitem__(self, key: str) -> None: + del self._obj._attrs[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._obj._attrs) + + def __len__(self) -> int: + return len(self._obj._attrs) + + def __repr__(self) -> str: + return repr(self._obj._attrs) + + def __str__(self) -> str: + return str(self._obj._attrs) + + def copy(self) -> Any: + return self._obj._attrs.copy() + + @property + def data(self) -> Any: + """Access the raw internal attrs dict.""" + return self._obj._attrs + + +def wrap_method_with_attrs(method_name: str, dask_class: type[DaskDataFrame] | type[DaskSeries]) -> None: + """Wrap a Dask DataFrame method to preserve _attrs. + + Copies _attrs from self before calling method, then assigns to result. + Safe for lazy operations like set_index, assign, map_partitions. + """ + original_method = getattr(dask_class, method_name) + + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if not isinstance(self.attrs, AttrsAccessor): + raise RuntimeError( + "Invalid .attrs: expected an accessor (`AttrsAccessor`), " + f"got {type(self.attrs).__name__}. A common cause is assigning a dict, e.g. " + "my_dd_object.attrs = {...}. Do not assign to 'attrs'; use " + "my_dd_object.attrs.update(...) instead." + ) + + old_attrs = self.attrs.copy() + result = original_method(self, *args, **kwargs) + # the pandas Index do not have attrs, but dd.Index, since they are a subclass of dd.Series, do have attrs + # thanks to our accessor. Here we ensure that we do not assign attrs to pd.Index objects. + if hasattr(result, "attrs"): + result.attrs.update(old_attrs) + + return result + + setattr(dask_class, method_name, wrapper) + + +def wrap_indexer_with_attrs( + indexer_name: Literal["loc", "iloc"], dask_class: type[DaskDataFrame] | type[DaskSeries] +) -> None: + """Patch dd.DataFrame or dd.Series loc or iloc to preserve _attrs. + + Reason for having this separate from methods is because both loc and iloc are a property that return an indexer. + Therefore, they have to be wrapped differently from methods in order to preserve attrs. + """ + original_property = getattr(dask_class, indexer_name) # this is a property + + def indexer_with_attrs(self: DaskDataFrame | DaskSeries) -> Any: + parent_obj = self + indexer = original_property.fget(parent_obj) + + class IndexerWrapper: + def __init__(self, parent_indexer: Any, parent_obj: DaskDataFrame | DaskSeries) -> None: + self._parent_indexer = parent_indexer + self._parent_obj = parent_obj + + def __getitem__(self, key: str) -> Any: + result = self._parent_indexer[key] + if hasattr(self._parent_obj, "attrs"): + result._attrs = self._parent_obj.attrs.copy() + return result + + def __setitem__(self, key: str, value: Any) -> DaskDataFrame | DaskSeries: + # preserve attrs even if user assigns via .loc + self._parent_indexer[key] = value + return self._parent_obj + + def __repr__(self) -> str: + return repr(self._parent_indexer) + + return IndexerWrapper(indexer, parent_obj) + + setattr(dask_class, indexer_name, property(indexer_with_attrs)) + + +for method_name in [ + "__getitem__", + "compute", + "copy", + "drop", + "map_partitions", + "set_index", +]: + wrap_method_with_attrs(method_name=method_name, dask_class=DaskDataFrame) + +for method_name in [ + "__getitem__", + "compute", + "copy", + "map_partitions", +]: + wrap_method_with_attrs(method_name=method_name, dask_class=DaskSeries) + +wrap_indexer_with_attrs(indexer_name="loc", dask_class=DaskDataFrame) +wrap_indexer_with_attrs(indexer_name="iloc", dask_class=DaskDataFrame) +wrap_indexer_with_attrs(indexer_name="loc", dask_class=DaskSeries) +# DaskSeries do not have iloc diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 60f4ee205..bed33ff1d 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -808,9 +808,7 @@ def _( sort=sort, **kwargs, ) - # we cannot compute the divisions whne the index is not monotonically increasing and npartitions > 1 - if not table.known_divisions and (sort or table.npartitions == 1): - table.divisions = table.compute_current_divisions() + # TODO: dask does not allow for setting divisions directly anymore. We have to decide on forcing the user. if feature_key is not None: feature_categ = dd.from_pandas( data[feature_key].astype(str).astype("category"), diff --git a/src/spatialdata/transformations/_utils.py b/src/spatialdata/transformations/_utils.py index 44f9998b9..6d3b2c1a4 100644 --- a/src/spatialdata/transformations/_utils.py +++ b/src/spatialdata/transformations/_utils.py @@ -54,8 +54,6 @@ def _set_transformations_to_element(element: Any, transformations: MappingToCoor if TRANSFORM_KEY not in attrs: attrs[TRANSFORM_KEY] = {} attrs[TRANSFORM_KEY] = transformations - # this calls an eventual setter in the element class; modifying the attrs directly would not trigger the setter - element.attrs = attrs @singledispatch diff --git a/tests/conftest.py b/tests/conftest.py index 775721253..c931b3f88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,3 @@ -import dask - -dask.config.set({"dataframe.query-planning": False}) from collections.abc import Sequence from pathlib import Path from typing import Any diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index 25f3c3d0f..a2ffde3d4 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -52,10 +52,10 @@ def _rasterize(element: DataArray | DataTree, element_name: str, **kwargs) -> Da def _get_data_of_largest_scale(raster): if isinstance(raster, DataArray): - return raster.data.compute() + return raster.data xdata = get_pyramid_levels(raster, n=0) - return xdata.data.compute() + return xdata.data for element_name, raster in rasters.items(): dims = get_axes_names(raster) @@ -63,6 +63,9 @@ def _get_data_of_largest_scale(raster): slices = [all_slices[d] for d in dims] data = _get_data_of_largest_scale(raster) + # The line above before returned a numpy array. Setting the indices of the slice to 1 would previously update + # also raster, but since dask 2025.2.0 this does not happen anymore. However, we can just set the slice to 1 + # on the dask array. data[tuple(slices)] = 1 for kwargs in [ diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index fc59d0698..d7147dbfb 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -648,7 +648,10 @@ def _query( @pytest.mark.parametrize("with_polygon_query", [True, False]) def test_query_points_multiple_partitions(points, with_polygon_query: bool): p0 = points["points_0"] - p1 = PointsModel.parse(dd.from_pandas(p0.compute(), npartitions=10)) + attrs = p0.attrs.copy() + ddf = dd.from_pandas(p0.compute(), npartitions=10) + ddf.attrs.update(attrs) + p1 = PointsModel.parse(ddf) def _query(p: DaskDataFrame) -> DaskDataFrame: if with_polygon_query: @@ -669,7 +672,6 @@ def _query(p: DaskDataFrame) -> DaskDataFrame: q0 = _query(p0) q1 = _query(p1) assert np.array_equal(q0.index.compute(), q1.index.compute()) - pass @pytest.mark.parametrize("with_polygon_query", [True, False]) diff --git a/tests/io/test_pyramids_performance.py b/tests/io/test_pyramids_performance.py index 7f234800e..875879541 100644 --- a/tests/io/test_pyramids_performance.py +++ b/tests/io/test_pyramids_performance.py @@ -83,5 +83,9 @@ def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_p actual_num_chunk_writes = zarr_chunk_write_spy.call_count actual_num_chunk_reads = zarr_chunk_read_spy.call_count - assert actual_num_chunk_writes == num_chunks_all_scales.item() + # https://github.com/dask/dask/pull/11736 introduces an extra write of the last chunk when finalizing. + assert actual_num_chunk_writes in { + num_chunks_all_scales.item(), + num_chunks_all_scales.item() + 1, + } assert actual_num_chunk_reads == num_chunks_scale0.item() diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 8501687ca..6e948f519 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -229,8 +229,7 @@ def _workaround1_dask_backed( del sdata[new_name] sdata.delete_element_from_disk(new_name) - # @pytest.mark.parametrize("dask_backed", [True, False]) - @pytest.mark.parametrize("dask_backed", [True]) + @pytest.mark.parametrize("dask_backed", [True, False]) @pytest.mark.parametrize("workaround", [1, 2]) def test_incremental_io_on_disk( self, diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 0a430704f..57bfe6e42 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -107,6 +107,7 @@ def test_backing_files_combining_points_and_images(points, images): images1 = read_zarr(f1) p0 = points0.points["points_0"] + im1 = images1.images["image2d"] v = p0["x"].loc[0].values v.compute_chunk_sizes() diff --git a/tests/models/test_accessor.py b/tests/models/test_accessor.py new file mode 100644 index 000000000..7356f52cc --- /dev/null +++ b/tests/models/test_accessor.py @@ -0,0 +1,207 @@ +import dask.dataframe as dd +import pandas as pd +import pytest + +from spatialdata.models._accessor import AttrsAccessor + +# ============================================================================ +# General tests +# ============================================================================ + + +def test_dataframe_attrs_is_accessor(): + """Test that DataFrame.attrs is an AttrsAccessor, not a dict.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=2) + assert isinstance(df.attrs, AttrsAccessor) + + +def test_series_attrs_is_accessor(): + """Test that Series.attrs is an AttrsAccessor, not a dict.""" + s = dd.from_pandas(pd.Series([1, 2, 3], name="test"), npartitions=2) + assert isinstance(s.attrs, AttrsAccessor) + + +def test_attrs_setitem_getitem(): + """Test setting and getting attrs.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=2) + df.attrs["key"] = "value" + assert df.attrs["key"] == "value" + + +def test_attrs_update(): + """Test that attrs.update() works.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=2) + df.attrs.update({"key1": "value1", "key2": "value2"}) + assert df.attrs["key1"] == "value1" + assert df.attrs["key2"] == "value2" + + +def test_invalid_attrs_assignment_raises(): + """Test that assigning a dict to attrs raises an error on next operation.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), npartitions=2) + + # This is the wrong way to do it + df.attrs = {"key": "value"} + + # Should raise RuntimeError on next wrapped operation + with pytest.raises(RuntimeError, match="Invalid .attrs.*expected an accessor"): + df.set_index("a") + + +def test_chained_operations(): + """Test that attrs survive chained operations.""" + df = dd.from_pandas( + pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "c": [9, 10, 11, 12]}), + npartitions=2, + ) + df.attrs["experiment"] = "test123" + + result = df.set_index("a").drop("c", axis=1)[["b"]].copy() + + assert result.attrs["experiment"] == "test123" + assert isinstance(result.attrs, AttrsAccessor) + + +# ============================================================================ +# DataFrame wrapped methods tests +# ============================================================================ + + +def test_dataframe_getitem_preserves_attrs(): + """Test that DataFrame.__getitem__ preserves attrs.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), npartitions=2) + df.attrs["key"] = "value" + + # Single column (returns Series) + result = df["a"] + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + # Multiple columns (returns DataFrame) + result = df[["a", "b"]] + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + +def test_dataframe_compute_preserves_attrs(): + """Test that DataFrame.compute preserves attrs.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), npartitions=2) + df.attrs["key"] = "value" + result = df.compute() + # compute returns a pandas DataFrame, which has attrs as a dict + assert result.attrs["key"] == "value" + + +def test_dataframe_copy_preserves_attrs(): + """Test that DataFrame.copy preserves attrs.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), npartitions=2) + df.attrs["key"] = "value" + result = df.copy() + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + +def test_dataframe_drop_preserves_attrs(): + """Test that DataFrame.drop preserves attrs.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), npartitions=2) + df.attrs["key"] = "value" + result = df.drop("b", axis=1) + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + +def test_dataframe_map_partitions_preserves_attrs(): + """Test that DataFrame.map_partitions preserves attrs.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), npartitions=2) + df.attrs["key"] = "value" + result = df.map_partitions(lambda x: x * 2) + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + +def test_dataframe_set_index_preserves_attrs(): + """Test that DataFrame.set_index preserves attrs.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), npartitions=2) + df.attrs["key"] = "value" + result = df.set_index("a") + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + +# ============================================================================ +# Series wrapped methods tests +# ============================================================================ + + +def test_series_getitem_preserves_attrs(): + """Test that Series.__getitem__ preserves attrs.""" + s = dd.from_pandas(pd.Series([1, 2, 3, 4, 5], name="test"), npartitions=2) + s.attrs["key"] = "value" + result = s[1:3] + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + +def test_series_compute_preserves_attrs(): + """Test that Series.compute preserves attrs.""" + s = dd.from_pandas(pd.Series([1, 2, 3], name="test"), npartitions=2) + s.attrs["key"] = "value" + result = s.compute() + # compute returns a pandas Series, which has attrs as a dict + assert result.attrs["key"] == "value" + + +def test_series_copy_preserves_attrs(): + """Test that Series.copy preserves attrs.""" + s = dd.from_pandas(pd.Series([1, 2, 3], name="test"), npartitions=2) + s.attrs["key"] = "value" + result = s.copy() + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + +def test_series_map_partitions_preserves_attrs(): + """Test that Series.map_partitions preserves attrs.""" + s = dd.from_pandas(pd.Series([1, 2, 3], name="test"), npartitions=2) + s.attrs["key"] = "value" + result = s.map_partitions(lambda x: x * 2) + assert result.attrs["key"] == "value" + assert isinstance(result.attrs, AttrsAccessor) + + +# ============================================================================ +# Indexer tests +# ============================================================================ + + +def test_dataframe_loc_preserves_attrs(): + """Test that DataFrame.loc preserves attrs.""" + df = dd.from_pandas( + pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[10, 20, 30]), + npartitions=2, + ) + df.attrs["key"] = "value" + result = df.loc[10:20] + assert result.attrs["key"] == "value" + + +def test_dataframe_iloc_preserves_attrs(): + """Test that DataFrame.iloc preserves attrs.""" + df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), npartitions=2) + df.attrs["key"] = "value" + result = df.iloc[:, 0:1] + assert result.attrs["key"] == "value" + + +def test_series_loc_preserves_attrs(): + """Test that Series.loc preserves attrs.""" + s = dd.from_pandas( + pd.Series([1, 2, 3, 4, 5], index=[10, 20, 30, 40, 50], name="test"), + npartitions=2, + ) + s.attrs["key"] = "value" + result = s.loc[10:30] + assert result.attrs["key"] == "value" + + +# dd.Series do not have .iloc, hence there is no test_series_iloc_preserves_attrs() test