Skip to content

Commit 8a7a349

Browse files
committed
change selection input to str only, add value error tests
1 parent c7cfd0e commit 8a7a349

File tree

2 files changed

+49
-27
lines changed

2 files changed

+49
-27
lines changed

movement/utils/dimensions.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def collapse_extra_dimensions(
99
da: xr.DataArray,
1010
preserve_dims: Iterable[str] = ("time", "space"),
11-
**selection: int | str,
11+
**selection: str,
1212
) -> xr.DataArray:
1313
"""Collapse a ``DataArray``, preserving only the specified dimensions.
1414
@@ -22,10 +22,9 @@ def collapse_extra_dimensions(
2222
DataArray of multiple dimensions, which is to be collapsed.
2323
preserve_dims : Iterable[str]
2424
The dimensions of ``da`` that should not be collapsed.
25-
selection : int | str
26-
Mapping from dimension names to a particular index in that dimension.
27-
Dimensions that appear with an index in ``selection`` retain that index
28-
slice when collapsed, rather than the default 0th index slice.
25+
selection : str
26+
Mapping from dimension names to a particular index name in that
27+
dimension.
2928
3029
Returns
3130
-------
@@ -40,19 +39,24 @@ def collapse_extra_dimensions(
4039
4140
>>> import xarray as xr
4241
>>> import numpy as np
43-
>>> shape = (7, 2, 3, 4)
42+
>>> shape = (7, 2, 3, 2)
4443
>>> da = xr.DataArray(
4544
... data=np.arange(np.prod(shape)).reshape(shape),
46-
... dims=["time", "space", "dim_to_collapse_0", "dim_to_collapse_1"],
45+
... dims=["time", "space", "keypoints", "individuals"],
46+
... coords={
47+
... "time": np.arange(7),
48+
... "space": np.arange(2),
49+
... "keypoints": ["nose", "left_ear", "right_ear"],
50+
... "individuals": ["Alice", "Bob"],
4751
... )
4852
>>> space_time = collapse_extra_dimensions(da)
4953
>>> print(space_time.shape)
5054
(7, 2)
5155
5256
The call to ``collapse_extra_dimensions`` above is equivalent to
53-
``da.sel(dim_to_collapse_0=0, dim_to_collapse_1=0)``. We can change which
54-
slice we take from the collapsed dimensions by passing them as keyword
55-
arguments.
57+
``da.sel(keypoints="head", individuals="Alice")`` (indexing by label).
58+
We can change which slice we take from the collapsed dimensions by passing
59+
them as keyword arguments.
5660
5761
>>> # Equivalent to da.sel(dim_to_collapse_0=2, dim_to_collapse_1=1)
5862
>>> space_time_different_slice = collapse_extra_dimensions(
@@ -71,22 +75,19 @@ def collapse_extra_dimensions(
7175
data = da.copy(deep=True)
7276
dims_to_collapse = [d for d in data.dims if d not in preserve_dims]
7377
make_selection = {
74-
d: coord_of_dimension(da, d, selection.pop(d, 0))
78+
d: _coord_of_dimension(da, d, selection.pop(d, 0))
7579
for d in dims_to_collapse
7680
}
7781
return data.sel(make_selection)
7882

7983

80-
def coord_of_dimension(
84+
def _coord_of_dimension(
8185
da: xr.DataArray, dimension: str, coord_index: int | str
8286
) -> Hashable:
8387
"""Retrieve a coordinate of a given dimension.
8488
85-
This method handles the case where the coordinate is known either by name
86-
or only known by its index within the coordinates of the ``dimension``. It
87-
is predominantly useful when we want to give the user the flexibility to
88-
refer to (part of) a ``DataArray`` by index (as in regular ``numpy``-style
89-
array slicing) or by name (via ``xarray`` coordinates).
89+
This method handles the case where the coordinate is known by name
90+
within the coordinates of the ``dimension``.
9091
9192
If ``coord_index`` is an element of ``da.dimension``, it can just be
9293
returned. Otherwise, we need to return ``da.dimension[coord_index]``.

tests/test_unit/test_collapse_dimensions.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,25 @@
66
import xarray as xr
77

88
from movement.utils.dimensions import (
9+
_coord_of_dimension,
910
collapse_extra_dimensions,
10-
coord_of_dimension,
1111
)
1212

1313

1414
@pytest.fixture
1515
def shape() -> tuple[int, ...]:
16-
return (7, 2, 3, 4)
16+
return (7, 2, 4, 3)
1717

1818

1919
@pytest.fixture
2020
def da(shape: tuple[int, ...]) -> xr.DataArray:
2121
return xr.DataArray(
2222
data=np.arange(np.prod(shape)).reshape(shape),
23-
dims=["time", "space", "individuals", "keypoints"],
23+
dims=["time", "space", "keypoints", "individuals"],
2424
coords={
2525
"space": ["x", "y"],
26-
"individuals": ["a", "b", "c"],
2726
"keypoints": ["head", "shoulders", "knees", "toes"],
27+
"individuals": ["a", "b", "c"],
2828
},
2929
)
3030

@@ -43,7 +43,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
4343
id="Keep space only",
4444
),
4545
pytest.param(
46-
{"individuals": 1},
46+
{"individuals": "b"},
4747
{"individuals": "b", "keypoints": "head"},
4848
id="Request non-default slice",
4949
),
@@ -54,7 +54,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
5454
),
5555
pytest.param(
5656
{
57-
"individuals": 1,
57+
"individuals": "b",
5858
"elephants": "this is a non-existent dimension",
5959
"crabs": 42,
6060
},
@@ -70,7 +70,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
7070
)
7171
def test_collapse_dimensions(
7272
da: xr.DataArray,
73-
pass_to_function: dict[str, Any],
73+
pass_to_function: dict[str, str],
7474
equivalent_to_sel: dict[str, int | str],
7575
) -> None:
7676
result_from_collapsing = collapse_extra_dimensions(da, **pass_to_function)
@@ -82,6 +82,27 @@ def test_collapse_dimensions(
8282
xr.testing.assert_allclose(result_from_collapsing, expected_result)
8383

8484

85+
@pytest.mark.parametrize(
86+
["pass_to_function"],
87+
[
88+
pytest.param(
89+
{"keypoints": ["head", "toes"]},
90+
id="Multiple keypoints",
91+
),
92+
pytest.param(
93+
{"individuals": ["a", "b"]},
94+
id="Multiple individuals",
95+
),
96+
],
97+
)
98+
def test_collapse_dimensions_value_error(
99+
da: xr.DataArray,
100+
pass_to_function: dict[str, Any],
101+
) -> None:
102+
with pytest.raises(ValueError):
103+
collapse_extra_dimensions(da, **pass_to_function)
104+
105+
85106
@pytest.mark.parametrize(
86107
["args_to_fn", "expected"],
87108
[
@@ -113,10 +134,10 @@ def test_collapse_dimensions(
113134
],
114135
)
115136
def test_coord_of_dimension(
116-
da: xr.DataArray, args_to_fn: dict[str, Any], expected: str | Exception
137+
da: xr.DataArray, args_to_fn: dict[str, str], expected: str | Exception
117138
) -> None:
118139
if isinstance(expected, Exception):
119140
with pytest.raises(type(expected), match=re.escape(str(expected))):
120-
coord_of_dimension(da, **args_to_fn)
141+
_coord_of_dimension(da, **args_to_fn)
121142
else:
122-
assert expected == coord_of_dimension(da, **args_to_fn)
143+
assert expected == _coord_of_dimension(da, **args_to_fn)

0 commit comments

Comments
 (0)