Skip to content

Collapse dimensions common functionality for plot wrappers #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions movement/utils/dimensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Utilities for manipulating dimensions of ``xarray.DataArray`` objects."""

from collections.abc import Hashable, Iterable

import xarray as xr


def collapse_extra_dimensions(
da: xr.DataArray,
preserve_dims: Iterable[str] = ("time", "space"),
**selection: str,
) -> xr.DataArray:
"""Collapse a ``DataArray``, preserving only the specified dimensions.

By default, dimensions that are collapsed retain the corresponding 'slice'
along their 0th index of those dimensions, unless a particular index for is
given in the ``selection``.

Parameters
----------
da : xarray.DataArray
DataArray of multiple dimensions, which is to be collapsed.
preserve_dims : Iterable[str]
The dimensions of ``da`` that should not be collapsed.
selection : str
Mapping from dimension names to a particular index name in that
dimension.

Returns
-------
xarray.DataArray
DataArray whose shape is the same as the shape of the preserved
dimensions of ``da``, containing the data obtained from a slice along
the collapsed dimensions.

Examples
--------
Collapse a ``DataArray`` down to just a ``"time"``-``"space"`` slice.

>>> import xarray as xr
>>> import numpy as np
>>> shape = (7, 2, 3, 2)
>>> da = xr.DataArray(
... data=np.arange(np.prod(shape)).reshape(shape),
... dims=["time", "space", "keypoints", "individuals"],
... coords={
... "time": np.arange(7),
... "space": np.arange(2),
... "keypoints": ["nose", "left_ear", "right_ear"],
... "individuals": ["Alice", "Bob"],
... )
>>> space_time = collapse_extra_dimensions(da)
>>> print(space_time.shape)
(7, 2)

The call to ``collapse_extra_dimensions`` above is equivalent to
``da.isel(keypoints=0, individuals=0)`` (indexing by integer) and to
``da.sel(keypoints="nose", individuals="Alice")`` (indexing by label).
We can change which slice we take from the collapsed dimensions by passing
them as keyword arguments.

>>> # Equivalent to da.sel(keypoints="right_ear", individuals="Bob")
>>> space_time_bob_right_ear = collapse_extra_dimensions(
... da, keypoints="right_ear", individuals="Bob"
... )
>>> print(space_time_bob_right_ear.shape)
(7, 2)

We can also change which dimensions are to be preserved.

>>> time_only = collapse_extra_dims(da, preserve_dims=["time"])
>>> print(time_only.shape)
(7,)

"""
data = da.copy(deep=True)
dims_to_collapse = [d for d in data.dims if d not in preserve_dims]
make_selection = {
d: _coord_of_dimension(da, d, selection.pop(d, 0))
for d in dims_to_collapse
}
return data.sel(make_selection)


def _coord_of_dimension(
da: xr.DataArray, dimension: str, coord_index: int | str
) -> Hashable:
"""Retrieve a coordinate of a given dimension.

This method handles the case where the coordinate is known by name
within the coordinates of the ``dimension``.

If ``coord_index`` is an element of ``da.dimension``, it can just be
returned. Otherwise, we need to return ``da.dimension[coord_index]``.

Out of bounds index errors, or non existent dimension errors are handled by
the underlying ``xarray.DataArray`` implementation.

Parameters
----------
da : xarray.DataArray
DataArray to retrieve a coordinate from.
dimension : str
Dimension of the DataArray to fetch the coordinate from.
coord_index : int | str
The index of the coordinate along ``dimension`` to fetch.

Returns
-------
Hashable
The requested coordinate name at ``da.dimension[coord_index]``.

"""
dim = getattr(da, dimension)
return dim[coord_index] if coord_index not in dim else coord_index
143 changes: 143 additions & 0 deletions tests/test_unit/test_collapse_dimensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import re
from typing import Any

import numpy as np
import pytest
import xarray as xr

from movement.utils.dimensions import (
_coord_of_dimension,
collapse_extra_dimensions,
)


@pytest.fixture
def shape() -> tuple[int, ...]:
return (7, 2, 4, 3)


@pytest.fixture
def da(shape: tuple[int, ...]) -> xr.DataArray:
return xr.DataArray(
data=np.arange(np.prod(shape)).reshape(shape),
dims=["time", "space", "keypoints", "individuals"],
coords={
"space": ["x", "y"],
"keypoints": ["head", "shoulders", "knees", "toes"],
"individuals": ["a", "b", "c"],
},
)


@pytest.mark.parametrize(
["pass_to_function", "equivalent_to_sel"],
[
pytest.param(
{},
{"individuals": "a", "keypoints": "head"},
id="Default preserve time-space",
),
pytest.param(
{"preserve_dims": ["space"]},
{"time": 0, "individuals": "a", "keypoints": "head"},
id="Keep space only",
),
pytest.param(
{"individuals": "b"},
{"individuals": "b", "keypoints": "head"},
id="Request non-default slice",
),
pytest.param(
{"individuals": "c"},
{"individuals": "c", "keypoints": "head"},
id="Request by coordinate",
),
pytest.param(
{
"individuals": "b",
"elephants": "this is a non-existent dimension",
"crabs": 42,
},
{"individuals": "b", "keypoints": "head"},
id="Selection ignores dimensions that don't exist",
),
pytest.param(
{"preserve_dims": []},
{"time": 0, "space": "x", "individuals": "a", "keypoints": "head"},
id="Collapse everything",
),
],
)
def test_collapse_dimensions(
da: xr.DataArray,
pass_to_function: dict[str, str],
equivalent_to_sel: dict[str, int | str],
) -> None:
result_from_collapsing = collapse_extra_dimensions(da, **pass_to_function)

# We should be equivalent to this method
expected_result = da.sel(**equivalent_to_sel)

assert result_from_collapsing.shape == expected_result.values.shape
xr.testing.assert_allclose(result_from_collapsing, expected_result)


@pytest.mark.parametrize(
["pass_to_function"],
[
pytest.param(
{"keypoints": ["head", "toes"]},
id="Multiple keypoints",
),
pytest.param(
{"individuals": ["a", "b"]},
id="Multiple individuals",
),
],
)
def test_collapse_dimensions_value_error(
da: xr.DataArray,
pass_to_function: dict[str, Any],
) -> None:
with pytest.raises(ValueError):
collapse_extra_dimensions(da, **pass_to_function)


@pytest.mark.parametrize(
["args_to_fn", "expected"],
[
pytest.param(
{"dimension": "individuals", "coord_index": 1},
"b",
id="Fetch coord from index",
),
pytest.param(
{"dimension": "time", "coord_index": 6},
6,
id="Dimension with no coordinates",
),
pytest.param(
{"dimension": "space", "coord_index": "x"},
"x",
id="Fetch coord from name",
),
pytest.param(
{"dimension": "keypoints", "coord_index": 10},
IndexError("index 10 is out of bounds for axis 0 with size 4"),
id="Out of bounds index",
),
pytest.param(
{"dimension": "keypoints", "coord_index": "arms"},
KeyError("arms"),
id="Non existent coord name",
),
],
)
def test_coord_of_dimension(
da: xr.DataArray, args_to_fn: dict[str, str], expected: str | Exception
) -> None:
if isinstance(expected, Exception):
with pytest.raises(type(expected), match=re.escape(str(expected))):
_coord_of_dimension(da, **args_to_fn)
else:
assert expected == _coord_of_dimension(da, **args_to_fn)
Loading