Skip to content

Commit

Permalink
change selection input to str only, add value error tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaprins committed Feb 6, 2025
1 parent c7cfd0e commit 8a7a349
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 27 deletions.
35 changes: 18 additions & 17 deletions movement/utils/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def collapse_extra_dimensions(
da: xr.DataArray,
preserve_dims: Iterable[str] = ("time", "space"),
**selection: int | str,
**selection: str,
) -> xr.DataArray:
"""Collapse a ``DataArray``, preserving only the specified dimensions.
Expand All @@ -22,10 +22,9 @@ def collapse_extra_dimensions(
DataArray of multiple dimensions, which is to be collapsed.
preserve_dims : Iterable[str]
The dimensions of ``da`` that should not be collapsed.
selection : int | str
Mapping from dimension names to a particular index in that dimension.
Dimensions that appear with an index in ``selection`` retain that index
slice when collapsed, rather than the default 0th index slice.
selection : str
Mapping from dimension names to a particular index name in that
dimension.
Returns
-------
Expand All @@ -40,19 +39,24 @@ def collapse_extra_dimensions(
>>> import xarray as xr
>>> import numpy as np
>>> shape = (7, 2, 3, 4)
>>> shape = (7, 2, 3, 2)
>>> da = xr.DataArray(
... data=np.arange(np.prod(shape)).reshape(shape),
... dims=["time", "space", "dim_to_collapse_0", "dim_to_collapse_1"],
... dims=["time", "space", "keypoints", "individuals"],
... coords={
... "time": np.arange(7),
... "space": np.arange(2),
... "keypoints": ["nose", "left_ear", "right_ear"],
... "individuals": ["Alice", "Bob"],
... )
>>> space_time = collapse_extra_dimensions(da)
>>> print(space_time.shape)
(7, 2)
The call to ``collapse_extra_dimensions`` above is equivalent to
``da.sel(dim_to_collapse_0=0, dim_to_collapse_1=0)``. We can change which
slice we take from the collapsed dimensions by passing them as keyword
arguments.
``da.sel(keypoints="head", individuals="Alice")`` (indexing by label).
We can change which slice we take from the collapsed dimensions by passing
them as keyword arguments.
>>> # Equivalent to da.sel(dim_to_collapse_0=2, dim_to_collapse_1=1)
>>> space_time_different_slice = collapse_extra_dimensions(
Expand All @@ -71,22 +75,19 @@ def collapse_extra_dimensions(
data = da.copy(deep=True)
dims_to_collapse = [d for d in data.dims if d not in preserve_dims]
make_selection = {
d: coord_of_dimension(da, d, selection.pop(d, 0))
d: _coord_of_dimension(da, d, selection.pop(d, 0))
for d in dims_to_collapse
}
return data.sel(make_selection)


def coord_of_dimension(
def _coord_of_dimension(
da: xr.DataArray, dimension: str, coord_index: int | str
) -> Hashable:
"""Retrieve a coordinate of a given dimension.
This method handles the case where the coordinate is known either by name
or only known by its index within the coordinates of the ``dimension``. It
is predominantly useful when we want to give the user the flexibility to
refer to (part of) a ``DataArray`` by index (as in regular ``numpy``-style
array slicing) or by name (via ``xarray`` coordinates).
This method handles the case where the coordinate is known by name
within the coordinates of the ``dimension``.
If ``coord_index`` is an element of ``da.dimension``, it can just be
returned. Otherwise, we need to return ``da.dimension[coord_index]``.
Expand Down
41 changes: 31 additions & 10 deletions tests/test_unit/test_collapse_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@
import xarray as xr

from movement.utils.dimensions import (
_coord_of_dimension,
collapse_extra_dimensions,
coord_of_dimension,
)


@pytest.fixture
def shape() -> tuple[int, ...]:
return (7, 2, 3, 4)
return (7, 2, 4, 3)


@pytest.fixture
def da(shape: tuple[int, ...]) -> xr.DataArray:
return xr.DataArray(
data=np.arange(np.prod(shape)).reshape(shape),
dims=["time", "space", "individuals", "keypoints"],
dims=["time", "space", "keypoints", "individuals"],
coords={
"space": ["x", "y"],
"individuals": ["a", "b", "c"],
"keypoints": ["head", "shoulders", "knees", "toes"],
"individuals": ["a", "b", "c"],
},
)

Expand All @@ -43,7 +43,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
id="Keep space only",
),
pytest.param(
{"individuals": 1},
{"individuals": "b"},
{"individuals": "b", "keypoints": "head"},
id="Request non-default slice",
),
Expand All @@ -54,7 +54,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
),
pytest.param(
{
"individuals": 1,
"individuals": "b",
"elephants": "this is a non-existent dimension",
"crabs": 42,
},
Expand All @@ -70,7 +70,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
)
def test_collapse_dimensions(
da: xr.DataArray,
pass_to_function: dict[str, Any],
pass_to_function: dict[str, str],
equivalent_to_sel: dict[str, int | str],
) -> None:
result_from_collapsing = collapse_extra_dimensions(da, **pass_to_function)
Expand All @@ -82,6 +82,27 @@ def test_collapse_dimensions(
xr.testing.assert_allclose(result_from_collapsing, expected_result)


@pytest.mark.parametrize(
["pass_to_function"],
[
pytest.param(
{"keypoints": ["head", "toes"]},
id="Multiple keypoints",
),
pytest.param(
{"individuals": ["a", "b"]},
id="Multiple individuals",
),
],
)
def test_collapse_dimensions_value_error(
da: xr.DataArray,
pass_to_function: dict[str, Any],
) -> None:
with pytest.raises(ValueError):
collapse_extra_dimensions(da, **pass_to_function)


@pytest.mark.parametrize(
["args_to_fn", "expected"],
[
Expand Down Expand Up @@ -113,10 +134,10 @@ def test_collapse_dimensions(
],
)
def test_coord_of_dimension(
da: xr.DataArray, args_to_fn: dict[str, Any], expected: str | Exception
da: xr.DataArray, args_to_fn: dict[str, str], expected: str | Exception
) -> None:
if isinstance(expected, Exception):
with pytest.raises(type(expected), match=re.escape(str(expected))):
coord_of_dimension(da, **args_to_fn)
_coord_of_dimension(da, **args_to_fn)
else:
assert expected == coord_of_dimension(da, **args_to_fn)
assert expected == _coord_of_dimension(da, **args_to_fn)

0 comments on commit 8a7a349

Please sign in to comment.