diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index e011d318c..5ab904d96 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -134,45 +134,37 @@ def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: d = {k: Dataset({"image": d[k]}) for k in scales_to_keep} result = DataTree.from_dict(d) - # rechunk the data to avoid irregular chunks - coords = list(result["scale0"].coords.keys()) - result = result.chunk(dict.fromkeys(coords, "auto")) - from dask.array.core import _check_regular_chunks - # check that the rechunking into regular chunks worked - chunks_still_irregular = False + # rechunk to avoid irregular chunks for scale in result: data = result[scale]["image"].data - chunks_still_irregular = chunks_still_irregular or not _check_regular_chunks(data.chunks) - - if chunks_still_irregular: - # reported here: https://github.com/scverse/spatialdata/issues/821#issuecomment-2632201695 - # seemingly due to this bug: https://github.com/dask/dask/issues/11713 - CHUNK_SIZE = 1024 - rechunk_strategy = dict.fromkeys(coords, CHUNK_SIZE) - if "c" in coords: - rechunk_strategy["c"] = result["scale0"]["image"].chunks[0][0] - result = result.chunk(rechunk_strategy) + chunks = data.chunks + if not _check_regular_chunks(chunks): + data = data.rechunk(data.chunksize) + if not _check_regular_chunks(data.chunks): + raise ValueError( + f"Chunks are not regular for {scale} of the queried data: {chunks} " + "and could also not be rechunked regularly. Please report this bug." + ) + result[scale]["image"].data = data - for scale in result: - data = result[scale]["image"].data - assert _check_regular_chunks(data.chunks), ( - f"Chunks are not regular for the {scale} of the queried data: {data.chunks}. Please report this bug." - ) return result def _process_query_result( result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...] ) -> DataArray | DataTree | None: + from dask.array.core import _check_regular_chunks + from spatialdata.transformations import get_transformation, set_transformation if isinstance(result, DataArray): if 0 in result.shape: return None - # rechunk the data to avoid irregular chunks - result = result.chunk("auto") + # rechunk to avoid irregular chunks + if not _check_regular_chunks(result.data.chunks): + result.data = result.data.rechunk(result.data.chunksize) elif isinstance(result, DataTree): result = _process_data_tree_query_result(result) if result is None: