Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions src/spatialdata/_core/query/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,45 +134,37 @@ def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None:
d = {k: Dataset({"image": d[k]}) for k in scales_to_keep}
result = DataTree.from_dict(d)

# rechunk the data to avoid irregular chunks
coords = list(result["scale0"].coords.keys())
result = result.chunk(dict.fromkeys(coords, "auto"))

from dask.array.core import _check_regular_chunks

# check that the rechunking into regular chunks worked
chunks_still_irregular = False
# rechunk to avoid irregular chunks
for scale in result:
data = result[scale]["image"].data
chunks_still_irregular = chunks_still_irregular or not _check_regular_chunks(data.chunks)

if chunks_still_irregular:
# reported here: https://github.com/scverse/spatialdata/issues/821#issuecomment-2632201695
# seemingly due to this bug: https://github.com/dask/dask/issues/11713
CHUNK_SIZE = 1024
rechunk_strategy = dict.fromkeys(coords, CHUNK_SIZE)
if "c" in coords:
rechunk_strategy["c"] = result["scale0"]["image"].chunks[0][0]
result = result.chunk(rechunk_strategy)
chunks = data.chunks
if not _check_regular_chunks(chunks):
data = data.rechunk(data.chunksize)
if not _check_regular_chunks(data.chunks):
raise ValueError(
f"Chunks are not regular for {scale} of the queried data: {chunks} "
"and could also not be rechunked regularly. Please report this bug."
)
result[scale]["image"].data = data

for scale in result:
data = result[scale]["image"].data
assert _check_regular_chunks(data.chunks), (
f"Chunks are not regular for the {scale} of the queried data: {data.chunks}. Please report this bug."
)
return result


def _process_query_result(
result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...]
) -> DataArray | DataTree | None:
from dask.array.core import _check_regular_chunks

from spatialdata.transformations import get_transformation, set_transformation

if isinstance(result, DataArray):
if 0 in result.shape:
return None
# rechunk the data to avoid irregular chunks
result = result.chunk("auto")
# rechunk to avoid irregular chunks
if not _check_regular_chunks(result.data.chunks):
result.data = result.data.rechunk(result.data.chunksize)
elif isinstance(result, DataTree):
result = _process_data_tree_query_result(result)
if result is None:
Expand Down
Loading