Skip to content
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ v2026.03.0 (unreleased)
New Features
~~~~~~~~~~~~

- Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit
all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`).
By `Alfonso Ladino <https://github.com/aladinor>`_.

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
57 changes: 46 additions & 11 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,28 @@ def _coord_variables(self) -> ChainMap[Hashable, Variable]:
*(p._node_coord_variables_with_index for p in self.parents), # type: ignore[arg-type]
)

@property
def _coord_variables_all(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
self._node_coord_variables,
*(p._node_coord_variables for p in self.parents),
)

def _resolve_inherit(
self, inherit: bool | Literal["all_coords", "indexes"]
) -> tuple[Mapping[Hashable, Variable], dict[Hashable, Index]]:
"""Resolve the inherit parameter to (coord_vars, indexes)."""
if inherit is False:
return self._node_coord_variables, dict(self._node_indexes)
if inherit is True or inherit == "indexes":
return self._coord_variables, dict(self._indexes)
if inherit == "all_coords":
return self._coord_variables_all, dict(self._indexes)
raise ValueError(
f"Invalid value for inherit: {inherit!r}. "
"Expected True, False, 'indexes', or 'all'."
)

@property
def _dims(self) -> ChainMap[Hashable, int]:
return ChainMap(self._node_dims, *(p._node_dims for p in self.parents))
Expand All @@ -596,8 +618,12 @@ def _dims(self) -> ChainMap[Hashable, int]:
def _indexes(self) -> ChainMap[Hashable, Index]:
return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents))

def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
coord_vars = self._coord_variables if inherit else self._node_coord_variables
def _to_dataset_view(
self,
rebuild_dims: bool,
inherit: bool | Literal["all_coords", "indexes"] = True,
) -> DatasetView:
coord_vars, indexes = self._resolve_inherit(inherit)
variables = dict(self._data_variables)
variables |= coord_vars
if rebuild_dims:
Expand Down Expand Up @@ -636,10 +662,10 @@ def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
dims = dict(self._node_dims)
return DatasetView._constructor(
variables=variables,
coord_names=set(self._coord_variables),
coord_names=set(coord_vars),
dims=dims,
attrs=self._attrs,
indexes=dict(self._indexes if inherit else self._node_indexes),
indexes=indexes,
encoding=self._encoding,
close=None,
)
Expand Down Expand Up @@ -669,30 +695,39 @@ def dataset(self, data: Dataset | None = None) -> None:
# xarray-contrib/datatree
ds = dataset

def to_dataset(self, inherit: bool = True) -> Dataset:
def to_dataset(
self, inherit: bool | Literal["all_coords", "indexes"] = True
) -> Dataset:
"""
Return the data in this node as a new xarray.Dataset object.

Parameters
----------
inherit : bool, optional
If False, only include coordinates and indexes defined at the level
of this DataTree node, excluding any inherited coordinates and indexes.
inherit : bool or {"all_coords", "indexes"}, default True
Controls which coordinates are inherited from parent nodes.

- True or "indexes": inherit only indexed coordinates (default).
- "all_coords": inherit all coordinates, including non-index coordinates.
- False: only include coordinates defined at this node.

See Also
--------
DataTree.dataset
"""
coord_vars = self._coord_variables if inherit else self._node_coord_variables
coord_vars, indexes = self._resolve_inherit(inherit)
variables = dict(self._data_variables)
variables |= coord_vars
dims = calculate_dimensions(variables) if inherit else dict(self._node_dims)
dims = (
dict(self._node_dims)
if inherit is False
else calculate_dimensions(variables)
)
return Dataset._construct_direct(
variables,
set(coord_vars),
dims,
None if self._attrs is None else dict(self._attrs),
dict(self._indexes if inherit else self._node_indexes),
indexes,
None if self._encoding is None else dict(self._encoding),
None,
)
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,28 @@ def test_to_dataset_inherited(self) -> None:
assert_identical(tree.to_dataset(inherit=True), base)
assert_identical(subtree.to_dataset(inherit=True), sub_and_base)

def test_to_dataset_inherit_all(self) -> None:
base = xr.Dataset(coords={"a": [1], "b": 2})
sub = xr.Dataset(coords={"c": [3]})
tree = DataTree.from_dict({"/": base, "/sub": sub})
subtree = typing.cast(DataTree, tree["sub"])

expected = xr.Dataset(coords={"a": [1], "b": 2, "c": [3]})
assert_identical(subtree.to_dataset(inherit="all_coords"), expected)
assert_identical(tree.to_dataset(inherit="all_coords"), base)

mid = xr.Dataset(coords={"c": 3.0})
leaf = xr.Dataset(coords={"d": [4]})
deep = DataTree.from_dict({"/": base, "/mid": mid, "/mid/leaf": leaf})
leaf_node = typing.cast(DataTree, deep["/mid/leaf"])
result = leaf_node.to_dataset(inherit="all_coords")
assert set(result.coords) == {"a", "b", "c", "d"}

def test_to_dataset_inherit_invalid(self) -> None:
tree = DataTree()
with pytest.raises(ValueError, match="Invalid value for inherit"):
tree.to_dataset(inherit="invalid") # type: ignore[arg-type]


class TestVariablesChildrenNameCollisions:
def test_parent_already_has_variable_with_childs_name(self) -> None:
Expand Down
Loading