Skip to content

Commit 62551c7

Browse files
authored
fix:*-like creation routines take kwargs (#2992)
* ensure that user-provided array creation kwargs can pass through array-like creation routines * test for kwarg propagation through array-like routines * propagate fill value if unspecified * changelog * Update 2992.fix.rst * add test for open_like * Update 2992.fix.rst * lint * add likeargs typeddict * explicitly iterate over functions in test * add test cases for fill_value in test_array_like_creation * use correct type: ignore statement * remove test that made no sense after allowing dtype inference in full_like
1 parent 6805332 commit 62551c7

File tree

5 files changed

+183
-23
lines changed

5 files changed

+183
-23
lines changed

changes/2992.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fix a bug preventing ``ones_like``, ``full_like``, ``empty_like``, ``zeros_like`` and ``open_like`` functions from accepting
2+
an explicit specification of array attributes like shape, dtype, chunks etc. The functions ``full_like``,
3+
``empty_like``, and ``open_like`` now also more consistently infer a ``fill_value`` parameter from the provided array.

src/zarr/api/asynchronous.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import dataclasses
55
import warnings
6-
from typing import TYPE_CHECKING, Any, Literal, cast
6+
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast
77

88
import numpy as np
99
import numpy.typing as npt
@@ -56,6 +56,7 @@
5656
from zarr.abc.numcodec import Numcodec
5757
from zarr.core.buffer import NDArrayLikeOrScalar
5858
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
59+
from zarr.core.metadata.v2 import CompressorLikev2
5960
from zarr.storage import StoreLike
6061

6162
# TODO: this type could use some more thought
@@ -124,10 +125,20 @@ def _get_shape_chunks(a: ArrayLike | Any) -> tuple[tuple[int, ...] | None, tuple
124125
return shape, chunks
125126

126127

127-
def _like_args(a: ArrayLike, kwargs: dict[str, Any]) -> dict[str, Any]:
128+
class _LikeArgs(TypedDict):
129+
shape: NotRequired[tuple[int, ...]]
130+
chunks: NotRequired[tuple[int, ...]]
131+
dtype: NotRequired[np.dtype[np.generic]]
132+
order: NotRequired[Literal["C", "F"]]
133+
filters: NotRequired[tuple[Numcodec, ...] | None]
134+
compressor: NotRequired[CompressorLikev2]
135+
codecs: NotRequired[tuple[Codec, ...]]
136+
137+
138+
def _like_args(a: ArrayLike) -> _LikeArgs:
128139
"""Set default values for shape and chunks if they are not present in the array-like object"""
129140

130-
new = kwargs.copy()
141+
new: _LikeArgs = {}
131142

132143
shape, chunks = _get_shape_chunks(a)
133144
if shape is not None:
@@ -138,9 +149,9 @@ def _like_args(a: ArrayLike, kwargs: dict[str, Any]) -> dict[str, Any]:
138149
if hasattr(a, "dtype"):
139150
new["dtype"] = a.dtype
140151

141-
if isinstance(a, AsyncArray):
142-
new["order"] = a.order
152+
if isinstance(a, AsyncArray | Array):
143153
if isinstance(a.metadata, ArrayV2Metadata):
154+
new["order"] = a.order
144155
new["compressor"] = a.metadata.compressor
145156
new["filters"] = a.metadata.filters
146157
else:
@@ -1087,7 +1098,7 @@ async def empty(
10871098
shape: tuple[int, ...], **kwargs: Any
10881099
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
10891100
"""Create an empty array with the specified shape. The contents will be filled with the
1090-
array's fill value or zeros if no fill value is provided.
1101+
specified fill value or zeros if no fill value is provided.
10911102
10921103
Parameters
10931104
----------
@@ -1102,8 +1113,7 @@ async def empty(
11021113
retrieve data from an empty Zarr array, any values may be returned,
11031114
and these are not guaranteed to be stable from one access to the next.
11041115
"""
1105-
1106-
return await create(shape=shape, fill_value=None, **kwargs)
1116+
return await create(shape=shape, **kwargs)
11071117

11081118

11091119
async def empty_like(
@@ -1130,8 +1140,10 @@ async def empty_like(
11301140
retrieve data from an empty Zarr array, any values may be returned,
11311141
and these are not guaranteed to be stable from one access to the next.
11321142
"""
1133-
like_kwargs = _like_args(a, kwargs)
1134-
return await empty(**like_kwargs)
1143+
like_kwargs = _like_args(a) | kwargs
1144+
if isinstance(a, (AsyncArray | Array)):
1145+
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
1146+
return await empty(**like_kwargs) # type: ignore[arg-type]
11351147

11361148

11371149
# TODO: add type annotations for fill_value and kwargs
@@ -1176,10 +1188,10 @@ async def full_like(
11761188
Array
11771189
The new array.
11781190
"""
1179-
like_kwargs = _like_args(a, kwargs)
1180-
if isinstance(a, AsyncArray):
1191+
like_kwargs = _like_args(a) | kwargs
1192+
if isinstance(a, (AsyncArray | Array)):
11811193
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
1182-
return await full(**like_kwargs)
1194+
return await full(**like_kwargs) # type: ignore[arg-type]
11831195

11841196

11851197
async def ones(
@@ -1220,8 +1232,8 @@ async def ones_like(
12201232
Array
12211233
The new array.
12221234
"""
1223-
like_kwargs = _like_args(a, kwargs)
1224-
return await ones(**like_kwargs)
1235+
like_kwargs = _like_args(a) | kwargs
1236+
return await ones(**like_kwargs) # type: ignore[arg-type]
12251237

12261238

12271239
async def open_array(
@@ -1300,10 +1312,10 @@ async def open_like(
13001312
AsyncArray
13011313
The opened array.
13021314
"""
1303-
like_kwargs = _like_args(a, kwargs)
1315+
like_kwargs = _like_args(a) | kwargs
13041316
if isinstance(a, (AsyncArray | Array)):
1305-
kwargs.setdefault("fill_value", a.metadata.fill_value)
1306-
return await open_array(path=path, **like_kwargs)
1317+
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
1318+
return await open_array(path=path, **like_kwargs) # type: ignore[arg-type]
13071319

13081320

13091321
async def zeros(
@@ -1344,5 +1356,5 @@ async def zeros_like(
13441356
Array
13451357
The new array.
13461358
"""
1347-
like_kwargs = _like_args(a, kwargs)
1348-
return await zeros(**like_kwargs)
1359+
like_kwargs = _like_args(a) | kwargs
1360+
return await zeros(**like_kwargs) # type: ignore[arg-type]

tests/test_api.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import inspect
44
import re
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Any
66

77
import zarr.codecs
88
import zarr.storage
@@ -82,6 +82,91 @@ def test_create(memory_store: Store) -> None:
8282
z = create(shape=(400, 100), chunks=(16, 16.5), store=store, overwrite=True) # type: ignore[arg-type]
8383

8484

85+
@pytest.mark.parametrize(
86+
"func",
87+
[
88+
zarr.api.asynchronous.zeros_like,
89+
zarr.api.asynchronous.ones_like,
90+
zarr.api.asynchronous.empty_like,
91+
zarr.api.asynchronous.full_like,
92+
zarr.api.asynchronous.open_like,
93+
],
94+
)
95+
@pytest.mark.parametrize("out_shape", ["keep", (10, 10)])
96+
@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)])
97+
@pytest.mark.parametrize("out_dtype", ["keep", "int8"])
98+
@pytest.mark.parametrize("out_fill", ["keep", 4])
99+
async def test_array_like_creation(
100+
zarr_format: ZarrFormat,
101+
func: Callable[[Any], Any],
102+
out_shape: Literal["keep"] | tuple[int, ...],
103+
out_chunks: Literal["keep"] | tuple[int, ...],
104+
out_dtype: str,
105+
out_fill: Literal["keep"] | int,
106+
) -> None:
107+
"""
108+
Test zeros_like, ones_like, empty_like, full_like, ensuring that we can override the
109+
shape, chunks, dtype and fill_value of the array-like object provided to these functions with
110+
appropriate keyword arguments
111+
"""
112+
ref_fill = 100
113+
ref_arr = zarr.create_array(
114+
store={},
115+
shape=(11, 12),
116+
dtype="uint8",
117+
chunks=(11, 12),
118+
zarr_format=zarr_format,
119+
fill_value=ref_fill,
120+
)
121+
kwargs: dict[str, object] = {}
122+
if func is zarr.api.asynchronous.full_like:
123+
if out_fill == "keep":
124+
expect_fill = ref_fill
125+
else:
126+
expect_fill = out_fill
127+
kwargs["fill_value"] = expect_fill
128+
elif func is zarr.api.asynchronous.zeros_like:
129+
expect_fill = 0
130+
elif func is zarr.api.asynchronous.ones_like:
131+
expect_fill = 1
132+
elif func is zarr.api.asynchronous.empty_like:
133+
if out_fill == "keep":
134+
expect_fill = ref_fill
135+
else:
136+
kwargs["fill_value"] = out_fill
137+
expect_fill = out_fill
138+
elif func is zarr.api.asynchronous.open_like: # type: ignore[comparison-overlap]
139+
if out_fill == "keep":
140+
expect_fill = ref_fill
141+
else:
142+
kwargs["fill_value"] = out_fill
143+
expect_fill = out_fill
144+
kwargs["mode"] = "w"
145+
else:
146+
raise AssertionError
147+
if out_shape != "keep":
148+
kwargs["shape"] = out_shape
149+
expect_shape = out_shape
150+
else:
151+
expect_shape = ref_arr.shape
152+
if out_chunks != "keep":
153+
kwargs["chunks"] = out_chunks
154+
expect_chunks = out_chunks
155+
else:
156+
expect_chunks = ref_arr.chunks
157+
if out_dtype != "keep":
158+
kwargs["dtype"] = out_dtype
159+
expect_dtype = out_dtype
160+
else:
161+
expect_dtype = ref_arr.dtype # type: ignore[assignment]
162+
163+
new_arr = await func(ref_arr, path="foo", zarr_format=zarr_format, **kwargs) # type: ignore[call-arg]
164+
assert new_arr.shape == expect_shape
165+
assert new_arr.chunks == expect_chunks
166+
assert new_arr.dtype == expect_dtype
167+
assert np.all(Array(new_arr)[:] == expect_fill)
168+
169+
85170
# TODO: parametrize over everything this function takes
86171
@pytest.mark.parametrize("store", ["memory"], indirect=True)
87172
def test_create_array(store: Store, zarr_format: ZarrFormat) -> None:

tests/test_api/test_asynchronous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_like_args(
8888
"""
8989
Test the like_args function
9090
"""
91-
assert _like_args(observed, {}) == expected
91+
assert _like_args(observed) == expected
9292

9393

9494
async def test_open_no_array() -> None:

tests/test_group.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
import time
1010
import warnings
11-
from typing import TYPE_CHECKING, Any, Literal
11+
from typing import TYPE_CHECKING, Any, Literal, get_args
1212

1313
import numpy as np
1414
import pytest
@@ -761,6 +761,66 @@ def test_group_create_array(
761761
assert np.array_equal(array[:], data)
762762

763763

764+
LikeMethodName = Literal["zeros_like", "ones_like", "empty_like", "full_like"]
765+
766+
767+
@pytest.mark.parametrize("method_name", get_args(LikeMethodName))
768+
@pytest.mark.parametrize("out_shape", ["keep", (10, 10)])
769+
@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)])
770+
@pytest.mark.parametrize("out_dtype", ["keep", "int8"])
771+
def test_group_array_like_creation(
772+
zarr_format: ZarrFormat,
773+
method_name: LikeMethodName,
774+
out_shape: Literal["keep"] | tuple[int, ...],
775+
out_chunks: Literal["keep"] | tuple[int, ...],
776+
out_dtype: str,
777+
) -> None:
778+
"""
779+
Test Group.{zeros_like, ones_like, empty_like, full_like}, ensuring that we can override the
780+
shape, chunks, and dtype of the array-like object provided to these functions with
781+
appropriate keyword arguments
782+
"""
783+
ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12))
784+
group = Group.from_store({}, zarr_format=zarr_format)
785+
kwargs = {}
786+
if method_name == "full_like":
787+
expect_fill = 4
788+
kwargs["fill_value"] = expect_fill
789+
meth = group.full_like
790+
elif method_name == "zeros_like":
791+
expect_fill = 0
792+
meth = group.zeros_like
793+
elif method_name == "ones_like":
794+
expect_fill = 1
795+
meth = group.ones_like
796+
elif method_name == "empty_like":
797+
expect_fill = ref_arr.fill_value
798+
meth = group.empty_like
799+
else:
800+
raise AssertionError
801+
if out_shape != "keep":
802+
kwargs["shape"] = out_shape
803+
expect_shape = out_shape
804+
else:
805+
expect_shape = ref_arr.shape
806+
if out_chunks != "keep":
807+
kwargs["chunks"] = out_chunks
808+
expect_chunks = out_chunks
809+
else:
810+
expect_chunks = ref_arr.chunks
811+
if out_dtype != "keep":
812+
kwargs["dtype"] = out_dtype
813+
expect_dtype = out_dtype
814+
else:
815+
expect_dtype = ref_arr.dtype
816+
817+
new_arr = meth(name="foo", data=ref_arr, **kwargs)
818+
assert new_arr.shape == expect_shape
819+
assert new_arr.chunks == expect_chunks
820+
assert new_arr.dtype == expect_dtype
821+
assert np.all(new_arr[:] == expect_fill)
822+
823+
764824
def test_group_array_creation(
765825
store: Store,
766826
zarr_format: ZarrFormat,

0 commit comments

Comments
 (0)