diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 474a72da739..1eac42ba371 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -85,24 +85,41 @@ def assert_isomorphic(a: DataTree, b: DataTree): def maybe_transpose_dims(a, b, check_dim_order: bool): - """Helper for assert_equal/allclose/identical""" + """Helper for assert_equal/allclose/identical + + Returns (a, b) tuple with dimensions transposed to canonical order if needed. + """ __tracebackhide__ = True + if check_dim_order: + return a, b + def _maybe_transpose_dims(a, b): if not isinstance(a, Variable | DataArray | Dataset): - return b - if set(a.dims) == set(b.dims): - # Ensure transpose won't fail if a dimension is missing - # If this is the case, the difference will be caught by the caller - return b.transpose(*a.dims) - return b - - if check_dim_order: - return b + return a, b + + # Find common dimensions and transpose both to canonical order + common_dims = set(a.dims) & set(b.dims) + if common_dims: + # Use order from the intersection, with ellipsis for any unique dims + canonical_order = list(common_dims) + [...] + # For Datasets, we need to transpose both to the same order + # For Variable/DataArray, we could just transpose b, but for consistency + # and simplicity we transpose both + return a.transpose(*canonical_order), b.transpose(*canonical_order) + return a, b if isinstance(a, DataTree): - return map_over_datasets(_maybe_transpose_dims, a, b) + # DataTree needs special handling - transpose both trees + # map_over_datasets applies a function over corresponding datasets + transposed_a = map_over_datasets( + lambda a_ds, b_ds: _maybe_transpose_dims(a_ds, b_ds)[0], a, b + ) + transposed_b = map_over_datasets( + lambda a_ds, b_ds: _maybe_transpose_dims(a_ds, b_ds)[1], a, b + ) + return transposed_a, transposed_b return _maybe_transpose_dims(a, b) @@ -139,7 +156,7 @@ def assert_equal(a, b, check_dim_order: bool = True): assert type(a) is type(b) or ( isinstance(a, Coordinates) and isinstance(b, Coordinates) ) - b = maybe_transpose_dims(a, b, check_dim_order) + a, b = maybe_transpose_dims(a, b, check_dim_order) if isinstance(a, Variable | DataArray): assert a.equals(b), formatting.diff_array_repr(a, b, "equals") elif isinstance(a, Dataset): @@ -227,7 +244,7 @@ def assert_allclose( """ __tracebackhide__ = True assert type(a) is type(b) - b = maybe_transpose_dims(a, b, check_dim_order) + a, b = maybe_transpose_dims(a, b, check_dim_order) equiv = functools.partial( _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes diff --git a/xarray/tests/test_assertions.py b/xarray/tests/test_assertions.py index 222a01a6628..290dda6e124 100644 --- a/xarray/tests/test_assertions.py +++ b/xarray/tests/test_assertions.py @@ -100,6 +100,28 @@ def test_assert_equal_transpose_datatree() -> None: xr.testing.assert_equal(a, b, check_dim_order=False) + # Test with mixed dimension orders in datasets (the tricky case) + import numpy as np + + ds_mixed = xr.Dataset( + { + "foo": xr.DataArray(np.zeros([4, 5]), dims=("a", "b")), + "bar": xr.DataArray(np.ones([5, 4]), dims=("b", "a")), + } + ) + ds_mixed2 = xr.Dataset( + { + "foo": xr.DataArray(np.zeros([5, 4]), dims=("b", "a")), + "bar": xr.DataArray(np.ones([4, 5]), dims=("a", "b")), + } + ) + + tree1 = xr.DataTree.from_dict({"node": ds_mixed}) + tree2 = xr.DataTree.from_dict({"node": ds_mixed2}) + + # Should work with check_dim_order=False + xr.testing.assert_equal(tree1, tree2, check_dim_order=False) + @pytest.mark.filterwarnings("error") @pytest.mark.parametrize( @@ -224,3 +246,104 @@ def __array__( getattr(xr.testing, func)(a, b) assert len(w) == 0 + + +def test_assert_equal_dataset_check_dim_order(): + """Test for issue #10704 - check_dim_order=False with Datasets containing mixed dimension orders.""" + import numpy as np + + # Dataset with variables having different dimension orders + dataset_1 = xr.Dataset( + { + "foo": xr.DataArray(np.zeros([4, 5]), dims=("a", "b")), + "bar": xr.DataArray(np.ones([5, 4]), dims=("b", "a")), + } + ) + + dataset_2 = xr.Dataset( + { + "foo": xr.DataArray(np.zeros([5, 4]), dims=("b", "a")), + "bar": xr.DataArray(np.ones([4, 5]), dims=("a", "b")), + } + ) + + # These should be equal when ignoring dimension order + xr.testing.assert_equal(dataset_1, dataset_2, check_dim_order=False) + xr.testing.assert_allclose(dataset_1, dataset_2, check_dim_order=False) + + # Should also work when comparing dataset to itself + xr.testing.assert_equal(dataset_1, dataset_1, check_dim_order=False) + xr.testing.assert_allclose(dataset_1, dataset_1, check_dim_order=False) + + # But should fail with check_dim_order=True + with pytest.raises(AssertionError): + xr.testing.assert_equal(dataset_1, dataset_2, check_dim_order=True) + with pytest.raises(AssertionError): + xr.testing.assert_allclose(dataset_1, dataset_2, check_dim_order=True) + + # Test with non-sortable dimension names (int and str) + dataset_mixed_1 = xr.Dataset( + { + "foo": xr.DataArray(np.zeros([4, 5]), dims=(1, "b")), + "bar": xr.DataArray(np.ones([5, 4]), dims=("b", 1)), + } + ) + + dataset_mixed_2 = xr.Dataset( + { + "foo": xr.DataArray(np.zeros([5, 4]), dims=("b", 1)), + "bar": xr.DataArray(np.ones([4, 5]), dims=(1, "b")), + } + ) + + # Should work with mixed types when ignoring dimension order + xr.testing.assert_equal(dataset_mixed_1, dataset_mixed_2, check_dim_order=False) + xr.testing.assert_equal(dataset_mixed_1, dataset_mixed_1, check_dim_order=False) + + +def test_assert_equal_no_common_dims(): + """Test assert_equal when objects have no common dimensions.""" + import numpy as np + + # DataArrays with completely different dimensions + da1 = xr.DataArray(np.zeros([4, 5]), dims=("x", "y")) + da2 = xr.DataArray(np.zeros([3, 2]), dims=("a", "b")) + + # Should fail even with check_dim_order=False since dims are different + with pytest.raises(AssertionError): + xr.testing.assert_equal(da1, da2, check_dim_order=False) + + # Datasets with no common dimensions + ds1 = xr.Dataset( + { + "foo": xr.DataArray(np.zeros([4]), dims=("x",)), + "bar": xr.DataArray(np.ones([5]), dims=("y",)), + } + ) + ds2 = xr.Dataset( + { + "foo": xr.DataArray(np.zeros([3]), dims=("a",)), + "bar": xr.DataArray(np.ones([2]), dims=("b",)), + } + ) + + # Should fail since dimensions are completely different + with pytest.raises(AssertionError): + xr.testing.assert_equal(ds1, ds2, check_dim_order=False) + + +def test_assert_equal_variable_transpose(): + """Test assert_equal with transposed Variable objects.""" + import numpy as np + + # Variables with transposed dimensions + var1 = xr.Variable(("x", "y"), np.zeros([4, 5])) + var2 = xr.Variable(("y", "x"), np.zeros([5, 4])) + + # Should fail with check_dim_order=True + with pytest.raises(AssertionError): + xr.testing.assert_equal(var1, var2, check_dim_order=True) + + # Should pass with check_dim_order=False + xr.testing.assert_equal(var1, var2, check_dim_order=False) + xr.testing.assert_allclose(var1, var2, check_dim_order=False)