diff --git a/movement/utils/dimensions.py b/movement/utils/dimensions.py new file mode 100644 index 000000000..e03ab7951 --- /dev/null +++ b/movement/utils/dimensions.py @@ -0,0 +1,115 @@ +"""Utilities for manipulating dimensions of ``xarray.DataArray`` objects.""" + +from collections.abc import Hashable, Iterable + +import xarray as xr + + +def collapse_extra_dimensions( + da: xr.DataArray, + preserve_dims: Iterable[str] = ("time", "space"), + **selection: str, +) -> xr.DataArray: + """Collapse a ``DataArray``, preserving only the specified dimensions. + + By default, dimensions that are collapsed retain the corresponding 'slice' + along their 0th index of those dimensions, unless a particular index for is + given in the ``selection``. + + Parameters + ---------- + da : xarray.DataArray + DataArray of multiple dimensions, which is to be collapsed. + preserve_dims : Iterable[str] + The dimensions of ``da`` that should not be collapsed. + selection : str + Mapping from dimension names to a particular index name in that + dimension. + + Returns + ------- + xarray.DataArray + DataArray whose shape is the same as the shape of the preserved + dimensions of ``da``, containing the data obtained from a slice along + the collapsed dimensions. + + Examples + -------- + Collapse a ``DataArray`` down to just a ``"time"``-``"space"`` slice. + + >>> import xarray as xr + >>> import numpy as np + >>> shape = (7, 2, 3, 2) + >>> da = xr.DataArray( + ... data=np.arange(np.prod(shape)).reshape(shape), + ... dims=["time", "space", "keypoints", "individuals"], + ... coords={ + ... "time": np.arange(7), + ... "space": np.arange(2), + ... "keypoints": ["nose", "left_ear", "right_ear"], + ... "individuals": ["Alice", "Bob"], + ... ) + >>> space_time = collapse_extra_dimensions(da) + >>> print(space_time.shape) + (7, 2) + + The call to ``collapse_extra_dimensions`` above is equivalent to + ``da.isel(keypoints=0, individuals=0)`` (indexing by integer) and to + ``da.sel(keypoints="nose", individuals="Alice")`` (indexing by label). + We can change which slice we take from the collapsed dimensions by passing + them as keyword arguments. + + >>> # Equivalent to da.sel(keypoints="right_ear", individuals="Bob") + >>> space_time_bob_right_ear = collapse_extra_dimensions( + ... da, keypoints="right_ear", individuals="Bob" + ... ) + >>> print(space_time_bob_right_ear.shape) + (7, 2) + + We can also change which dimensions are to be preserved. + + >>> time_only = collapse_extra_dims(da, preserve_dims=["time"]) + >>> print(time_only.shape) + (7,) + + """ + data = da.copy(deep=True) + dims_to_collapse = [d for d in data.dims if d not in preserve_dims] + make_selection = { + d: _coord_of_dimension(da, d, selection.pop(d, 0)) + for d in dims_to_collapse + } + return data.sel(make_selection) + + +def _coord_of_dimension( + da: xr.DataArray, dimension: str, coord_index: int | str +) -> Hashable: + """Retrieve a coordinate of a given dimension. + + This method handles the case where the coordinate is known by name + within the coordinates of the ``dimension``. + + If ``coord_index`` is an element of ``da.dimension``, it can just be + returned. Otherwise, we need to return ``da.dimension[coord_index]``. + + Out of bounds index errors, or non existent dimension errors are handled by + the underlying ``xarray.DataArray`` implementation. + + Parameters + ---------- + da : xarray.DataArray + DataArray to retrieve a coordinate from. + dimension : str + Dimension of the DataArray to fetch the coordinate from. + coord_index : int | str + The index of the coordinate along ``dimension`` to fetch. + + Returns + ------- + Hashable + The requested coordinate name at ``da.dimension[coord_index]``. + + """ + dim = getattr(da, dimension) + return dim[coord_index] if coord_index not in dim else coord_index diff --git a/tests/test_unit/test_collapse_dimensions.py b/tests/test_unit/test_collapse_dimensions.py new file mode 100644 index 000000000..333d748c5 --- /dev/null +++ b/tests/test_unit/test_collapse_dimensions.py @@ -0,0 +1,143 @@ +import re +from typing import Any + +import numpy as np +import pytest +import xarray as xr + +from movement.utils.dimensions import ( + _coord_of_dimension, + collapse_extra_dimensions, +) + + +@pytest.fixture +def shape() -> tuple[int, ...]: + return (7, 2, 4, 3) + + +@pytest.fixture +def da(shape: tuple[int, ...]) -> xr.DataArray: + return xr.DataArray( + data=np.arange(np.prod(shape)).reshape(shape), + dims=["time", "space", "keypoints", "individuals"], + coords={ + "space": ["x", "y"], + "keypoints": ["head", "shoulders", "knees", "toes"], + "individuals": ["a", "b", "c"], + }, + ) + + +@pytest.mark.parametrize( + ["pass_to_function", "equivalent_to_sel"], + [ + pytest.param( + {}, + {"individuals": "a", "keypoints": "head"}, + id="Default preserve time-space", + ), + pytest.param( + {"preserve_dims": ["space"]}, + {"time": 0, "individuals": "a", "keypoints": "head"}, + id="Keep space only", + ), + pytest.param( + {"individuals": "b"}, + {"individuals": "b", "keypoints": "head"}, + id="Request non-default slice", + ), + pytest.param( + {"individuals": "c"}, + {"individuals": "c", "keypoints": "head"}, + id="Request by coordinate", + ), + pytest.param( + { + "individuals": "b", + "elephants": "this is a non-existent dimension", + "crabs": 42, + }, + {"individuals": "b", "keypoints": "head"}, + id="Selection ignores dimensions that don't exist", + ), + pytest.param( + {"preserve_dims": []}, + {"time": 0, "space": "x", "individuals": "a", "keypoints": "head"}, + id="Collapse everything", + ), + ], +) +def test_collapse_dimensions( + da: xr.DataArray, + pass_to_function: dict[str, str], + equivalent_to_sel: dict[str, int | str], +) -> None: + result_from_collapsing = collapse_extra_dimensions(da, **pass_to_function) + + # We should be equivalent to this method + expected_result = da.sel(**equivalent_to_sel) + + assert result_from_collapsing.shape == expected_result.values.shape + xr.testing.assert_allclose(result_from_collapsing, expected_result) + + +@pytest.mark.parametrize( + ["pass_to_function"], + [ + pytest.param( + {"keypoints": ["head", "toes"]}, + id="Multiple keypoints", + ), + pytest.param( + {"individuals": ["a", "b"]}, + id="Multiple individuals", + ), + ], +) +def test_collapse_dimensions_value_error( + da: xr.DataArray, + pass_to_function: dict[str, Any], +) -> None: + with pytest.raises(ValueError): + collapse_extra_dimensions(da, **pass_to_function) + + +@pytest.mark.parametrize( + ["args_to_fn", "expected"], + [ + pytest.param( + {"dimension": "individuals", "coord_index": 1}, + "b", + id="Fetch coord from index", + ), + pytest.param( + {"dimension": "time", "coord_index": 6}, + 6, + id="Dimension with no coordinates", + ), + pytest.param( + {"dimension": "space", "coord_index": "x"}, + "x", + id="Fetch coord from name", + ), + pytest.param( + {"dimension": "keypoints", "coord_index": 10}, + IndexError("index 10 is out of bounds for axis 0 with size 4"), + id="Out of bounds index", + ), + pytest.param( + {"dimension": "keypoints", "coord_index": "arms"}, + KeyError("arms"), + id="Non existent coord name", + ), + ], +) +def test_coord_of_dimension( + da: xr.DataArray, args_to_fn: dict[str, str], expected: str | Exception +) -> None: + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): + _coord_of_dimension(da, **args_to_fn) + else: + assert expected == _coord_of_dimension(da, **args_to_fn)