Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,24 +85,41 @@ def assert_isomorphic(a: DataTree, b: DataTree):


def maybe_transpose_dims(a, b, check_dim_order: bool):
"""Helper for assert_equal/allclose/identical"""
"""Helper for assert_equal/allclose/identical

Returns (a, b) tuple with dimensions transposed to canonical order if needed.
"""

__tracebackhide__ = True

if check_dim_order:
return a, b

def _maybe_transpose_dims(a, b):
if not isinstance(a, Variable | DataArray | Dataset):
return b
if set(a.dims) == set(b.dims):
# Ensure transpose won't fail if a dimension is missing
# If this is the case, the difference will be caught by the caller
return b.transpose(*a.dims)
return b

if check_dim_order:
return b
return a, b

# Find common dimensions and transpose both to canonical order
common_dims = set(a.dims) & set(b.dims)
if common_dims:
# Use order from the intersection, with ellipsis for any unique dims
canonical_order = list(common_dims) + [...]
# For Datasets, we need to transpose both to the same order
# For Variable/DataArray, we could just transpose b, but for consistency
# and simplicity we transpose both
return a.transpose(*canonical_order), b.transpose(*canonical_order)
return a, b

if isinstance(a, DataTree):
return map_over_datasets(_maybe_transpose_dims, a, b)
# DataTree needs special handling - transpose both trees
# map_over_datasets applies a function over corresponding datasets
transposed_a = map_over_datasets(
lambda a_ds, b_ds: _maybe_transpose_dims(a_ds, b_ds)[0], a, b
)
transposed_b = map_over_datasets(
lambda a_ds, b_ds: _maybe_transpose_dims(a_ds, b_ds)[1], a, b
)
return transposed_a, transposed_b

return _maybe_transpose_dims(a, b)

Expand Down Expand Up @@ -139,7 +156,7 @@ def assert_equal(a, b, check_dim_order: bool = True):
assert type(a) is type(b) or (
isinstance(a, Coordinates) and isinstance(b, Coordinates)
)
b = maybe_transpose_dims(a, b, check_dim_order)
a, b = maybe_transpose_dims(a, b, check_dim_order)
if isinstance(a, Variable | DataArray):
assert a.equals(b), formatting.diff_array_repr(a, b, "equals")
elif isinstance(a, Dataset):
Expand Down Expand Up @@ -227,7 +244,7 @@ def assert_allclose(
"""
__tracebackhide__ = True
assert type(a) is type(b)
b = maybe_transpose_dims(a, b, check_dim_order)
a, b = maybe_transpose_dims(a, b, check_dim_order)

equiv = functools.partial(
_data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes
Expand Down
123 changes: 123 additions & 0 deletions xarray/tests/test_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,28 @@ def test_assert_equal_transpose_datatree() -> None:

xr.testing.assert_equal(a, b, check_dim_order=False)

# Test with mixed dimension orders in datasets (the tricky case)
import numpy as np

ds_mixed = xr.Dataset(
{
"foo": xr.DataArray(np.zeros([4, 5]), dims=("a", "b")),
"bar": xr.DataArray(np.ones([5, 4]), dims=("b", "a")),
}
)
ds_mixed2 = xr.Dataset(
{
"foo": xr.DataArray(np.zeros([5, 4]), dims=("b", "a")),
"bar": xr.DataArray(np.ones([4, 5]), dims=("a", "b")),
}
)

tree1 = xr.DataTree.from_dict({"node": ds_mixed})
tree2 = xr.DataTree.from_dict({"node": ds_mixed2})

# Should work with check_dim_order=False
xr.testing.assert_equal(tree1, tree2, check_dim_order=False)


@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -224,3 +246,104 @@ def __array__(
getattr(xr.testing, func)(a, b)

assert len(w) == 0


def test_assert_equal_dataset_check_dim_order():
"""Test for issue #10704 - check_dim_order=False with Datasets containing mixed dimension orders."""
import numpy as np

# Dataset with variables having different dimension orders
dataset_1 = xr.Dataset(
{
"foo": xr.DataArray(np.zeros([4, 5]), dims=("a", "b")),
"bar": xr.DataArray(np.ones([5, 4]), dims=("b", "a")),
}
)

dataset_2 = xr.Dataset(
{
"foo": xr.DataArray(np.zeros([5, 4]), dims=("b", "a")),
"bar": xr.DataArray(np.ones([4, 5]), dims=("a", "b")),
}
)

# These should be equal when ignoring dimension order
xr.testing.assert_equal(dataset_1, dataset_2, check_dim_order=False)
xr.testing.assert_allclose(dataset_1, dataset_2, check_dim_order=False)

# Should also work when comparing dataset to itself
xr.testing.assert_equal(dataset_1, dataset_1, check_dim_order=False)
xr.testing.assert_allclose(dataset_1, dataset_1, check_dim_order=False)

# But should fail with check_dim_order=True
with pytest.raises(AssertionError):
xr.testing.assert_equal(dataset_1, dataset_2, check_dim_order=True)
with pytest.raises(AssertionError):
xr.testing.assert_allclose(dataset_1, dataset_2, check_dim_order=True)

# Test with non-sortable dimension names (int and str)
dataset_mixed_1 = xr.Dataset(
{
"foo": xr.DataArray(np.zeros([4, 5]), dims=(1, "b")),
"bar": xr.DataArray(np.ones([5, 4]), dims=("b", 1)),
}
)

dataset_mixed_2 = xr.Dataset(
{
"foo": xr.DataArray(np.zeros([5, 4]), dims=("b", 1)),
"bar": xr.DataArray(np.ones([4, 5]), dims=(1, "b")),
}
)

# Should work with mixed types when ignoring dimension order
xr.testing.assert_equal(dataset_mixed_1, dataset_mixed_2, check_dim_order=False)
xr.testing.assert_equal(dataset_mixed_1, dataset_mixed_1, check_dim_order=False)


def test_assert_equal_no_common_dims():
"""Test assert_equal when objects have no common dimensions."""
import numpy as np

# DataArrays with completely different dimensions
da1 = xr.DataArray(np.zeros([4, 5]), dims=("x", "y"))
da2 = xr.DataArray(np.zeros([3, 2]), dims=("a", "b"))

# Should fail even with check_dim_order=False since dims are different
with pytest.raises(AssertionError):
xr.testing.assert_equal(da1, da2, check_dim_order=False)

# Datasets with no common dimensions
ds1 = xr.Dataset(
{
"foo": xr.DataArray(np.zeros([4]), dims=("x",)),
"bar": xr.DataArray(np.ones([5]), dims=("y",)),
}
)
ds2 = xr.Dataset(
{
"foo": xr.DataArray(np.zeros([3]), dims=("a",)),
"bar": xr.DataArray(np.ones([2]), dims=("b",)),
}
)

# Should fail since dimensions are completely different
with pytest.raises(AssertionError):
xr.testing.assert_equal(ds1, ds2, check_dim_order=False)


def test_assert_equal_variable_transpose():
"""Test assert_equal with transposed Variable objects."""
import numpy as np

# Variables with transposed dimensions
var1 = xr.Variable(("x", "y"), np.zeros([4, 5]))
var2 = xr.Variable(("y", "x"), np.zeros([5, 4]))

# Should fail with check_dim_order=True
with pytest.raises(AssertionError):
xr.testing.assert_equal(var1, var2, check_dim_order=True)

# Should pass with check_dim_order=False
xr.testing.assert_equal(var1, var2, check_dim_order=False)
xr.testing.assert_allclose(var1, var2, check_dim_order=False)
Loading