Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ v2026.05.0 (unreleased)
New Features
~~~~~~~~~~~~

- Change behavior of ``chunks="auto"`` to guarantee that chunks in xarray
match on-disk chunks or multiples of them. No automatic chunk splitting allowed.
(:pull:`11060`).
By `Julia Signell <https://github.com/jsignell>`_

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
71 changes: 71 additions & 0 deletions properties/test_parallelcompat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pytest

pytest.importorskip("hypothesis")
# isort: split

from hypothesis import given

import xarray.testing.strategies as xrst
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint


class TestPreserveChunks:
@given(xrst.shape_and_chunks())
def test_preserve_all_chunks(
self, shape_and_chunks: tuple[tuple[int, ...], tuple[int, ...]]
) -> None:
shape, previous_chunks = shape_and_chunks
typesize = 8
target = 1024 * 1024

actual = ChunkManagerEntrypoint.preserve_chunks(
chunks=("auto",) * len(shape),
shape=shape,
target=target,
typesize=typesize,
previous_chunks=previous_chunks,
)
for i, chunk in enumerate(actual):
if chunk != shape[i]:
assert chunk >= previous_chunks[i]
assert chunk % previous_chunks[i] == 0
assert chunk <= shape[i]

if actual != shape:
assert np.prod(actual) * typesize >= 0.5 * target

@pytest.mark.parametrize("first_chunk", [-1, (), 1])
@given(xrst.shape_and_chunks(min_dims=2))
def test_preserve_some_chunks(
self,
first_chunk: int | tuple[int, ...],
shape_and_chunks: tuple[tuple[int, ...], tuple[int, ...]],
) -> None:
shape, previous_chunks = shape_and_chunks
typesize = 4
target = 2 * 1024 * 1024

actual = ChunkManagerEntrypoint.preserve_chunks(
chunks=(first_chunk, *["auto" for _ in range(len(shape) - 1)]),
shape=shape,
target=target,
typesize=typesize,
previous_chunks=previous_chunks,
)
for i, chunk in enumerate(actual):
if i == 0:
if first_chunk == 1:
assert chunk == 1
elif first_chunk == -1:
assert chunk == shape[i]
elif first_chunk == ():
assert chunk == previous_chunks[i]
elif chunk != shape[i]:
assert chunk >= previous_chunks[i]
assert chunk % previous_chunks[i] == 0
assert chunk <= shape[i]

# if we have more than one chunk, make sure the chunks are big enough
if actual[1:] != shape[1:]:
assert np.prod(actual) * typesize >= 0.5 * target
26 changes: 15 additions & 11 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,11 @@ def _chunk_ds(
name,
var,
var_chunks,
chunkmanager,
overwrite_encoded_chunks=overwrite_encoded_chunks,
name_prefix=name_prefix,
token=token,
inline_array=inline_array,
chunked_array_type=chunkmanager,
from_array_kwargs=from_array_kwargs.copy(),
just_use_token=True,
)
Expand Down Expand Up @@ -294,7 +294,7 @@ def _dataset_from_backend_dataset(
):
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
f"chunks must be an int, dict, 'auto' or None. Instead found {chunks}."
)

_protect_dataset_variables_inplace(backend_ds, cache)
Expand Down Expand Up @@ -344,7 +344,7 @@ def _datatree_from_backend_datatree(
):
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
f"chunks must be an int, dict, 'auto' or None. Instead found {chunks}."
)

_protect_datatree_variables_inplace(backend_tree, cache)
Expand Down Expand Up @@ -433,8 +433,9 @@ class (a subclass of ``BackendEntrypoint``) can also be used.
chunks : int, dict, 'auto' or None, default: None
If provided, used to load the data into dask arrays.

- ``chunks="auto"`` will use dask ``auto`` chunking taking into account the
engine preferred chunks.
- ``chunks="auto"`` will use a chunking scheme that never splits encoded
chunks. If encoded chunks are small then "auto" takes multiples of them
over the largest dimension.
- ``chunks=None`` skips using dask. This uses xarray's internally private
:ref:`lazy indexing classes <internal design.lazy indexing>`,
but data is eagerly loaded into memory as numpy arrays when accessed.
Expand Down Expand Up @@ -677,8 +678,9 @@ class (a subclass of ``BackendEntrypoint``) can also be used.
chunks : int, dict, 'auto' or None, default: None
If provided, used to load the data into dask arrays.

