Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ v2026.05.0 (unreleased)
New Features
~~~~~~~~~~~~

- Support reading and writing Zarr V3 arrays with rectilinear (variable-sized)
chunk grids. Requires zarr-python >= 3.2 with
``zarr.config.set({"array.rectilinear_chunks": True})``. (:pull:`11279`).
By `Max Jones <https://github.com/maxrjones>`_.

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
85 changes: 80 additions & 5 deletions xarray/backends/chunks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

import numpy as np

from xarray.core.datatree import Variable
Expand Down Expand Up @@ -133,9 +135,12 @@ def align_nd_chunks(

def build_grid_chunks(
size: int,
chunk_size: int,
chunk_size: int | tuple[int, ...],
region: slice | None = None,
) -> tuple[int, ...]:
if isinstance(chunk_size, (list, tuple)):
return _build_rectilinear_grid_chunks(chunk_size, region)

if region is None:
region = slice(0, size)

Expand All @@ -153,9 +158,39 @@ def build_grid_chunks(
return tuple(chunks_on_region)


def _build_rectilinear_grid_chunks(
chunk_sizes: tuple[int, ...],
region: slice | None = None,
) -> tuple[int, ...]:
"""Build grid chunks for a rectilinear dimension within a region."""
if region is None or region == slice(None):
return tuple(chunk_sizes)

region_start = region.start or 0
region_stop = region.stop or sum(chunk_sizes)

boundaries = [0]
for cs in chunk_sizes:
boundaries.append(boundaries[-1] + cs)

result = []
for i in range(len(chunk_sizes)):
chunk_start = boundaries[i]
chunk_end = boundaries[i + 1]

if chunk_end <= region_start or chunk_start >= region_stop:
continue

effective_start = max(chunk_start, region_start)
effective_end = min(chunk_end, region_stop)
result.append(effective_end - effective_start)

return tuple(result)


def grid_rechunk(
v: Variable,
enc_chunks: tuple[int, ...],
encoding_chunks: tuple[int, ...] | tuple[int | tuple[int, ...], ...],
region: tuple[slice, ...],
) -> Variable:
nd_v_chunks = v.chunks
Expand All @@ -169,7 +204,7 @@ def grid_rechunk(
chunk_size=chunk_size,
)
for v_size, chunk_size, interval in zip(
v.shape, enc_chunks, region, strict=True
v.shape, encoding_chunks, region, strict=True
)
)

Expand All @@ -181,9 +216,36 @@ def grid_rechunk(
return v


def _validate_rectilinear_chunk_alignment(
dask_chunks: tuple[int, ...],
encoding_chunks: tuple[int, ...],
axis: int,
name: str,
region: slice = slice(None),
) -> None:
"""Validate dask chunks align with rectilinear encoding chunk boundaries."""
encoding_stops = set(itertools.accumulate(encoding_chunks))
region_start = region.start or 0
dask_stops = {region_start + s for s in itertools.accumulate(dask_chunks)}
# The final stop (total size) always matches — exclude it
total = sum(encoding_chunks)
encoding_stops.discard(total)
dask_stops.discard(total)
bad = dask_stops - encoding_stops
if bad:
raise ValueError(
f"Specified rectilinear encoding chunks {encoding_chunks!r} for variable "
f"named {name!r} would overlap multiple Dask chunks on axis {axis}. "
f"Dask chunk boundaries at positions {sorted(bad)} do not align with "
f"encoding chunk boundaries at {sorted(encoding_stops)}. "
"Writing this array in parallel with Dask could lead to corrupted data. "
"Consider rechunking using `chunk()` or setting `safe_chunks=False`."
)


def validate_grid_chunks_alignment(
nd_v_chunks: tuple[tuple[int, ...], ...] | None,
enc_chunks: tuple[int, ...],
enc_chunks: tuple[int | tuple[int, ...], ...],
backend_shape: tuple[int, ...],
region: tuple[slice, ...],
allow_partial_chunks: bool,
Expand All @@ -205,14 +267,27 @@ def validate_grid_chunks_alignment(
"- Enable automatic chunks alignment with `align_chunks=True`."
)

for axis, chunk_size, v_chunks, interval, size in zip(
for axis, enc_chunk, v_chunks, interval, size in zip(
range(len(enc_chunks)),
enc_chunks,
nd_v_chunks,
region,
backend_shape,
strict=True,
):
if isinstance(enc_chunk, (list, tuple)):
# Rectilinear dimension — use boundary-based validation
_validate_rectilinear_chunk_alignment(
dask_chunks=v_chunks,
encoding_chunks=enc_chunk,
axis=axis,
name=name,
region=interval,
)
continue

# Regular dimension — existing validation logic
chunk_size = enc_chunk
for i, chunk in enumerate(v_chunks[1:-1]):
if chunk % chunk_size:
raise ValueError(
Expand Down
88 changes: 78 additions & 10 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import base64
import functools
import importlib.util
import json
import os
import struct
Expand Down Expand Up @@ -46,6 +48,28 @@
from xarray.core.types import ZarrArray, ZarrGroup


@functools.cache
def _has_unified_chunk_grid() -> bool:
"""Check if zarr has the unified ChunkGrid with is_regular support.

Defers the actual import so zarr stays lazy at module load time.
"""
if importlib.util.find_spec("zarr.core.chunk_grids") is None:
return False
from zarr.core.chunk_grids import ChunkGrid

return hasattr(ChunkGrid, "is_regular")


def _is_regular_chunk_spec(chunks: tuple) -> bool:
"""True when *chunks* is a flat tuple of ints (regular chunk grid).

Returns False for rectilinear specs where at least one element is a
sequence of per-chunk edge lengths.
"""
return all(isinstance(c, int) for c in chunks)


def _get_mappers(*, storage_options, store, chunk_store):
# expand str and path-like arguments
store = _normalize_path(store)
Expand Down Expand Up @@ -333,7 +357,7 @@ async def async_getitem(self, key):
)


def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, zarr_format):
"""
Given encoding chunks (possibly None or []) and variable chunks
(possibly None or []).
Expand All @@ -355,18 +379,34 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
# while dask chunks can be variable sized
# https://dask.pydata.org/en/latest/array-design.html#chunks
if var_chunks and not enc_chunks:
if zarr_format == 3 and _has_unified_chunk_grid():
# Check if dask chunks are regular (uniform except for last chunk)
has_varying_interior = any(
len(set(chunks[:-1])) > 1 for chunks in var_chunks
)
has_larger_final = any(chunks[0] < chunks[-1] for chunks in var_chunks)
if has_varying_interior or has_larger_final:
# Truly rectilinear — return dask-style tuples of per-chunk sizes.
# Requires zarr config: array.rectilinear_chunks = True
return tuple(var_chunks)
# Regular chunks — return the first chunk size per dimension
return tuple(chunk[0] for chunk in var_chunks)

if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks):
raise ValueError(
"Zarr requires uniform chunk sizes except for final chunk. "
"Zarr v2 requires uniform chunk sizes except for the final chunk. "
f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. "
"Consider rechunking using `chunk()`."
"Consider rechunking using `chunk()`, or switching to the "
"zarr v3 format with zarr-python>=3.2."
)
if any((chunks[0] < chunks[-1]) for chunks in var_chunks):
raise ValueError(
"Final chunk of Zarr array must be the same size or smaller "
f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}."
"Consider either rechunking using `chunk()` or instead deleting "
"or modifying `encoding['chunks']`."
"The final chunk of a Zarr v2 array or a Zarr v3 array without the "
"rectilinear chunks extension must be the same size or smaller "
f"than the first. Variable named {name!r} has incompatible Dask "
f"chunks {var_chunks!r}. "
"Consider switching to Zarr v3 with the rectilinear chunks extension, "
"rechunking using `chunk()` or deleting or modifying `encoding['chunks']`."
)
# return the first chunk for each dimension
return tuple(chunk[0] for chunk in var_chunks)
Expand All @@ -389,8 +429,17 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
var_chunks,
ndim,
name,
zarr_format,
)

# Rectilinear chunks: each element is a sequence of per-chunk edge lengths
if (
zarr_format == 3
and _has_unified_chunk_grid()
and any(not isinstance(x, int) for x in enc_chunks_tuple)
):
return enc_chunks_tuple

for x in enc_chunks_tuple:
if not isinstance(x, int):
raise TypeError(
Expand Down Expand Up @@ -532,6 +581,7 @@ def extract_zarr_variable_encoding(
var_chunks=variable.chunks,
ndim=variable.ndim,
name=name,
zarr_format=zarr_format,
)
if _zarr_v3() and chunks is None:
chunks = "auto"
Expand Down Expand Up @@ -910,9 +960,27 @@ def open_store_variable(self, name):
)
attributes = dict(attributes)

if _has_unified_chunk_grid() and zarr_array.metadata.zarr_format == 3:
from zarr.core.metadata.v3 import (
RectilinearChunkGridMetadata,
RegularChunkGridMetadata,
)

chunk_grid = zarr_array.metadata.chunk_grid
if isinstance(chunk_grid, RegularChunkGridMetadata):
chunks = chunk_grid.chunk_shape
elif isinstance(chunk_grid, RectilinearChunkGridMetadata):
chunks = chunk_grid.chunk_shapes
else:
chunks = tuple(zarr_array.chunks)
preferred_chunks = dict(zip(dimensions, chunks, strict=True))
else:
chunks = tuple(zarr_array.chunks)
preferred_chunks = dict(zip(dimensions, chunks, strict=True))

encoding = {
"chunks": zarr_array.chunks,
"preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)),
"chunks": chunks,
"preferred_chunks": preferred_chunks,
}

if _zarr_v3():
Expand Down Expand Up @@ -1300,7 +1368,7 @@ def set_variables(
if self._align_chunks and isinstance(effective_write_chunks, tuple):
v = grid_rechunk(
v=v,
enc_chunks=effective_write_chunks,
encoding_chunks=effective_write_chunks,
region=region,
)

Expand Down
Loading
Loading