diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 93f335e625b..a72317b7fc0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,6 +14,10 @@ v2026.05.0 (unreleased) New Features ~~~~~~~~~~~~ +- Support reading and writing Zarr V3 arrays with rectilinear (variable-sized) + chunk grids. Requires zarr-python >= 3.2 with + ``zarr.config.set({"array.rectilinear_chunks": True})``. (:pull:`11279`). + By `Max Jones `_. Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/backends/chunks.py b/xarray/backends/chunks.py index c255c7db591..4a4a2363073 100644 --- a/xarray/backends/chunks.py +++ b/xarray/backends/chunks.py @@ -1,3 +1,5 @@ +import itertools + import numpy as np from xarray.core.datatree import Variable @@ -133,9 +135,12 @@ def align_nd_chunks( def build_grid_chunks( size: int, - chunk_size: int, + chunk_size: int | tuple[int, ...], region: slice | None = None, ) -> tuple[int, ...]: + if isinstance(chunk_size, (list, tuple)): + return _build_rectilinear_grid_chunks(chunk_size, region) + if region is None: region = slice(0, size) @@ -153,9 +158,39 @@ def build_grid_chunks( return tuple(chunks_on_region) +def _build_rectilinear_grid_chunks( + chunk_sizes: tuple[int, ...], + region: slice | None = None, +) -> tuple[int, ...]: + """Build grid chunks for a rectilinear dimension within a region.""" + if region is None or region == slice(None): + return tuple(chunk_sizes) + + region_start = region.start or 0 + region_stop = region.stop or sum(chunk_sizes) + + boundaries = [0] + for cs in chunk_sizes: + boundaries.append(boundaries[-1] + cs) + + result = [] + for i in range(len(chunk_sizes)): + chunk_start = boundaries[i] + chunk_end = boundaries[i + 1] + + if chunk_end <= region_start or chunk_start >= region_stop: + continue + + effective_start = max(chunk_start, region_start) + effective_end = min(chunk_end, region_stop) + result.append(effective_end - effective_start) + + return tuple(result) + + def grid_rechunk( v: Variable, - enc_chunks: tuple[int, ...], + encoding_chunks: tuple[int, ...] | tuple[int | tuple[int, ...], ...], region: tuple[slice, ...], ) -> Variable: nd_v_chunks = v.chunks @@ -169,7 +204,7 @@ def grid_rechunk( chunk_size=chunk_size, ) for v_size, chunk_size, interval in zip( - v.shape, enc_chunks, region, strict=True + v.shape, encoding_chunks, region, strict=True ) ) @@ -181,9 +216,36 @@ def grid_rechunk( return v +def _validate_rectilinear_chunk_alignment( + dask_chunks: tuple[int, ...], + encoding_chunks: tuple[int, ...], + axis: int, + name: str, + region: slice = slice(None), +) -> None: + """Validate dask chunks align with rectilinear encoding chunk boundaries.""" + encoding_stops = set(itertools.accumulate(encoding_chunks)) + region_start = region.start or 0 + dask_stops = {region_start + s for s in itertools.accumulate(dask_chunks)} + # The final stop (total size) always matches — exclude it + total = sum(encoding_chunks) + encoding_stops.discard(total) + dask_stops.discard(total) + bad = dask_stops - encoding_stops + if bad: + raise ValueError( + f"Specified rectilinear encoding chunks {encoding_chunks!r} for variable " + f"named {name!r} would overlap multiple Dask chunks on axis {axis}. " + f"Dask chunk boundaries at positions {sorted(bad)} do not align with " + f"encoding chunk boundaries at {sorted(encoding_stops)}. " + "Writing this array in parallel with Dask could lead to corrupted data. " + "Consider rechunking using `chunk()` or setting `safe_chunks=False`." + ) + + def validate_grid_chunks_alignment( nd_v_chunks: tuple[tuple[int, ...], ...] | None, - enc_chunks: tuple[int, ...], + enc_chunks: tuple[int | tuple[int, ...], ...], backend_shape: tuple[int, ...], region: tuple[slice, ...], allow_partial_chunks: bool, @@ -205,7 +267,7 @@ def validate_grid_chunks_alignment( "- Enable automatic chunks alignment with `align_chunks=True`." ) - for axis, chunk_size, v_chunks, interval, size in zip( + for axis, enc_chunk, v_chunks, interval, size in zip( range(len(enc_chunks)), enc_chunks, nd_v_chunks, @@ -213,6 +275,19 @@ def validate_grid_chunks_alignment( backend_shape, strict=True, ): + if isinstance(enc_chunk, (list, tuple)): + # Rectilinear dimension — use boundary-based validation + _validate_rectilinear_chunk_alignment( + dask_chunks=v_chunks, + encoding_chunks=enc_chunk, + axis=axis, + name=name, + region=interval, + ) + continue + + # Regular dimension — existing validation logic + chunk_size = enc_chunk for i, chunk in enumerate(v_chunks[1:-1]): if chunk % chunk_size: raise ValueError( diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index d9279dc2de9..6491e97cf5e 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,6 +1,8 @@ from __future__ import annotations import base64 +import functools +import importlib.util import json import os import struct @@ -46,6 +48,28 @@ from xarray.core.types import ZarrArray, ZarrGroup +@functools.cache +def _has_unified_chunk_grid() -> bool: + """Check if zarr has the unified ChunkGrid with is_regular support. + + Defers the actual import so zarr stays lazy at module load time. + """ + if importlib.util.find_spec("zarr.core.chunk_grids") is None: + return False + from zarr.core.chunk_grids import ChunkGrid + + return hasattr(ChunkGrid, "is_regular") + + +def _is_regular_chunk_spec(chunks: tuple) -> bool: + """True when *chunks* is a flat tuple of ints (regular chunk grid). + + Returns False for rectilinear specs where at least one element is a + sequence of per-chunk edge lengths. + """ + return all(isinstance(c, int) for c in chunks) + + def _get_mappers(*, storage_options, store, chunk_store): # expand str and path-like arguments store = _normalize_path(store) @@ -333,7 +357,7 @@ async def async_getitem(self, key): ) -def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): +def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, zarr_format): """ Given encoding chunks (possibly None or []) and variable chunks (possibly None or []). @@ -355,18 +379,34 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): # while dask chunks can be variable sized # https://dask.pydata.org/en/latest/array-design.html#chunks if var_chunks and not enc_chunks: + if zarr_format == 3 and _has_unified_chunk_grid(): + # Check if dask chunks are regular (uniform except for last chunk) + has_varying_interior = any( + len(set(chunks[:-1])) > 1 for chunks in var_chunks + ) + has_larger_final = any(chunks[0] < chunks[-1] for chunks in var_chunks) + if has_varying_interior or has_larger_final: + # Truly rectilinear — return dask-style tuples of per-chunk sizes. + # Requires zarr config: array.rectilinear_chunks = True + return tuple(var_chunks) + # Regular chunks — return the first chunk size per dimension + return tuple(chunk[0] for chunk in var_chunks) + if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks): raise ValueError( - "Zarr requires uniform chunk sizes except for final chunk. " + "Zarr v2 requires uniform chunk sizes except for the final chunk. " f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. " - "Consider rechunking using `chunk()`." + "Consider rechunking using `chunk()`, or switching to the " + "zarr v3 format with zarr-python>=3.2." ) if any((chunks[0] < chunks[-1]) for chunks in var_chunks): raise ValueError( - "Final chunk of Zarr array must be the same size or smaller " - f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}." - "Consider either rechunking using `chunk()` or instead deleting " - "or modifying `encoding['chunks']`." + "The final chunk of a Zarr v2 array or a Zarr v3 array without the " + "rectilinear chunks extension must be the same size or smaller " + f"than the first. Variable named {name!r} has incompatible Dask " + f"chunks {var_chunks!r}. " + "Consider switching to Zarr v3 with the rectilinear chunks extension, " + "rechunking using `chunk()` or deleting or modifying `encoding['chunks']`." ) # return the first chunk for each dimension return tuple(chunk[0] for chunk in var_chunks) @@ -389,8 +429,17 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): var_chunks, ndim, name, + zarr_format, ) + # Rectilinear chunks: each element is a sequence of per-chunk edge lengths + if ( + zarr_format == 3 + and _has_unified_chunk_grid() + and any(not isinstance(x, int) for x in enc_chunks_tuple) + ): + return enc_chunks_tuple + for x in enc_chunks_tuple: if not isinstance(x, int): raise TypeError( @@ -532,6 +581,7 @@ def extract_zarr_variable_encoding( var_chunks=variable.chunks, ndim=variable.ndim, name=name, + zarr_format=zarr_format, ) if _zarr_v3() and chunks is None: chunks = "auto" @@ -910,9 +960,27 @@ def open_store_variable(self, name): ) attributes = dict(attributes) + if _has_unified_chunk_grid() and zarr_array.metadata.zarr_format == 3: + from zarr.core.metadata.v3 import ( + RectilinearChunkGridMetadata, + RegularChunkGridMetadata, + ) + + chunk_grid = zarr_array.metadata.chunk_grid + if isinstance(chunk_grid, RegularChunkGridMetadata): + chunks = chunk_grid.chunk_shape + elif isinstance(chunk_grid, RectilinearChunkGridMetadata): + chunks = chunk_grid.chunk_shapes + else: + chunks = tuple(zarr_array.chunks) + preferred_chunks = dict(zip(dimensions, chunks, strict=True)) + else: + chunks = tuple(zarr_array.chunks) + preferred_chunks = dict(zip(dimensions, chunks, strict=True)) + encoding = { - "chunks": zarr_array.chunks, - "preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)), + "chunks": chunks, + "preferred_chunks": preferred_chunks, } if _zarr_v3(): @@ -1300,7 +1368,7 @@ def set_variables( if self._align_chunks and isinstance(effective_write_chunks, tuple): v = grid_rechunk( v=v, - enc_chunks=effective_write_chunks, + encoding_chunks=effective_write_chunks, region=region, ) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e42bfc2cd9f..ad1e881557c 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2973,9 +2973,18 @@ def test_chunk_encoding_with_dask(self) -> None: # should fail if dask_chunks are irregular... ds_chunk_irreg = ds.chunk({"x": (5, 4, 3)}) - with pytest.raises(ValueError, match=r"uniform chunk sizes."): - with self.roundtrip(ds_chunk_irreg) as actual: - pass + if ( + backends.zarr._has_unified_chunk_grid() + and zarr.config.config["default_zarr_format"] == 3 + ): + # zarr v3 with unified chunk grid supports rectilinear chunks + with zarr.config.set({"array.rectilinear_chunks": True}): + with self.roundtrip(ds_chunk_irreg) as actual: + pass + else: + with pytest.raises(ValueError, match=r"uniform chunk sizes."): + with self.roundtrip(ds_chunk_irreg) as actual: + pass # should fail if encoding["chunks"] clashes with dask_chunks badenc = ds.chunk({"x": 4}) @@ -7299,6 +7308,281 @@ def test_extract_zarr_variable_encoding() -> None: ) +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_encoding_roundtrip(tmp_path: Path) -> None: + """Rectilinear chunk sizes in encoding are passed through to zarr v3.""" + + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + chunk_sizes = [10, 20, 30] + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)}) + + store_path = tmp_path / "rectilinear.zarr" + encoding = {"var": {"chunks": [chunk_sizes]}} + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w", encoding=encoding) + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == tuple(chunk_sizes) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_no_encoding(tmp_path: Path) -> None: + """Variable dask chunks are written as rectilinear when no encoding is given.""" + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + chunk_sizes = [15, 25, 20] + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)}) + + store_path = tmp_path / "rectilinear_no_enc.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == tuple(chunk_sizes) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_multidim(tmp_path: Path) -> None: + """Rectilinear chunks on a multi-dimensional array.""" + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + data = np.arange(120, dtype="float64").reshape(6, 20) + ds = xr.Dataset({"var": xr.Variable(("x", "y"), data)}).chunk( + {"x": (2, 4), "y": (5, 10, 5)} + ) + + store_path = tmp_path / "rectilinear_2d.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == (2, 4) + assert roundtrip.chunks["y"] == (5, 10, 5) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_mixed_dims(tmp_path: Path) -> None: + """One dimension regular, another rectilinear.""" + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + data = np.arange(60, dtype="float32").reshape(3, 20) + ds = xr.Dataset({"var": xr.Variable(("x", "y"), data)}).chunk( + {"x": 3, "y": (5, 10, 5)} + ) + + store_path = tmp_path / "mixed_chunks.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == (3,) + assert roundtrip.chunks["y"] == (5, 10, 5) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_interop(tmp_path: Path) -> None: + """Read rectilinear array created directly by zarr.""" + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + store_path = tmp_path / "zarr_native.zarr" + data = np.arange(60, dtype="float32") + + with zarr.config.set({"array.rectilinear_chunks": True}): + root = zarr.open_group(store_path, mode="w", zarr_format=3) + arr = root.create( + "var", + shape=(60,), + # zarr stubs don't include rectilinear chunk types yet + chunks=((10, 20, 30),), # type: ignore[arg-type] + dtype="float32", + dimension_names=("x",), + ) + arr[:] = data + + roundtrip = xr.open_zarr(store_path, zarr_format=3, consolidated=False) + assert roundtrip.chunks["x"] == (10, 20, 30) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_safe_chunks_fail(tmp_path: Path) -> None: + """Misaligned dask chunks should raise when safe_chunks=True.""" + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": (15, 15, 30)}) + + store_path = tmp_path / "safe_chunks_fail.zarr" + encoding = {"var": {"chunks": [(10, 20, 30)]}} + + with zarr.config.set({"array.rectilinear_chunks": True}): + with pytest.raises(ValueError, match=r"rectilinear.*overlap"): + ds.to_zarr( + store_path, + zarr_format=3, + mode="w", + encoding=encoding, + safe_chunks=True, + ) + + +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_region_write(tmp_path: Path) -> None: + """Write to a region of a rectilinear chunked array.""" + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + chunk_sizes = (10, 20, 30) + data = np.zeros(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": chunk_sizes}) + + store_path = tmp_path / "region.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + # Overwrite just the second chunk (positions 10..30) + update = np.arange(20, dtype="float32") + 100 + ds_update = xr.Dataset({"var": xr.Variable("x", update)}).chunk({"x": (20,)}) + ds_update.to_zarr( + store_path, zarr_format=3, region={"x": slice(10, 30)}, mode="r+" + ) + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + expected = data.copy() + expected[10:30] = update + np.testing.assert_array_equal(roundtrip["var"].values, expected) + assert roundtrip.chunks["x"] == chunk_sizes + + +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_encoding_roundtrip_rewrite(tmp_path: Path) -> None: + """Read a rectilinear array and write it back preserving chunks.""" + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + chunk_sizes = (10, 20, 30) + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": chunk_sizes}) + + path1 = tmp_path / "source.zarr" + path2 = tmp_path / "dest.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(path1, zarr_format=3, mode="w") + + loaded = xr.open_zarr(path1, zarr_format=3) + loaded.to_zarr(path2, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(path2, zarr_format=3) + assert roundtrip.chunks["x"] == chunk_sizes + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +def test_validate_grid_chunks_alignment_rectilinear_pass() -> None: + """Dask chunks that align with rectilinear zarr boundaries should pass.""" + from xarray.backends.chunks import validate_grid_chunks_alignment + + validate_grid_chunks_alignment( + nd_v_chunks=((10, 20, 30),), + enc_chunks=((10, 20, 30),), + region=(slice(None),), + allow_partial_chunks=True, + name="var", + backend_shape=(60,), + ) + + # Dask chunks are coarser (merging zarr chunks is fine) + validate_grid_chunks_alignment( + nd_v_chunks=((30, 30),), + enc_chunks=((10, 20, 30),), + region=(slice(None),), + allow_partial_chunks=True, + name="var", + backend_shape=(60,), + ) + + +def test_validate_grid_chunks_alignment_rectilinear_fail() -> None: + """Dask chunks that split a rectilinear zarr chunk should raise.""" + from xarray.backends.chunks import validate_grid_chunks_alignment + + with pytest.raises(ValueError, match=r"rectilinear.*overlap"): + validate_grid_chunks_alignment( + nd_v_chunks=((15, 15, 30),), + enc_chunks=((10, 20, 30),), + region=(slice(None),), + allow_partial_chunks=True, + name="var", + backend_shape=(60,), + ) + + +def test_build_grid_chunks_rectilinear_full() -> None: + """build_grid_chunks with rectilinear spec and no region returns the spec.""" + from xarray.backends.chunks import build_grid_chunks + + result = build_grid_chunks(size=60, chunk_size=(10, 20, 30)) + assert result == (10, 20, 30) + + +def test_build_grid_chunks_rectilinear_region() -> None: + """build_grid_chunks with rectilinear spec and a region clips to region.""" + from xarray.backends.chunks import build_grid_chunks + + result = build_grid_chunks(size=45, chunk_size=(10, 20, 30), region=slice(15, 60)) + assert result == (15, 30) + + +def test_build_grid_chunks_rectilinear_region_mid() -> None: + """Region that starts and ends mid-chunk.""" + from xarray.backends.chunks import build_grid_chunks + + result = build_grid_chunks(size=40, chunk_size=(10, 20, 30), region=slice(5, 45)) + assert result == (5, 20, 15) + + @requires_zarr @requires_fsspec @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") diff --git a/xarray/tests/test_backends_chunks.py b/xarray/tests/test_backends_chunks.py index bb1297d0db3..684f142e7ad 100644 --- a/xarray/tests/test_backends_chunks.py +++ b/xarray/tests/test_backends_chunks.py @@ -110,7 +110,7 @@ def test_grid_rechunk(enc_chunks, region, nd_v_chunks, expected_chunks): result = grid_rechunk( arr.variable, - enc_chunks=enc_chunks, + encoding_chunks=enc_chunks, region=region, ) assert result.chunks == expected_chunks