From f3a4d98692d6f56167418cb66499b5f455cc78ee Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Fri, 24 Oct 2025 19:42:21 +0200 Subject: [PATCH 1/6] refactor the comparison for arg reductions --- xarray_array_testing/reduction.py | 66 ++++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/xarray_array_testing/reduction.py b/xarray_array_testing/reduction.py index 1fa0f23..845673e 100644 --- a/xarray_array_testing/reduction.py +++ b/xarray_array_testing/reduction.py @@ -1,10 +1,12 @@ +import itertools from contextlib import nullcontext import hypothesis.strategies as st import numpy as np import pytest +import xarray as xr import xarray.testing.strategies as xrst -from hypothesis import given +from hypothesis import given, note from xarray_array_testing.base import DuckArrayTestMixin @@ -60,17 +62,63 @@ def test_variable_order_reduce(self, op, data): @given(st.data()) def test_variable_order_reduce_index(self, op, data): variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) + possible_dims = [..., list(variable.dims), *variable.dims] + list( + itertools.chain.from_iterable( + map(list, itertools.combinations(variable.dims, length)) + for length in range(1, len(variable.dims)) + ) + ) + dim = data.draw(st.sampled_from(possible_dims)) with self.expected_errors(op, variable=variable): # compute using xr.Variable.() - actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()} - - # compute using xp.(array) - index = getattr(self.xp, op)(variable.data) - unraveled = np.unravel_index(index, variable.shape) - expected = dict(zip(variable.dims, unraveled)) - - self.assert_equal(actual, expected) + actual = getattr(variable, op)(dim=dim) + + if dim is not ... and not isinstance(dim, list): + # compute using xp.(array) + note(dim) + axis = variable.get_axis_num(dim) + expected = getattr(self.xp, op)(variable.data, axis=axis) + self.assert_equal(actual.data, expected) + elif dim is ... or len(dim) == len(variable.dims): + # compute using xp.(array) + index = getattr(self.xp, op)(variable.data) + + unraveled = np.unravel_index(index, variable.shape) + expected = dict(zip(variable.dims, unraveled)) + + # all elements are 0D + assert actual == expected + else: + if len(dim) == 1: + dim_ = dim[0] + axis = variable.get_axis_num(dim_) + index = getattr(self.xp, op)(variable.data, axis=axis) + + result_dims = [d for d in variable.dims if d != dim_] + expected = {dim_: xr.Variable(result_dims, index)} + else: + # move the relevant dims together and flatten + dim_name = object() + stacked = variable.stack({dim_name: dim}) + + result_dims = stacked.dims[:-1] + reduce_shape = tuple(variable.sizes[d] for d in dim) + index = getattr(self.xp, op)(stacked.data, axis=-1) + + unravelled = np.unravel_index(index, reduce_shape) + + expected = { + d: xr.Variable(result_dims, idx) + for d, idx in zip(dim, unravelled, strict=True) + } + + note(f"original: {variable}") + note(f"actual: {actual}") + note(f"expected: {expected}") + + assert actual.keys() == expected.keys(), "Reduction dims are not equal" + assert all(actual[k].equals(expected[k]) for k in actual) @pytest.mark.parametrize( "op", From 0ecc82de119d99bd46deb817db1256da40181098 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Fri, 24 Oct 2025 20:53:44 +0200 Subject: [PATCH 2/6] add a static method to compare indexing overators --- xarray_array_testing/base.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/xarray_array_testing/base.py b/xarray_array_testing/base.py index 945e7e3..5a27c20 100644 --- a/xarray_array_testing/base.py +++ b/xarray_array_testing/base.py @@ -2,7 +2,9 @@ from abc import ABC from types import ModuleType +import numpy as np import numpy.testing as npt +import xarray as xr from xarray.namedarray._typing import duckarray @@ -24,3 +26,23 @@ def array_strategy_fn(*, shape, dtype): @staticmethod def assert_equal(a, b): npt.assert_equal(a, b) + + @staticmethod + def assert_dimension_indexers_equal(a, b): + assert type(a) is type(b), f"types don't match: {type(a)} vs {type(b)}" + + if isinstance(a, dict): + assert a.keys() == b.keys(), f"Different dimensions: {list(a)} vs {list(b)}" + + values = ((a[k], b[k]) for k in a) + assert all( + ( + isinstance(v1, xr.Variable) + and isinstance(v2, xr.Variable) + and v1.dims == v2.dims + and np.equal(v1.data, v2.data) + ) + for v1, v2 in values + ), "Differing indexers" + else: + npt.assert_equal(a, b) From dcd4ca1bb1dfb42198bf169a941f30ab4547bec9 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Fri, 24 Oct 2025 20:55:08 +0200 Subject: [PATCH 3/6] simplify --- xarray_array_testing/base.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/xarray_array_testing/base.py b/xarray_array_testing/base.py index 5a27c20..37915fe 100644 --- a/xarray_array_testing/base.py +++ b/xarray_array_testing/base.py @@ -4,7 +4,6 @@ import numpy as np import numpy.testing as npt -import xarray as xr from xarray.namedarray._typing import duckarray @@ -34,15 +33,6 @@ def assert_dimension_indexers_equal(a, b): if isinstance(a, dict): assert a.keys() == b.keys(), f"Different dimensions: {list(a)} vs {list(b)}" - values = ((a[k], b[k]) for k in a) - assert all( - ( - isinstance(v1, xr.Variable) - and isinstance(v2, xr.Variable) - and v1.dims == v2.dims - and np.equal(v1.data, v2.data) - ) - for v1, v2 in values - ), "Differing indexers" + assert all(np.equal(a[k], b[k]) for k in a), "Differing indexers" else: npt.assert_equal(a, b) From 0f336d3029bd11857ac1273990a4149b014abed5 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sat, 25 Oct 2025 00:33:35 +0200 Subject: [PATCH 4/6] parametrize by `xp` This requires converting the static method to an actual method. --- xarray_array_testing/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray_array_testing/base.py b/xarray_array_testing/base.py index 37915fe..fc69ba0 100644 --- a/xarray_array_testing/base.py +++ b/xarray_array_testing/base.py @@ -2,7 +2,6 @@ from abc import ABC from types import ModuleType -import numpy as np import numpy.testing as npt from xarray.namedarray._typing import duckarray @@ -26,13 +25,14 @@ def array_strategy_fn(*, shape, dtype): def assert_equal(a, b): npt.assert_equal(a, b) - @staticmethod - def assert_dimension_indexers_equal(a, b): + def assert_dimension_indexers_equal(self, a, b): assert type(a) is type(b), f"types don't match: {type(a)} vs {type(b)}" if isinstance(a, dict): assert a.keys() == b.keys(), f"Different dimensions: {list(a)} vs {list(b)}" - assert all(np.equal(a[k], b[k]) for k in a), "Differing indexers" + assert all( + self.xp.all(self.xp.equal(a[k], b[k])) for k in a + ), "Differing indexers" else: npt.assert_equal(a, b) From 7cbbd56da9abed74b43d8760d4b9a637fe477e18 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sat, 25 Oct 2025 00:34:31 +0200 Subject: [PATCH 5/6] refactor to always use the indexer equals function --- xarray_array_testing/reduction.py | 71 +++++++++++++++---------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/xarray_array_testing/reduction.py b/xarray_array_testing/reduction.py index 845673e..6ce67c2 100644 --- a/xarray_array_testing/reduction.py +++ b/xarray_array_testing/reduction.py @@ -4,7 +4,6 @@ import hypothesis.strategies as st import numpy as np import pytest -import xarray as xr import xarray.testing.strategies as xrst from hypothesis import given, note @@ -73,52 +72,52 @@ def test_variable_order_reduce_index(self, op, data): with self.expected_errors(op, variable=variable): # compute using xr.Variable.() actual = getattr(variable, op)(dim=dim) + if dim is ... or isinstance(dim, list): + actual_ = {dim_: var.data for dim_, var in actual.items()} + else: + actual_ = actual.data if dim is not ... and not isinstance(dim, list): # compute using xp.(array) note(dim) axis = variable.get_axis_num(dim) - expected = getattr(self.xp, op)(variable.data, axis=axis) - self.assert_equal(actual.data, expected) + indices = getattr(self.xp, op)(variable.data, axis=axis) + + expected = self.xp.asarray(indices) elif dim is ... or len(dim) == len(variable.dims): # compute using xp.(array) index = getattr(self.xp, op)(variable.data) unraveled = np.unravel_index(index, variable.shape) - expected = dict(zip(variable.dims, unraveled)) - - # all elements are 0D - assert actual == expected + expected = { + k: self.xp.asarray(v) for k, v in zip(variable.dims, unraveled) + } + elif len(dim) == 1: + dim_ = dim[0] + axis = variable.get_axis_num(dim_) + index = getattr(self.xp, op)(variable.data, axis=axis) + + expected = {dim_: self.xp.asarray(index)} else: - if len(dim) == 1: - dim_ = dim[0] - axis = variable.get_axis_num(dim_) - index = getattr(self.xp, op)(variable.data, axis=axis) - - result_dims = [d for d in variable.dims if d != dim_] - expected = {dim_: xr.Variable(result_dims, index)} - else: - # move the relevant dims together and flatten - dim_name = object() - stacked = variable.stack({dim_name: dim}) - - result_dims = stacked.dims[:-1] - reduce_shape = tuple(variable.sizes[d] for d in dim) - index = getattr(self.xp, op)(stacked.data, axis=-1) - - unravelled = np.unravel_index(index, reduce_shape) - - expected = { - d: xr.Variable(result_dims, idx) - for d, idx in zip(dim, unravelled, strict=True) - } - - note(f"original: {variable}") - note(f"actual: {actual}") - note(f"expected: {expected}") - - assert actual.keys() == expected.keys(), "Reduction dims are not equal" - assert all(actual[k].equals(expected[k]) for k in actual) + # move the relevant dims together and flatten + dim_name = object() + stacked = variable.stack({dim_name: dim}) + + reduce_shape = tuple(variable.sizes[d] for d in dim) + index = getattr(self.xp, op)(stacked.data, axis=-1) + + unravelled = np.unravel_index(index, reduce_shape) + + expected = { + d: self.xp.asarray(idx) + for d, idx in zip(dim, unravelled, strict=True) + } + + note(f"original: {variable}") + note(f"actual: {repr(actual_)}") + note(f"expected: {repr(expected)}") + + self.assert_dimension_indexers_equal(actual_, expected) @pytest.mark.parametrize( "op", From c45bd14c43fe73e21ca763f0aa9c38f05b1a34cc Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 23 Dec 2025 13:47:20 +0100 Subject: [PATCH 6/6] move the note displaying the value of `dim` --- xarray_array_testing/reduction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray_array_testing/reduction.py b/xarray_array_testing/reduction.py index 6ce67c2..6833c58 100644 --- a/xarray_array_testing/reduction.py +++ b/xarray_array_testing/reduction.py @@ -77,9 +77,9 @@ def test_variable_order_reduce_index(self, op, data): else: actual_ = actual.data + note(f"dim: {dim}") if dim is not ... and not isinstance(dim, list): # compute using xp.(array) - note(dim) axis = variable.get_axis_num(dim) indices = getattr(self.xp, op)(variable.data, axis=axis)