Skip to content

Commit 12c690f

Browse files
authored
Disallow passing a DataArray as data into the DataTree constructor (#9444)
1 parent a74a605 commit 12c690f

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

Diff for: xarray/core/datatree.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,13 @@ def _collect_data_and_coord_variables(
9191
return data_variables, coord_variables
9292

9393

94-
def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset:
95-
if isinstance(data, DataArray):
96-
ds = data.to_dataset()
97-
elif isinstance(data, Dataset):
94+
def _to_new_dataset(data: Dataset | None) -> Dataset:
95+
if isinstance(data, Dataset):
9896
ds = data.copy(deep=False)
9997
elif data is None:
10098
ds = Dataset()
10199
else:
102-
raise TypeError(
103-
f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}"
104-
)
100+
raise TypeError(f"data object is not an xarray.Dataset, dict, or None: {data}")
105101
return ds
106102

107103

@@ -422,7 +418,8 @@ class DataTree(
422418

423419
def __init__(
424420
self,
425-
data: Dataset | DataArray | None = None,
421+
data: Dataset | None = None,
422+
parent: DataTree | None = None,
426423
children: Mapping[str, DataTree] | None = None,
427424
name: str | None = None,
428425
):
@@ -435,9 +432,8 @@ def __init__(
435432
436433
Parameters
437434
----------
438-
data : Dataset, DataArray, or None, optional
439-
Data to store under the .ds attribute of this node. DataArrays will
440-
be promoted to Datasets. Default is None.
435+
data : Dataset, optional
436+
Data to store under the .ds attribute of this node.
441437
children : Mapping[str, DataTree], optional
442438
Any child nodes of this node. Default is None.
443439
name : str, optional
@@ -455,7 +451,7 @@ def __init__(
455451
children = {}
456452

457453
super().__init__(name=name)
458-
self._set_node_data(_coerce_to_dataset(data))
454+
self._set_node_data(_to_new_dataset(data))
459455

460456
# shallow copy to avoid modifying arguments in-place (see GH issue #9196)
461457
self.children = {name: child.copy() for name, child in children.items()}
@@ -540,8 +536,8 @@ def ds(self) -> DatasetView:
540536
return self._to_dataset_view(rebuild_dims=True)
541537

542538
@ds.setter
543-
def ds(self, data: Dataset | DataArray | None = None) -> None:
544-
ds = _coerce_to_dataset(data)
539+
def ds(self, data: Dataset | None = None) -> None:
540+
ds = _to_new_dataset(data)
545541
self._replace_node(ds)
546542

547543
def to_dataset(self, inherited: bool = True) -> Dataset:
@@ -1050,7 +1046,7 @@ def drop_nodes(
10501046
@classmethod
10511047
def from_dict(
10521048
cls,
1053-
d: Mapping[str, Dataset | DataArray | DataTree | None],
1049+
d: Mapping[str, Dataset | DataTree | None],
10541050
name: str | None = None,
10551051
) -> DataTree:
10561052
"""
@@ -1059,10 +1055,10 @@ def from_dict(
10591055
Parameters
10601056
----------
10611057
d : dict-like
1062-
A mapping from path names to xarray.Dataset, xarray.DataArray, or DataTree objects.
1058+
A mapping from path names to xarray.Dataset or DataTree objects.
10631059
1064-
Path names are to be given as unix-like path. If path names containing more than one part are given, new
1065-
tree nodes will be constructed as necessary.
1060+
Path names are to be given as unix-like path. If path names containing more than one
1061+
part are given, new tree nodes will be constructed as necessary.
10661062
10671063
To assign data to the root node of the tree use "/" as the path.
10681064
name : Hashable | None, optional
@@ -1083,8 +1079,12 @@ def from_dict(
10831079
if isinstance(root_data, DataTree):
10841080
obj = root_data.copy()
10851081
obj.orphan()
1086-
else:
1082+
elif root_data is None or isinstance(root_data, Dataset):
10871083
obj = cls(name=name, data=root_data, children=None)
1084+
else:
1085+
raise TypeError(
1086+
f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}'
1087+
)
10881088

10891089
def depth(item) -> int:
10901090
pathstr, _ = item

Diff for: xarray/tests/test_datatree.py

+13
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ def test_bad_names(self):
3333
with pytest.raises(ValueError):
3434
DataTree(name="folder/data")
3535

36+
def test_data_arg(self):
37+
ds = xr.Dataset({"foo": 42})
38+
tree: DataTree = DataTree(data=ds)
39+
assert_identical(tree.to_dataset(), ds)
40+
41+
with pytest.raises(TypeError):
42+
DataTree(data=xr.DataArray(42, name="foo")) # type: ignore
43+
3644

3745
class TestFamilyTree:
3846
def test_dont_modify_children_inplace(self):
@@ -613,6 +621,11 @@ def test_insertion_order(self):
613621
# despite 'Bart' coming before 'Lisa' when sorted alphabetically
614622
assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"]
615623

624+
def test_array_values(self):
625+
data = {"foo": xr.DataArray(1, name="bar")}
626+
with pytest.raises(TypeError):
627+
DataTree.from_dict(data) # type: ignore
628+
616629

617630
class TestDatasetView:
618631
def test_view_contents(self):

0 commit comments

Comments
 (0)