Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ New Features
All of Xarray's netCDF backends now support in-memory reads and writes
(:pull:`10624`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- :py:func:`merge` now supports merging :py:class:`DataTree` objects
(:issue:`9790`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from xarray.core.dataset import Dataset
from xarray.core.dataset_variables import DataVariables
from xarray.core.datatree_mapping import (
_handle_errors_with_path_context,
add_path_context_to_errors,
map_over_datasets,
)
from xarray.core.formatting import (
Expand Down Expand Up @@ -2213,8 +2213,8 @@ def _selective_indexing(
result = {}
for path, node in self.subtree_with_keys:
node_indexers = {k: v for k, v in indexers.items() if k in node.dims}
func_with_error_context = _handle_errors_with_path_context(path)(func)
node_result = func_with_error_context(node.dataset, node_indexers)
with add_path_context_to_errors(path):
node_result = func(node.dataset, node_indexers)
# Indexing datasets corresponding to each node results in redundant
# coordinates when indexes from a parent node are inherited.
# Ideally, we would avoid creating such coordinates in the first
Expand Down
35 changes: 11 additions & 24 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Callable, Mapping
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, cast, overload

from xarray.core.dataset import Dataset
Expand Down Expand Up @@ -112,9 +113,8 @@ def map_over_datasets(
for i, arg in enumerate(args):
if not isinstance(arg, DataTree):
node_dataset_args.insert(i, arg)

func_with_error_context = _handle_errors_with_path_context(path)(func)
results = func_with_error_context(*node_dataset_args, **kwargs)
with add_path_context_to_errors(path):
results = func(*node_dataset_args, **kwargs)
out_data_objects[path] = results

num_return_values = _check_all_return_values(out_data_objects)
Expand All @@ -138,27 +138,14 @@ def map_over_datasets(
)


def _handle_errors_with_path_context(path: str):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""

def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
# Add the context information to the error message
add_note(
e, f"Raised whilst mapping function over node with path {path!r}"
)
raise

return wrapper

return decorator


def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)
@contextmanager
def add_path_context_to_errors(path: str):
"""Add path context to any errors."""
try:
yield
except Exception as e:
e.add_note(f"Raised whilst mapping function over node(s) with path {path!r}")
raise


def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None:
Expand Down
6 changes: 1 addition & 5 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
import sys
from collections.abc import Iterator, Mapping
from pathlib import PurePosixPath
from typing import (
TYPE_CHECKING,
Any,
TypeVar,
)
from typing import TYPE_CHECKING, Any, TypeVar

from xarray.core.types import Self
from xarray.core.utils import Frozen, is_dict_like
Expand Down
117 changes: 109 additions & 8 deletions xarray/structure/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Set as AbstractSet
from typing import TYPE_CHECKING, Any, NamedTuple, Union
from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast, overload

import pandas as pd

Expand Down Expand Up @@ -34,6 +34,7 @@
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.types import (
CombineAttrsOptions,
CompatOptions,
Expand Down Expand Up @@ -793,18 +794,99 @@ def merge_core(
return _MergeResult(variables, coord_names, dims, out_indexes, attrs)


def merge_trees(
trees: Iterable[DataTree],
compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT,
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
fill_value: object = dtypes.NA,
combine_attrs: CombineAttrsOptions = "override",
) -> DataTree:
"""Merge specialized to DataTree objects."""
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.datatree_mapping import add_path_context_to_errors

if fill_value is not dtypes.NA:
# fill_value support dicts, which probably should be mapped to sub-groups?
raise NotImplementedError(
"fill_value is not yet supported for DataTree objects in merge"
)

node_lists: defaultdict[str, list[DataTree]] = defaultdict(list)
for tree in trees:
for key, node in tree.subtree_with_keys:
node_lists[key].append(node)

root_datasets = [node.dataset for node in node_lists.pop(".")]
with add_path_context_to_errors("."):
root_ds = merge(
root_datasets, compat=compat, join=join, combine_attrs=combine_attrs
)
result = DataTree(dataset=root_ds)

def depth(kv):
return kv[0].count("/")

for key, nodes in sorted(node_lists.items(), key=depth):
# Merge datasets, including inherited indexes to ensure alignment.
datasets = [node.dataset for node in nodes]
with add_path_context_to_errors(key):
merge_result = merge_core(
datasets,
compat=compat,
join=join,
combine_attrs=combine_attrs,
)
# Remove inherited coordinates/indexes/dimensions.
for var_name in list(merge_result.coord_names):
if not any(var_name in node._coord_variables for node in nodes):
del merge_result.variables[var_name]
merge_result.coord_names.remove(var_name)
for index_name in list(merge_result.indexes):
if not any(index_name in node._node_indexes for node in nodes):
del merge_result.indexes[index_name]
for dim in list(merge_result.dims):
if not any(dim in node._node_dims for node in nodes):
del merge_result.dims[dim]

merged_ds = Dataset._construct_direct(**merge_result._asdict())
result[key] = DataTree(dataset=merged_ds)

return result


@overload
def merge(
objects: Iterable[DataTree],
compat: CompatOptions | CombineKwargDefault = ...,
join: JoinOptions | CombineKwargDefault = ...,
fill_value: object = ...,
combine_attrs: CombineAttrsOptions = ...,
) -> DataTree: ...


@overload
def merge(
objects: Iterable[DataArray | Dataset | Coordinates | dict],
compat: CompatOptions | CombineKwargDefault = ...,
join: JoinOptions | CombineKwargDefault = ...,
fill_value: object = ...,
combine_attrs: CombineAttrsOptions = ...,
) -> Dataset: ...


def merge(
objects: Iterable[DataArray | CoercibleMapping],
objects: Iterable[DataTree | DataArray | Dataset | Coordinates | dict],
compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT,
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
fill_value: object = dtypes.NA,
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> DataTree | Dataset:
"""Merge any number of xarray objects into a single Dataset as variables.

Parameters
----------
objects : iterable of Dataset or iterable of DataArray or iterable of dict-like
objects : iterable of DataArray, Dataset, DataTree or dict
Merge together all variables from these objects. If any of them are
DataArray objects, they must have a name.
compat : {"identical", "equals", "broadcast_equals", "no_conflicts", \
Expand Down Expand Up @@ -859,8 +941,9 @@ def merge(

Returns
-------
Dataset
Dataset with combined variables from each object.
Dataset or DataTree
Objects with combined variables from the inputs. If any inputs are a
DataTree, this will also be a DataTree. Otherwise it will be a Dataset.

Examples
--------
Expand Down Expand Up @@ -1023,13 +1106,31 @@ def merge(
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree

objects = list(objects)

if any(isinstance(obj, DataTree) for obj in objects):
if not all(isinstance(obj, DataTree) for obj in objects):
raise TypeError(
"merge does not support mixed type arguments when one argument "
f"is a DataTree: {objects}"
)
trees = cast(list[DataTree], objects)
return merge_trees(
trees,
compat=compat,
join=join,
combine_attrs=combine_attrs,
fill_value=fill_value,
)

dict_like_objects = []
for obj in objects:
if not isinstance(obj, DataArray | Dataset | Coordinates | dict):
raise TypeError(
"objects must be an iterable containing only "
"Dataset(s), DataArray(s), and dictionaries."
"objects must be an iterable containing only DataTree(s), "
f"Dataset(s), DataArray(s), and dictionaries: {objects}"
)

if isinstance(obj, DataArray):
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,15 +2217,15 @@
}
)

with pytest.raises(

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11 bare-minimum

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11 bare-min-and-scipy

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12 all-but-dask

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.13 all-but-numba

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.11

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11 min-all-deps

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.13

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''

Check failure on line 2220 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.13

TestIndexing.test_sel_isel_error_has_node_info AssertionError: Regex pattern did not match. Regex: "Raised whilst mapping function over node(s) with path 'second'" Input: '"not all values found in index \'x\'. Try setting the `method` keyword argument (example: method=\'nearest\')."\nRaised whilst mapping function over node(s) with path \'second\''
KeyError,
match="Raised whilst mapping function over node with path 'second'",
match="Raised whilst mapping function over node(s) with path 'second'",
):
tree.sel(x=1)

with pytest.raises(
IndexError,
match="Raised whilst mapping function over node with path 'first'",
match="Raised whilst mapping function over node(s) with path 'first'",
):
tree.isel(x=4)

Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def fail_on_specific_node(ds):
with pytest.raises(
ValueError,
match=re.escape(
r"Raised whilst mapping function over node with path 'set1'"
r"Raised whilst mapping function over node(s) with path 'set1'"
),
):
dt.map_over_datasets(fail_on_specific_node)
Expand Down
93 changes: 93 additions & 0 deletions xarray/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
import warnings

import numpy as np
Expand Down Expand Up @@ -867,3 +868,95 @@ def test_merge_auto_align(self):
with set_options(use_new_combine_kwarg_defaults=True):
with pytest.raises(ValueError, match="might be related to new default"):
expected.identical(ds2.merge(ds1))


class TestMergeDataTree:
def test_mixed(self) -> None:
tree = xr.DataTree()
ds = xr.Dataset()
with pytest.raises(
TypeError,
match="merge does not support mixed type arguments when one argument is a DataTree",
):
xr.merge([tree, ds]) # type: ignore[list-item]

def test_distinct(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b/c": 1})
tree2 = xr.DataTree.from_dict({"/a/d/e": 2})
expected = xr.DataTree.from_dict({"/a/b/c": 1, "/a/d/e": 2})
merged = xr.merge([tree1, tree2])
assert_equal(merged, expected)

def test_overlap(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b": 1})
tree2 = xr.DataTree.from_dict({"/a/c": 2})
tree3 = xr.DataTree.from_dict({"/a/d": 3})
expected = xr.DataTree.from_dict({"/a/b": 1, "/a/c": 2, "/a/d": 3})
merged = xr.merge([tree1, tree2, tree3])
assert_equal(merged, expected)

def test_inherited(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b": ("x", [1])}, coords={"x": [0]})
tree2 = xr.DataTree.from_dict({"/a/c": ("x", [2])})
expected = xr.DataTree.from_dict(
{"/a/b": ("x", [1]), "a/c": ("x", [2])}, coords={"x": [0]}
)
merged = xr.merge([tree1, tree2])
assert_equal(merged, expected)

def test_inherited_join(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b": ("x", [0, 1])}, coords={"x": [0, 1]})
tree2 = xr.DataTree.from_dict({"/a/c": ("x", [1, 2])}, coords={"x": [1, 2]})

expected = xr.DataTree.from_dict(
{"/a/b": ("x", [0, 1]), "a/c": ("x", [np.nan, 1])}, coords={"x": [0, 1]}
)
merged = xr.merge([tree1, tree2], join="left")
assert_equal(merged, expected)

expected = xr.DataTree.from_dict(
{"/a/b": ("x", [1, np.nan]), "a/c": ("x", [1, 2])}, coords={"x": [1, 2]}
)
merged = xr.merge([tree1, tree2], join="right")
assert_equal(merged, expected)

expected = xr.DataTree.from_dict(
{"/a/b": ("x", [1]), "a/c": ("x", [1])}, coords={"x": [1]}
)
merged = xr.merge([tree1, tree2], join="inner")
assert_equal(merged, expected)

expected = xr.DataTree.from_dict(
{"/a/b": ("x", [0, 1, np.nan]), "a/c": ("x", [np.nan, 1, 2])},
coords={"x": [0, 1, 2]},
)
merged = xr.merge([tree1, tree2], join="outer")
assert_equal(merged, expected)

with pytest.raises(
xr.AlignmentError,
match=re.escape("cannot align objects with join='exact'"),
):
xr.merge([tree1, tree2], join="exact")

def test_merge_error_includes_path(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b": ("x", [0, 1])})
tree2 = xr.DataTree.from_dict({"/a/b": ("x", [1, 2])})
with pytest.raises(
xr.MergeError,
match=re.escape(
"Raised whilst mapping function over node(s) with path 'a'"
),
):
xr.merge([tree1, tree2], join="exact")

def test_fill_value_errors(self) -> None:
trees = [xr.DataTree(), xr.DataTree()]

with pytest.raises(
NotImplementedError,
match=re.escape(
"fill_value is not yet supported for DataTree objects in merge"
),
):
xr.merge(trees, fill_value=None)
Loading