From cc227dd68391bfba84c66b23781ddc059892d1e3 Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Tue, 23 Sep 2025 13:53:09 +0200 Subject: [PATCH 1/2] spatial query rechunk --- src/spatialdata/_core/query/_utils.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index e011d318c..f4232dee2 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -135,25 +135,11 @@ def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: result = DataTree.from_dict(d) # rechunk the data to avoid irregular chunks - coords = list(result["scale0"].coords.keys()) - result = result.chunk(dict.fromkeys(coords, "auto")) - - from dask.array.core import _check_regular_chunks - - # check that the rechunking into regular chunks worked - chunks_still_irregular = False for scale in result: - data = result[scale]["image"].data - chunks_still_irregular = chunks_still_irregular or not _check_regular_chunks(data.chunks) + result[scale]["image"].data = result[scale]["image"].data.rechunk(result[scale]["image"].data.chunksize) - if chunks_still_irregular: - # reported here: https://github.com/scverse/spatialdata/issues/821#issuecomment-2632201695 - # seemingly due to this bug: https://github.com/dask/dask/issues/11713 - CHUNK_SIZE = 1024 - rechunk_strategy = dict.fromkeys(coords, CHUNK_SIZE) - if "c" in coords: - rechunk_strategy["c"] = result["scale0"]["image"].chunks[0][0] - result = result.chunk(rechunk_strategy) + # sanity check + from dask.array.core import _check_regular_chunks for scale in result: data = result[scale]["image"].data @@ -172,7 +158,7 @@ def _process_query_result( if 0 in result.shape: return None # rechunk the data to avoid irregular chunks - result = result.chunk("auto") + result.data = result.data.rechunk(result.data.chunksize) elif isinstance(result, DataTree): result = _process_data_tree_query_result(result) if result is None: From 8c1bb52cda12eb2bb0f00b2500df687527ac218f Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Mon, 29 Sep 2025 14:06:25 +0200 Subject: [PATCH 2/2] check regular chunks before rechunking --- src/spatialdata/_core/query/_utils.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index f4232dee2..5ab904d96 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -134,31 +134,37 @@ def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: d = {k: Dataset({"image": d[k]}) for k in scales_to_keep} result = DataTree.from_dict(d) - # rechunk the data to avoid irregular chunks - for scale in result: - result[scale]["image"].data = result[scale]["image"].data.rechunk(result[scale]["image"].data.chunksize) - - # sanity check from dask.array.core import _check_regular_chunks + # rechunk to avoid irregular chunks for scale in result: data = result[scale]["image"].data - assert _check_regular_chunks(data.chunks), ( - f"Chunks are not regular for the {scale} of the queried data: {data.chunks}. Please report this bug." - ) + chunks = data.chunks + if not _check_regular_chunks(chunks): + data = data.rechunk(data.chunksize) + if not _check_regular_chunks(data.chunks): + raise ValueError( + f"Chunks are not regular for {scale} of the queried data: {chunks} " + "and could also not be rechunked regularly. Please report this bug." + ) + result[scale]["image"].data = data + return result def _process_query_result( result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...] ) -> DataArray | DataTree | None: + from dask.array.core import _check_regular_chunks + from spatialdata.transformations import get_transformation, set_transformation if isinstance(result, DataArray): if 0 in result.shape: return None - # rechunk the data to avoid irregular chunks - result.data = result.data.rechunk(result.data.chunksize) + # rechunk to avoid irregular chunks + if not _check_regular_chunks(result.data.chunks): + result.data = result.data.rechunk(result.data.chunksize) elif isinstance(result, DataTree): result = _process_data_tree_query_result(result) if result is None: