Skip to content

Commit 278b11a

Browse files
authored
Stop inheriting non-indexed coordinates for DataTree (#9555)
1 parent 98596dd commit 278b11a

File tree

6 files changed

+60
-36
lines changed

6 files changed

+60
-36
lines changed

xarray/core/dataarray.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
)
7070
from xarray.core.utils import (
7171
Default,
72-
HybridMappingProxy,
72+
FilteredMapping,
7373
ReprObject,
7474
_default,
7575
either_dict_or_kwargs,
@@ -929,11 +929,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
929929
@property
930930
def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
931931
"""Places to look-up items for key-completion"""
932-
yield HybridMappingProxy(keys=self._coords, mapping=self.coords)
932+
yield FilteredMapping(keys=self._coords, mapping=self.coords)
933933

934934
# virtual coordinates
935-
# uses empty dict -- everything here can already be found in self.coords.
936-
yield HybridMappingProxy(keys=self.dims, mapping={})
935+
yield FilteredMapping(keys=self.dims, mapping=self.coords)
937936

938937
def __contains__(self, key: Any) -> bool:
939938
return key in self.data

xarray/core/dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@
103103
)
104104
from xarray.core.utils import (
105105
Default,
106+
FilteredMapping,
106107
Frozen,
107108
FrozenMappingWarningOnValuesAccess,
108-
HybridMappingProxy,
109109
OrderedSet,
110110
_default,
111111
decode_numpy_dict_values,
@@ -1507,10 +1507,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
15071507
def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
15081508
"""Places to look-up items for key-completion"""
15091509
yield self.data_vars
1510-
yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords)
1510+
yield FilteredMapping(keys=self._coord_names, mapping=self.coords)
15111511

15121512
# virtual coordinates
1513-
yield HybridMappingProxy(keys=self.sizes, mapping=self)
1513+
yield FilteredMapping(keys=self.sizes, mapping=self)
15141514

