Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ Deprecations
Bug Fixes
~~~~~~~~~

- Fix multi-coordinate indexes being dropped in :py:meth:`DataArray._replace_maybe_drop_dims`
(e.g. after reducing over an unrelated dimension) and in :py:meth:`Dataset._copy_listed`
(e.g. when subsetting a Dataset by variable names). Both paths now consult
:py:meth:`Index.should_add_coord_to_array`, consistent with
:py:meth:`Dataset._construct_dataarray`. Also simplify :py:meth:`Dataset.to_dataarray`
to keep all coordinates and indexes directly, since variables are broadcast and all
coords are retained (:issue:`11215`, :pull:`11286`).
By `Rich Signell <https://github.com/rsignell>`_.
- Allow writing ``StringDType`` variables to netCDF files (:issue:`11199`).
By `Kristian Kollsgård <https://github.com/kkollsga>`_.
- Fix ``Source`` link in api docs (:pull:`11187`)
Expand Down
10 changes: 7 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,13 @@ def _replace_maybe_drop_dims(
indexes = filter_indexes_from_coords(self._indexes, set(coords))
else:
allowed_dims = set(variable.dims)
coords = {
k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims
}
coords = {}
for k, v in self._coords.items():
if k in self._indexes:
if self._indexes[k].should_add_coord_to_array(k, v, allowed_dims):
coords[k] = v
elif set(v.dims) <= allowed_dims:
coords[k] = v
indexes = filter_indexes_from_coords(self._indexes, set(coords))
return self._replace(variable, coords, name, indexes=indexes)

Expand Down
10 changes: 8 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,13 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
if k not in self._coord_names:
continue

if set(self.variables[k].dims) <= needed_dims:
if k in self._indexes:
if self._indexes[k].should_add_coord_to_array(
k, self._variables[k], set(needed_dims)
):
variables[k] = self._variables[k]
coord_names.add(k)
elif set(self.variables[k].dims) <= needed_dims:
variables[k] = self._variables[k]
coord_names.add(k)

Expand Down Expand Up @@ -7155,7 +7161,7 @@ def to_dataarray(
variable = Variable(dims, data, self.attrs, fastpath=True)

coords = {k: v.variable for k, v in self.coords.items()}
indexes = filter_indexes_from_coords(self._indexes, set(coords))
indexes = dict(self._indexes)
new_dim_index = PandasIndex(list(self.data_vars), dim)
indexes[dim] = new_dim_index
coords.update(new_dim_index.create_variables())
Expand Down
33 changes: 33 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,39 @@ def should_add_coord_to_array(self, name, var, dims):
assert_identical(actual.coords, coords, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_replace_maybe_drop_dims_preserves_multi_coord_index(self) -> None:
# Regression test for https://github.com/pydata/xarray/issues/11215
# Multi-coordinate indexes spanning multiple dims should be preserved
# after reducing over an unrelated dimension.
class MultiDimIndex(Index):
def should_add_coord_to_array(self, name, var, dims):
return True

idx = MultiDimIndex()
coords = Coordinates(
coords={
"node_x": ("nodes", [0.0, 1.0, 2.0]),
"node_y": ("nodes", [0.0, 0.0, 1.0]),
"face_x": ("faces", [0.5, 1.5]),
"face_y": ("faces", [0.5, 0.5]),
},
indexes=dict.fromkeys(["node_x", "node_y", "face_x", "face_y"], idx),
)
node_da = DataArray(
np.random.rand(3, 4), dims=("nodes", "extra"), coords=coords
)
face_da = DataArray(
np.random.rand(2, 4), dims=("faces", "extra"), coords=coords
)

reduced_node = node_da.mean("extra")
reduced_face = face_da.mean("extra")

for da in [reduced_node, reduced_face]:
for name in ["node_x", "node_y", "face_x", "face_y"]:
assert name in da.coords
assert isinstance(da.xindexes[name], MultiDimIndex)

def test_equals_and_identical(self) -> None:
orig = DataArray(np.arange(5.0), {"a": 42}, dims="x")

Expand Down
65 changes: 65 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4553,6 +4553,71 @@ def should_add_coord_to_array(self, name, var, dims):
assert_identical(actual.coords, coords, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_copy_listed_preserves_multi_coord_index(self) -> None:
# Regression test for https://github.com/pydata/xarray/issues/11215
# Multi-coordinate indexes spanning multiple dims should be preserved
# when subsetting a Dataset by variable names via ds[["var"]].
class MultiDimIndex(Index):
def should_add_coord_to_array(self, name, var, dims):
return True

idx = MultiDimIndex()
coords = Coordinates(
coords={
"node_x": ("nodes", [0.0, 1.0, 2.0]),
"node_y": ("nodes", [0.0, 0.0, 1.0]),
"face_x": ("faces", [0.5, 1.5]),
"face_y": ("faces", [0.5, 0.5]),
},
indexes=dict.fromkeys(["node_x", "node_y", "face_x", "face_y"], idx),
)
ds = Dataset(
{
"node_data": (("nodes",), [1.0, 2.0, 3.0]),
"face_data": (("faces",), [10.0, 20.0]),
},
coords=coords,
)

node_subset = ds[["node_data"]]
face_subset = ds[["face_data"]]

for ds_sub in [node_subset, face_subset]:
for name in ["node_x", "node_y", "face_x", "face_y"]:
assert name in ds_sub.coords
assert isinstance(ds_sub.xindexes[name], MultiDimIndex)

def test_to_dataarray_preserves_multi_coord_index(self) -> None:
# Regression test for https://github.com/pydata/xarray/issues/11215
# Multi-coordinate indexes spanning multiple dims should be preserved
# when converting a Dataset to a DataArray via to_dataarray().
class MultiDimIndex(Index):
def should_add_coord_to_array(self, name, var, dims):
return True

idx = MultiDimIndex()
coords = Coordinates(
coords={
"node_x": ("nodes", [0.0, 1.0, 2.0]),
"node_y": ("nodes", [0.0, 0.0, 1.0]),
"face_x": ("faces", [0.5, 1.5]),
"face_y": ("faces", [0.5, 0.5]),
},
indexes=dict.fromkeys(["node_x", "node_y", "face_x", "face_y"], idx),
)
ds = Dataset(
{
"node_data": (("nodes",), [1.0, 2.0, 3.0]),
},
coords=coords,
)

da = ds.to_dataarray()

for name in ["node_x", "node_y", "face_x", "face_y"]:
assert name in da.coords
assert isinstance(da.xindexes[name], MultiDimIndex)

def test_virtual_variables_default_coords(self) -> None:
dataset = Dataset({"foo": ("x", range(10))})
expected1 = DataArray(range(10), dims="x", name="x")
Expand Down
Loading