diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8d52557a5c6..668fbad067b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -20,6 +20,9 @@ New Features All of Xarray's netCDF backends now support in-memory reads and writes (:pull:`10624`). By `Stephan Hoyer `_. +- :py:func:`merge` now supports merging :py:class:`DataTree` objects + (:issue:`9790`). + By `Stephan Hoyer `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index d0b4f429b08..c2cfb7435bf 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -37,7 +37,7 @@ from xarray.core.dataset import Dataset from xarray.core.dataset_variables import DataVariables from xarray.core.datatree_mapping import ( - _handle_errors_with_path_context, + add_path_context_to_errors, map_over_datasets, ) from xarray.core.formatting import ( @@ -2213,8 +2213,8 @@ def _selective_indexing( result = {} for path, node in self.subtree_with_keys: node_indexers = {k: v for k, v in indexers.items() if k in node.dims} - func_with_error_context = _handle_errors_with_path_context(path)(func) - node_result = func_with_error_context(node.dataset, node_indexers) + with add_path_context_to_errors(path): + node_result = func(node.dataset, node_indexers) # Indexing datasets corresponding to each node results in redundant # coordinates when indexes from a parent node are inherited. # Ideally, we would avoid creating such coordinates in the first diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index dc2fc591f44..b905545c064 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping +from contextlib import contextmanager from typing import TYPE_CHECKING, Any, cast, overload from xarray.core.dataset import Dataset @@ -112,9 +113,8 @@ def map_over_datasets( for i, arg in enumerate(args): if not isinstance(arg, DataTree): node_dataset_args.insert(i, arg) - - func_with_error_context = _handle_errors_with_path_context(path)(func) - results = func_with_error_context(*node_dataset_args, **kwargs) + with add_path_context_to_errors(path): + results = func(*node_dataset_args, **kwargs) out_data_objects[path] = results num_return_values = _check_all_return_values(out_data_objects) @@ -138,27 +138,14 @@ def map_over_datasets( ) -def _handle_errors_with_path_context(path: str): - """Wraps given function so that if it fails it also raises path to node on which it failed.""" - - def decorator(func): - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as e: - # Add the context information to the error message - add_note( - e, f"Raised whilst mapping function over node with path {path!r}" - ) - raise - - return wrapper - - return decorator - - -def add_note(err: BaseException, msg: str) -> None: - err.add_note(msg) +@contextmanager +def add_path_context_to_errors(path: str): + """Add path context to any errors.""" + try: + yield + except Exception as e: + e.add_note(f"Raised whilst mapping function over node(s) with path {path!r}") + raise def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None: diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 58c0efafbdb..9bc52637a5c 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -4,11 +4,7 @@ import sys from collections.abc import Iterator, Mapping from pathlib import PurePosixPath -from typing import ( - TYPE_CHECKING, - Any, - TypeVar, -) +from typing import TYPE_CHECKING, Any, TypeVar from xarray.core.types import Self from xarray.core.utils import Frozen, is_dict_like diff --git a/xarray/structure/merge.py b/xarray/structure/merge.py index 5bb53036042..d18ece884cc 100644 --- a/xarray/structure/merge.py +++ b/xarray/structure/merge.py @@ -3,7 +3,7 @@ from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping, Sequence from collections.abc import Set as AbstractSet -from typing import TYPE_CHECKING, Any, NamedTuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast, overload import pandas as pd @@ -34,6 +34,7 @@ from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree from xarray.core.types import ( CombineAttrsOptions, CompatOptions, @@ -793,18 +794,99 @@ def merge_core( return _MergeResult(variables, coord_names, dims, out_indexes, attrs) +def merge_trees( + trees: Iterable[DataTree], + compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT, + join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT, + fill_value: object = dtypes.NA, + combine_attrs: CombineAttrsOptions = "override", +) -> DataTree: + """Merge specialized to DataTree objects.""" + from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree + from xarray.core.datatree_mapping import add_path_context_to_errors + + if fill_value is not dtypes.NA: + # fill_value support dicts, which probably should be mapped to sub-groups? + raise NotImplementedError( + "fill_value is not yet supported for DataTree objects in merge" + ) + + node_lists: defaultdict[str, list[DataTree]] = defaultdict(list) + for tree in trees: + for key, node in tree.subtree_with_keys: + node_lists[key].append(node) + + root_datasets = [node.dataset for node in node_lists.pop(".")] + with add_path_context_to_errors("."): + root_ds = merge( + root_datasets, compat=compat, join=join, combine_attrs=combine_attrs + ) + result = DataTree(dataset=root_ds) + + def depth(kv): + return kv[0].count("/") + + for key, nodes in sorted(node_lists.items(), key=depth): + # Merge datasets, including inherited indexes to ensure alignment. + datasets = [node.dataset for node in nodes] + with add_path_context_to_errors(key): + merge_result = merge_core( + datasets, + compat=compat, + join=join, + combine_attrs=combine_attrs, + ) + # Remove inherited coordinates/indexes/dimensions. + for var_name in list(merge_result.coord_names): + if not any(var_name in node._coord_variables for node in nodes): + del merge_result.variables[var_name] + merge_result.coord_names.remove(var_name) + for index_name in list(merge_result.indexes): + if not any(index_name in node._node_indexes for node in nodes): + del merge_result.indexes[index_name] + for dim in list(merge_result.dims): + if not any(dim in node._node_dims for node in nodes): + del merge_result.dims[dim] + + merged_ds = Dataset._construct_direct(**merge_result._asdict()) + result[key] = DataTree(dataset=merged_ds) + + return result + + +@overload +def merge( + objects: Iterable[DataTree], + compat: CompatOptions | CombineKwargDefault = ..., + join: JoinOptions | CombineKwargDefault = ..., + fill_value: object = ..., + combine_attrs: CombineAttrsOptions = ..., +) -> DataTree: ... + + +@overload +def merge( + objects: Iterable[DataArray | Dataset | Coordinates | dict], + compat: CompatOptions | CombineKwargDefault = ..., + join: JoinOptions | CombineKwargDefault = ..., + fill_value: object = ..., + combine_attrs: CombineAttrsOptions = ..., +) -> Dataset: ... + + def merge( - objects: Iterable[DataArray | CoercibleMapping], + objects: Iterable[DataTree | DataArray | Dataset | Coordinates | dict], compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT, join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT, fill_value: object = dtypes.NA, combine_attrs: CombineAttrsOptions = "override", -) -> Dataset: +) -> DataTree | Dataset: """Merge any number of xarray objects into a single Dataset as variables. Parameters ---------- - objects : iterable of Dataset or iterable of DataArray or iterable of dict-like + objects : iterable of DataArray, Dataset, DataTree or dict Merge together all variables from these objects. If any of them are DataArray objects, they must have a name. compat : {"identical", "equals", "broadcast_equals", "no_conflicts", \ @@ -859,8 +941,9 @@ def merge( Returns ------- - Dataset - Dataset with combined variables from each object. + Dataset or DataTree + Objects with combined variables from the inputs. If any inputs are a + DataTree, this will also be a DataTree. Otherwise it will be a Dataset. Examples -------- @@ -1023,13 +1106,31 @@ def merge( from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree + + objects = list(objects) + + if any(isinstance(obj, DataTree) for obj in objects): + if not all(isinstance(obj, DataTree) for obj in objects): + raise TypeError( + "merge does not support mixed type arguments when one argument " + f"is a DataTree: {objects}" + ) + trees = cast(list[DataTree], objects) + return merge_trees( + trees, + compat=compat, + join=join, + combine_attrs=combine_attrs, + fill_value=fill_value, + ) dict_like_objects = [] for obj in objects: if not isinstance(obj, DataArray | Dataset | Coordinates | dict): raise TypeError( - "objects must be an iterable containing only " - "Dataset(s), DataArray(s), and dictionaries." + "objects must be an iterable containing only DataTree(s), " + f"Dataset(s), DataArray(s), and dictionaries: {objects}" ) if isinstance(obj, DataArray): diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index a368c56dee9..2edf79c62ba 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2219,13 +2219,13 @@ def test_sel_isel_error_has_node_info(self) -> None: with pytest.raises( KeyError, - match="Raised whilst mapping function over node with path 'second'", + match="Raised whilst mapping function over node(s) with path 'second'", ): tree.sel(x=1) with pytest.raises( IndexError, - match="Raised whilst mapping function over node with path 'first'", + match="Raised whilst mapping function over node(s) with path 'first'", ): tree.isel(x=4) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 277a19887eb..2956802fe1f 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -192,7 +192,7 @@ def fail_on_specific_node(ds): with pytest.raises( ValueError, match=re.escape( - r"Raised whilst mapping function over node with path 'set1'" + r"Raised whilst mapping function over node(s) with path 'set1'" ), ): dt.map_over_datasets(fail_on_specific_node) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 427d827e54c..68db0babb04 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import warnings import numpy as np @@ -867,3 +868,95 @@ def test_merge_auto_align(self): with set_options(use_new_combine_kwarg_defaults=True): with pytest.raises(ValueError, match="might be related to new default"): expected.identical(ds2.merge(ds1)) + + +class TestMergeDataTree: + def test_mixed(self) -> None: + tree = xr.DataTree() + ds = xr.Dataset() + with pytest.raises( + TypeError, + match="merge does not support mixed type arguments when one argument is a DataTree", + ): + xr.merge([tree, ds]) # type: ignore[list-item] + + def test_distinct(self) -> None: + tree1 = xr.DataTree.from_dict({"/a/b/c": 1}) + tree2 = xr.DataTree.from_dict({"/a/d/e": 2}) + expected = xr.DataTree.from_dict({"/a/b/c": 1, "/a/d/e": 2}) + merged = xr.merge([tree1, tree2]) + assert_equal(merged, expected) + + def test_overlap(self) -> None: + tree1 = xr.DataTree.from_dict({"/a/b": 1}) + tree2 = xr.DataTree.from_dict({"/a/c": 2}) + tree3 = xr.DataTree.from_dict({"/a/d": 3}) + expected = xr.DataTree.from_dict({"/a/b": 1, "/a/c": 2, "/a/d": 3}) + merged = xr.merge([tree1, tree2, tree3]) + assert_equal(merged, expected) + + def test_inherited(self) -> None: + tree1 = xr.DataTree.from_dict({"/a/b": ("x", [1])}, coords={"x": [0]}) + tree2 = xr.DataTree.from_dict({"/a/c": ("x", [2])}) + expected = xr.DataTree.from_dict( + {"/a/b": ("x", [1]), "a/c": ("x", [2])}, coords={"x": [0]} + ) + merged = xr.merge([tree1, tree2]) + assert_equal(merged, expected) + + def test_inherited_join(self) -> None: + tree1 = xr.DataTree.from_dict({"/a/b": ("x", [0, 1])}, coords={"x": [0, 1]}) + tree2 = xr.DataTree.from_dict({"/a/c": ("x", [1, 2])}, coords={"x": [1, 2]}) + + expected = xr.DataTree.from_dict( + {"/a/b": ("x", [0, 1]), "a/c": ("x", [np.nan, 1])}, coords={"x": [0, 1]} + ) + merged = xr.merge([tree1, tree2], join="left") + assert_equal(merged, expected) + + expected = xr.DataTree.from_dict( + {"/a/b": ("x", [1, np.nan]), "a/c": ("x", [1, 2])}, coords={"x": [1, 2]} + ) + merged = xr.merge([tree1, tree2], join="right") + assert_equal(merged, expected) + + expected = xr.DataTree.from_dict( + {"/a/b": ("x", [1]), "a/c": ("x", [1])}, coords={"x": [1]} + ) + merged = xr.merge([tree1, tree2], join="inner") + assert_equal(merged, expected) + + expected = xr.DataTree.from_dict( + {"/a/b": ("x", [0, 1, np.nan]), "a/c": ("x", [np.nan, 1, 2])}, + coords={"x": [0, 1, 2]}, + ) + merged = xr.merge([tree1, tree2], join="outer") + assert_equal(merged, expected) + + with pytest.raises( + xr.AlignmentError, + match=re.escape("cannot align objects with join='exact'"), + ): + xr.merge([tree1, tree2], join="exact") + + def test_merge_error_includes_path(self) -> None: + tree1 = xr.DataTree.from_dict({"/a/b": ("x", [0, 1])}) + tree2 = xr.DataTree.from_dict({"/a/b": ("x", [1, 2])}) + with pytest.raises( + xr.MergeError, + match=re.escape( + "Raised whilst mapping function over node(s) with path 'a'" + ), + ): + xr.merge([tree1, tree2], join="exact") + + def test_fill_value_errors(self) -> None: + trees = [xr.DataTree(), xr.DataTree()] + + with pytest.raises( + NotImplementedError, + match=re.escape( + "fill_value is not yet supported for DataTree objects in merge" + ), + ): + xr.merge(trees, fill_value=None)