- ``chunks='auto'`` will use dask ``auto`` chunking taking into account the
engine preferred chunks.
- ``chunks="auto"`` will use a chunking scheme that never splits encoded
chunks. If encoded chunks are small then "auto" takes multiples of them
over the largest dimension.
- ``chunks=None`` skips using dask. This uses xarray's internally private
:ref:`lazy indexing classes <internal design.lazy indexing>`,
but data is eagerly loaded into memory as numpy arrays when accessed.
Expand Down Expand Up @@ -903,8 +905,9 @@ def open_datatree(
chunks : int, dict, 'auto' or None, default: None
If provided, used to load the data into dask arrays.

- ``chunks="auto"`` will use dask ``auto`` chunking taking into account the
engine preferred chunks.
- ``chunks="auto"`` will use a chunking scheme that never splits encoded
chunks. If encoded chunks are small then "auto" takes multiples of them
over the largest dimension.
- ``chunks=None`` skips using dask. This uses xarray's internally private
:ref:`lazy indexing classes <internal design.lazy indexing>`,
but data is eagerly loaded into memory as numpy arrays when accessed.
Expand Down Expand Up @@ -1149,8 +1152,9 @@ def open_groups(
chunks : int, dict, 'auto' or None, default: None
If provided, used to load the data into dask arrays.

- ``chunks="auto"`` will use dask ``auto`` chunking taking into account the
engine preferred chunks.
- ``chunks="auto"`` will use a chunking scheme that never splits encoded
chunks. If encoded chunks are small then "auto" takes multiples of them
over the largest dimension.
- ``chunks=None`` skips using dask. This uses xarray's internally private
:ref:`lazy indexing classes <internal design.lazy indexing>`,
but data is eagerly loaded into memory as numpy arrays when accessed.
Expand Down
7 changes: 4 additions & 3 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,12 +1499,13 @@ def open_zarr(
Array synchronizer provided to zarr
group : str, optional
Group path. (a.k.a. `path` in zarr terminology.)
chunks : int, dict, "auto" or None, optional
chunks : int, dict, "auto", or None, optional
Used to load the data into dask arrays. Default behavior is to use
``chunks={}`` if dask is available, otherwise ``chunks=None``.

- ``chunks='auto'`` will use dask ``auto`` chunking taking into account the
engine preferred chunks.
- ``chunks="auto"`` will use a chunking scheme that never splits encoded
chunks. If encoded chunks are small then "auto" takes multiples of them
over the largest dimension.
- ``chunks=None`` skips using dask. This uses xarray's internally private
:ref:`lazy indexing classes <internal design.lazy indexing>`,
but data is eagerly loaded into memory as numpy arrays when accessed.
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2639,11 +2639,11 @@ def _resolve_resampler(name: Hashable, resampler: Resampler) -> tuple[int, ...]:
k,
v,
chunks_mapping_ints,
chunkmanager,
token,
lock,
name_prefix,
inline_array=inline_array,
chunked_array_type=chunkmanager,
from_array_kwargs=from_array_kwargs.copy(),
)
for k, v in self.variables.items()
Expand Down
5 changes: 3 additions & 2 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from xarray.namedarray._typing import (
T_ChunkDim,
T_Chunks,
_DType_co,
_NormalizedChunks,
Expand Down Expand Up @@ -45,11 +46,11 @@ def chunks(self, data: Any) -> _NormalizedChunks:

def normalize_chunks(
self,
chunks: T_Chunks | _NormalizedChunks,
chunks: tuple[T_ChunkDim, ...] | _NormalizedChunks,
shape: tuple[int, ...] | None = None,
limit: int | None = None,
dtype: _DType_co | None = None,
previous_chunks: _NormalizedChunks | None = None,
previous_chunks: tuple[int, ...] | _NormalizedChunks | None = None,
) -> Any:
"""Called by open_dataset"""
from dask.array.core import normalize_chunks
Expand Down
118 changes: 118 additions & 0 deletions xarray/namedarray/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

if TYPE_CHECKING:
from xarray.namedarray._typing import (
T_ChunkDim,
T_Chunks,
_Chunks,
_DType,
Expand Down Expand Up @@ -784,3 +785,120 @@ def get_auto_chunk_size(
raise NotImplementedError(
"For 'auto' rechunking of cftime arrays, get_auto_chunk_size must be implemented by the chunk manager"
)

@staticmethod
def preserve_chunks(
chunks: tuple[T_ChunkDim, ...],
shape: tuple[int, ...],
target: int,
typesize: int,
previous_chunks: tuple[int, ...] | _NormalizedChunks,
) -> tuple[T_ChunkDim, ...]:
"""Quickly determine optimal chunks close to target size but never splitting
previous_chunks.

This takes in a chunks argument potentially containing ``"auto"`` for several
dimensions. This function replaces ``"auto"`` with concrete dimension sizes that
try to get chunks to be close to certain size in bytes, provided by the ``target=``
keyword. Any dimensions marked as ``"auto"`` will potentially be multiplied
by some factor to get close to the byte target, while never splitting
``previous_chunks``. If chunks are non-uniform along a particular dimension
then that dimension will always use exactly ``previous_chunks``.

Examples
--------
>>> ChunkManagerEntrypoint.preserve_chunks(
... chunks=("auto", "auto", "auto"),
... shape=(1280, 1280, 20),
... target=500 * 1024,
... typesize=8,
... previous_chunks=(128, 128, 1),
... )
(128, 128, 2)

>>> ChunkManagerEntrypoint.preserve_chunks(
... chunks=("auto", "auto", 1),
... shape=(1280, 1280, 20),
... target=1 * 1024 * 1024,
... typesize=8,
... previous_chunks=(128, 128, 1),
... )
(128, 1024, 1)

>>> ChunkManagerEntrypoint.preserve_chunks(
... chunks=("auto", "auto", 1),
... shape=(1280, 1280, 20),
... target=1 * 1024 * 1024,
... typesize=8,
... previous_chunks=((128,) * 10, (128, 256, 256, 512), (1,) * 20),
... )
(256, (128, 256, 256, 512), 1)

Parameters
----------
chunks: tuple[int | str | tuple[int], ...]
A tuple of either dimensions or tuples of explicit chunk dimensions
Some entries should be "auto".
shape: tuple[int]
The shape of the array
target: int
The target size of the chunk in bytes.
typesize: int
The size, in bytes, of each element of the chunk.
previous_chunks: tuple[int | tuple[int], ...]
Size of chunks being preserved. Expressed as a tuple of ints or tuple
of tuple of ints.
"""
new_chunks = [*previous_chunks]
auto_dims = [c == "auto" for c in chunks]
max_chunks = np.array(shape)
for i, previous_chunk in enumerate(previous_chunks):
chunk = chunks[i]
if chunk == -1:
# -1 means whole dim is in one chunk
new_chunks[i] = shape[i]
else:
if isinstance(previous_chunk, tuple):
# For uniform chunks just take the first item
if previous_chunk[1:-1] == previous_chunk[:-2]:
new_chunks[i] = previous_chunk[0]
previous_chunk = previous_chunk[0]
# For non-uniform chunks, leave them alone
else:
auto_dims[i] = False
max_chunks[i] = max(previous_chunk)

if isinstance(previous_chunk, int):
# auto, None or () means we want to track previous chunk
if chunk == "auto" or not chunk:
max_chunks[i] = previous_chunk
# otherwise use the explicitly provided chunk
else:
new_chunks[i] = chunk
max_chunks[i] = chunk if isinstance(chunk, int) else max(chunk)

if not any(auto_dims):
return chunks

while True:
# Repeatedly look for the last dim with more than one chunk and multiply it by 2.
# Stop when:
# 1a. we are larger than the target chunk size OR
# 1b. we are within 50% of the target chunk size OR
# 2. the chunk covers the entire array

num_chunks = np.array(shape) / max_chunks * auto_dims
chunk_bytes = np.prod(max_chunks) * typesize

if chunk_bytes > target or abs(chunk_bytes - target) / target < 0.5:
break

if (num_chunks <= 1).all():
break

idx = int(np.nonzero(num_chunks > 1)[0][-1])

new_chunks[idx] = min(new_chunks[idx] * 2, shape[idx])
max_chunks[idx] = new_chunks[idx]

return tuple(new_chunks)
11 changes: 10 additions & 1 deletion xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _get_chunk( # type: ignore[no-untyped-def]
preferred_chunk_shape = tuple(
itertools.starmap(preferred_chunks.get, zip(dims, shape, strict=True))
)
if isinstance(chunks, Number) or (chunks == "auto"):
if isinstance(chunks, (Number, str)):
chunks = dict.fromkeys(dims, chunks)
chunk_shape = tuple(
chunks.get(dim, None) or preferred_chunk_sizes
Expand All @@ -236,6 +236,15 @@ def _get_chunk( # type: ignore[no-untyped-def]
limit = None
dtype = data.dtype

if shape and preferred_chunk_shape and any(c == "auto" for c in chunk_shape):
chunk_shape = chunkmanager.preserve_chunks(
chunk_shape,
shape=shape,
target=chunkmanager.get_auto_chunk_size(),
typesize=getattr(dtype, "itemsize", 8),
previous_chunks=preferred_chunk_shape,
)

chunk_shape = chunkmanager.normalize_chunks(
chunk_shape,
shape=shape,
Expand Down
Loading
Loading