15151515
def __contains__(self, key: object) -> bool:
15161516
"""The 'in' operator will return true or false depending on whether

xarray/core/datatree.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
from xarray.core.treenode import NamedNode, NodePath
4141
from xarray.core.utils import (
4242
Default,
43+
FilteredMapping,
4344
Frozen,
44-
HybridMappingProxy,
4545
_default,
4646
either_dict_or_kwargs,
4747
maybe_wrap_array,
@@ -516,10 +516,17 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
516516
check_alignment(path, node_ds, parent_ds, self.children)
517517
_deduplicate_inherited_coordinates(self, parent)
518518

519+
@property
520+
def _node_coord_variables_with_index(self) -> Mapping[Hashable, Variable]:
521+
return FilteredMapping(
522+
keys=self._node_indexes, mapping=self._node_coord_variables
523+
)
524+
519525
@property
520526
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
521527
return ChainMap(
522-
self._node_coord_variables, *(p._node_coord_variables for p in self.parents)
528+
self._node_coord_variables,
529+
*(p._node_coord_variables_with_index for p in self.parents),
523530
)
524531

525532
@property
@@ -720,10 +727,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
720727
def _item_sources(self) -> Iterable[Mapping[Any, Any]]:
721728
"""Places to look-up items for key-completion"""
722729
yield self.data_vars
723-
yield HybridMappingProxy(keys=self._coord_variables, mapping=self.coords)
730+
yield FilteredMapping(keys=self._coord_variables, mapping=self.coords)
724731

725732
# virtual coordinates
726-
yield HybridMappingProxy(keys=self.dims, mapping=self)
733+
yield FilteredMapping(keys=self.dims, mapping=self)
727734

728735
# immediate child nodes
729736
yield self.children

xarray/core/utils.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -465,33 +465,39 @@ def values(self) -> ValuesView[V]:
465465
return super().values()
466466

467467

468-
class HybridMappingProxy(Mapping[K, V]):
468+
class FilteredMapping(Mapping[K, V]):
469469
"""Implements the Mapping interface. Uses the wrapped mapping for item lookup
470470
and a separate wrapped keys collection for iteration.
471471
472472
Can be used to construct a mapping object from another dict-like object without
473473
eagerly accessing its items or when a mapping object is expected but only
474474
iteration over keys is actually used.
475475
476-
Note: HybridMappingProxy does not validate consistency of the provided `keys`
477-
and `mapping`. It is the caller's responsibility to ensure that they are
478-
suitable for the task at hand.
476+
Note: keys should be a subset of mapping, but FilteredMapping does not
477+
validate consistency of the provided `keys` and `mapping`. It is the
478+
caller's responsibility to ensure that they are suitable for the task at
479+
hand.
479480
"""
480481

481-
__slots__ = ("_keys", "mapping")
482+
__slots__ = ("keys_", "mapping")
482483

483484
def __init__(self, keys: Collection[K], mapping: Mapping[K, V]):
484-
self._keys = keys
485+
self.keys_ = keys # .keys is already a property on Mapping
485486
self.mapping = mapping
486487

487488
def __getitem__(self, key: K) -> V:
489+
if key not in self.keys_:
490+
raise KeyError(key)
488491
return self.mapping[key]
489492

490493
def __iter__(self) -> Iterator[K]:
491-
return iter(self._keys)
494+
return iter(self.keys_)
492495

493496
def __len__(self) -> int:
494-
return len(self._keys)
497+
return len(self.keys_)
498+
499+
def __repr__(self) -> str:
500+
return f"{type(self).__name__}(keys={self.keys_!r}, mapping={self.mapping!r})"
495501

496502

497503
class OrderedSet(MutableSet[T]):

xarray/tests/test_datatree.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -216,16 +216,16 @@ def test_is_hollow(self):
216216

217217

218218
class TestToDataset:
219-
def test_to_dataset(self):
220-
base = xr.Dataset(coords={"a": 1})
221-
sub = xr.Dataset(coords={"b": 2})
219+
def test_to_dataset_inherited(self):
220+
base = xr.Dataset(coords={"a": [1], "b": 2})
221+
sub = xr.Dataset(coords={"c": [3]})
222222
tree = DataTree.from_dict({"/": base, "/sub": sub})
223223
subtree = typing.cast(DataTree, tree["sub"])
224224

225225
assert_identical(tree.to_dataset(inherited=False), base)
226226
assert_identical(subtree.to_dataset(inherited=False), sub)
227227

228-
sub_and_base = xr.Dataset(coords={"a": 1, "b": 2})
228+
sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]}) # no "b"
229229
assert_identical(tree.to_dataset(inherited=True), base)
230230
assert_identical(subtree.to_dataset(inherited=True), sub_and_base)
231231

@@ -714,7 +714,8 @@ def test_inherited(self):
714714
dt["child"] = DataTree()
715715
child = dt["child"]
716716

717-
assert set(child.coords) == {"x", "y", "a", "b"}
717+
assert set(dt.coords) == {"x", "y", "a", "b"}
718+
assert set(child.coords) == {"x", "y"}
718719

719720
actual = child.copy(deep=True)
720721
actual.coords["x"] = ("x", ["a", "b"])
@@ -729,7 +730,7 @@ def test_inherited(self):
729730

730731
with pytest.raises(KeyError):
731732
# cannot delete inherited coordinate from child node
732-
del child["b"]
733+
del child["x"]
733734

734735
# TODO requires a fix for #9472
735736
# actual = child.copy(deep=True)
@@ -1278,22 +1279,23 @@ def test_inherited_coords_index(self):
12781279
assert "x" in dt["/b"].coords
12791280
xr.testing.assert_identical(dt["/x"], dt["/b/x"])
12801281

1281-
def test_inherited_coords_override(self):
1282+
def test_inherit_only_index_coords(self):
12821283
dt = DataTree.from_dict(
12831284
{
1284-
"/": xr.Dataset(coords={"x": 1, "y": 2}),
1285-
"/b": xr.Dataset(coords={"x": 4, "z": 3}),
1285+
"/": xr.Dataset(coords={"x": [1], "y": 2}),
1286+
"/b": xr.Dataset(coords={"z": 3}),
12861287
}
12871288
)
12881289
assert dt.coords.keys() == {"x", "y"}
1289-
root_coords = {"x": 1, "y": 2}
1290-
sub_coords = {"x": 4, "y": 2, "z": 3}
1291-
xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords))
1292-
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords))
1293-
assert dt["/b"].coords.keys() == {"x", "y", "z"}
1294-
xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords))
1295-
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
1296-
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))
1290+
xr.testing.assert_equal(
1291+
dt["/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "y": 2})
1292+
)
1293+
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords={"y": 2}))
1294+
assert dt["/b"].coords.keys() == {"x", "z"}
1295+
xr.testing.assert_equal(
1296+
dt["/b/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "z": 3})
1297+
)
1298+
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords={"z": 3}))
12971299

12981300
def test_inherited_coords_with_index_are_deduplicated(self):
12991301
dt = DataTree.from_dict(

xarray/tests/test_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ def test_frozen(self):
139139
"Frozen({'b': 'B', 'a': 'A'})",
140140
)
141141

142+
def test_filtered(self):
143+
x = utils.FilteredMapping(keys={"a"}, mapping={"a": 1, "b": 2})
144+
assert "a" in x
145+
assert "b" not in x
146+
assert x["a"] == 1
147+
assert list(x) == ["a"]
148+
assert len(x) == 1
149+
assert repr(x) == "FilteredMapping(keys={'a'}, mapping={'a': 1, 'b': 2})"
150+
assert dict(x) == {"a": 1}
151+
142152

143153
def test_repr_object():
144154
obj = utils.ReprObject("foo")

0 commit comments

Comments
 (0)