From 8a7a349f58bd9f67291baaa936c2160d0695b4c4 Mon Sep 17 00:00:00 2001 From: Stella <30465823+stellaprins@users.noreply.github.com> Date: Thu, 6 Feb 2025 15:52:20 +0000 Subject: [PATCH] change selection input to str only, add value error tests --- movement/utils/dimensions.py | 35 +++++++++--------- tests/test_unit/test_collapse_dimensions.py | 41 ++++++++++++++++----- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/movement/utils/dimensions.py b/movement/utils/dimensions.py index 14c8d3f2..4cba096c 100644 --- a/movement/utils/dimensions.py +++ b/movement/utils/dimensions.py @@ -8,7 +8,7 @@ def collapse_extra_dimensions( da: xr.DataArray, preserve_dims: Iterable[str] = ("time", "space"), - **selection: int | str, + **selection: str, ) -> xr.DataArray: """Collapse a ``DataArray``, preserving only the specified dimensions. @@ -22,10 +22,9 @@ def collapse_extra_dimensions( DataArray of multiple dimensions, which is to be collapsed. preserve_dims : Iterable[str] The dimensions of ``da`` that should not be collapsed. - selection : int | str - Mapping from dimension names to a particular index in that dimension. - Dimensions that appear with an index in ``selection`` retain that index - slice when collapsed, rather than the default 0th index slice. + selection : str + Mapping from dimension names to a particular index name in that + dimension. Returns ------- @@ -40,19 +39,24 @@ def collapse_extra_dimensions( >>> import xarray as xr >>> import numpy as np - >>> shape = (7, 2, 3, 4) + >>> shape = (7, 2, 3, 2) >>> da = xr.DataArray( ... data=np.arange(np.prod(shape)).reshape(shape), - ... dims=["time", "space", "dim_to_collapse_0", "dim_to_collapse_1"], + ... dims=["time", "space", "keypoints", "individuals"], + ... coords={ + ... "time": np.arange(7), + ... "space": np.arange(2), + ... "keypoints": ["nose", "left_ear", "right_ear"], + ... "individuals": ["Alice", "Bob"], ... ) >>> space_time = collapse_extra_dimensions(da) >>> print(space_time.shape) (7, 2) The call to ``collapse_extra_dimensions`` above is equivalent to - ``da.sel(dim_to_collapse_0=0, dim_to_collapse_1=0)``. We can change which - slice we take from the collapsed dimensions by passing them as keyword - arguments. + ``da.sel(keypoints="head", individuals="Alice")`` (indexing by label). + We can change which slice we take from the collapsed dimensions by passing + them as keyword arguments. >>> # Equivalent to da.sel(dim_to_collapse_0=2, dim_to_collapse_1=1) >>> space_time_different_slice = collapse_extra_dimensions( @@ -71,22 +75,19 @@ def collapse_extra_dimensions( data = da.copy(deep=True) dims_to_collapse = [d for d in data.dims if d not in preserve_dims] make_selection = { - d: coord_of_dimension(da, d, selection.pop(d, 0)) + d: _coord_of_dimension(da, d, selection.pop(d, 0)) for d in dims_to_collapse } return data.sel(make_selection) -def coord_of_dimension( +def _coord_of_dimension( da: xr.DataArray, dimension: str, coord_index: int | str ) -> Hashable: """Retrieve a coordinate of a given dimension. - This method handles the case where the coordinate is known either by name - or only known by its index within the coordinates of the ``dimension``. It - is predominantly useful when we want to give the user the flexibility to - refer to (part of) a ``DataArray`` by index (as in regular ``numpy``-style - array slicing) or by name (via ``xarray`` coordinates). + This method handles the case where the coordinate is known by name + within the coordinates of the ``dimension``. If ``coord_index`` is an element of ``da.dimension``, it can just be returned. Otherwise, we need to return ``da.dimension[coord_index]``. diff --git a/tests/test_unit/test_collapse_dimensions.py b/tests/test_unit/test_collapse_dimensions.py index 8f6a1a5e..333d748c 100644 --- a/tests/test_unit/test_collapse_dimensions.py +++ b/tests/test_unit/test_collapse_dimensions.py @@ -6,25 +6,25 @@ import xarray as xr from movement.utils.dimensions import ( + _coord_of_dimension, collapse_extra_dimensions, - coord_of_dimension, ) @pytest.fixture def shape() -> tuple[int, ...]: - return (7, 2, 3, 4) + return (7, 2, 4, 3) @pytest.fixture def da(shape: tuple[int, ...]) -> xr.DataArray: return xr.DataArray( data=np.arange(np.prod(shape)).reshape(shape), - dims=["time", "space", "individuals", "keypoints"], + dims=["time", "space", "keypoints", "individuals"], coords={ "space": ["x", "y"], - "individuals": ["a", "b", "c"], "keypoints": ["head", "shoulders", "knees", "toes"], + "individuals": ["a", "b", "c"], }, ) @@ -43,7 +43,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray: id="Keep space only", ), pytest.param( - {"individuals": 1}, + {"individuals": "b"}, {"individuals": "b", "keypoints": "head"}, id="Request non-default slice", ), @@ -54,7 +54,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray: ), pytest.param( { - "individuals": 1, + "individuals": "b", "elephants": "this is a non-existent dimension", "crabs": 42, }, @@ -70,7 +70,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray: ) def test_collapse_dimensions( da: xr.DataArray, - pass_to_function: dict[str, Any], + pass_to_function: dict[str, str], equivalent_to_sel: dict[str, int | str], ) -> None: result_from_collapsing = collapse_extra_dimensions(da, **pass_to_function) @@ -82,6 +82,27 @@ def test_collapse_dimensions( xr.testing.assert_allclose(result_from_collapsing, expected_result) +@pytest.mark.parametrize( + ["pass_to_function"], + [ + pytest.param( + {"keypoints": ["head", "toes"]}, + id="Multiple keypoints", + ), + pytest.param( + {"individuals": ["a", "b"]}, + id="Multiple individuals", + ), + ], +) +def test_collapse_dimensions_value_error( + da: xr.DataArray, + pass_to_function: dict[str, Any], +) -> None: + with pytest.raises(ValueError): + collapse_extra_dimensions(da, **pass_to_function) + + @pytest.mark.parametrize( ["args_to_fn", "expected"], [ @@ -113,10 +134,10 @@ def test_collapse_dimensions( ], ) def test_coord_of_dimension( - da: xr.DataArray, args_to_fn: dict[str, Any], expected: str | Exception + da: xr.DataArray, args_to_fn: dict[str, str], expected: str | Exception ) -> None: if isinstance(expected, Exception): with pytest.raises(type(expected), match=re.escape(str(expected))): - coord_of_dimension(da, **args_to_fn) + _coord_of_dimension(da, **args_to_fn) else: - assert expected == coord_of_dimension(da, **args_to_fn) + assert expected == _coord_of_dimension(da, **args_to_fn)