diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index 78e27563b..146c7fa1f 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -44,6 +44,24 @@ Sample Data fetch_sample_data_path fetch_sample_data +Analysis +----------- +.. currentmodule:: movement.analysis.kinematics +.. autosummary:: + :toctree: api + + compute_displacement + compute_velocity + compute_acceleration + +Move Accessor +------------- +.. currentmodule:: movement.move_accessor +.. autosummary:: + :toctree: api + + MoveAccessor + Logging ------- .. currentmodule:: movement.logging diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py new file mode 100644 index 000000000..a89ffdf47 --- /dev/null +++ b/movement/analysis/kinematics.py @@ -0,0 +1,126 @@ +import numpy as np +import xarray as xr + +from movement.logging import log_error + + +def compute_displacement(data: xr.DataArray) -> xr.DataArray: + """Compute the displacement between consecutive positions + of each keypoint for each individual across time. + + Parameters + ---------- + data : xarray.DataArray + The input data containing ``time`` as a dimension. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed displacement. + """ + _validate_time_dimension(data) + result = data.diff(dim="time") + result = result.reindex(data.coords, fill_value=0) + return result + + +def compute_velocity(data: xr.DataArray) -> xr.DataArray: + """Compute the velocity between consecutive positions + of each keypoint for each individual across time. + + Parameters + ---------- + data : xarray.DataArray + The input data containing ``time`` as a dimension. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed velocity. + + Notes + ----- + This function computes velocity using numerical differentiation + and assumes equidistant time spacing. + """ + return _compute_approximate_derivative(data, order=1) + + +def compute_acceleration(data: xr.DataArray) -> xr.DataArray: + """Compute the acceleration between consecutive positions + of each keypoint for each individual across time. + + Parameters + ---------- + data : xarray.DataArray + The input data containing ``time`` as a dimension. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed acceleration. + + Notes + ----- + This function computes acceleration using numerical differentiation + and assumes equidistant time spacing. + """ + return _compute_approximate_derivative(data, order=2) + + +def _compute_approximate_derivative( + data: xr.DataArray, order: int +) -> xr.DataArray: + """Compute velocity or acceleration using numerical differentiation, + assuming equidistant time spacing. + + Parameters + ---------- + data : xarray.DataArray + The input data containing ``time`` as a dimension. + order : int + The order of the derivative. 1 for velocity, 2 for + acceleration. Value must be a positive integer. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the derived variable. + """ + if not isinstance(order, int): + raise log_error( + TypeError, f"Order must be an integer, but got {type(order)}." + ) + if order <= 0: + raise log_error(ValueError, "Order must be a positive integer.") + _validate_time_dimension(data) + result = data + dt = data["time"].values[1] - data["time"].values[0] + for _ in range(order): + result = xr.apply_ufunc( + np.gradient, + result, + dt, + kwargs={"axis": 0}, + ) + result = result.reindex_like(data) + return result + + +def _validate_time_dimension(data: xr.DataArray) -> None: + """Validate the input data contains a ``time`` dimension. + + Parameters + ---------- + data : xarray.DataArray + The input data to validate. + + Raises + ------ + ValueError + If the input data does not contain a ``time`` dimension. + """ + if "time" not in data.dims: + raise log_error( + ValueError, "Input data must contain 'time' as a dimension." + ) diff --git a/movement/move_accessor.py b/movement/move_accessor.py index 3186fe3bd..36a3c6eee 100644 --- a/movement/move_accessor.py +++ b/movement/move_accessor.py @@ -1,8 +1,9 @@ import logging -from typing import ClassVar +from typing import Callable, ClassVar import xarray as xr +from movement.analysis import kinematics from movement.io.validators import ValidPoseTracks logger = logging.getLogger(__name__) @@ -13,9 +14,10 @@ @xr.register_dataset_accessor("move") class MoveAccessor: - """An accessor that extends an xarray Dataset object. + """An accessor that extends an xarray Dataset by implementing + `movement`-specific properties and methods. - The xarray Dataset has the following dimensions: + The xarray Dataset contains the following expected dimensions: - ``time``: the number of frames in the video - ``individuals``: the number of individuals in the video - ``keypoints``: the number of keypoints in the skeleton @@ -26,11 +28,18 @@ class MoveAccessor: ['x','y',('z')] for ``space``. The coordinates of the ``time`` dimension are in seconds if ``fps`` is provided, otherwise they are in frame numbers. - The dataset contains two data variables (xarray DataArray objects): + The dataset contains two expected data variables (xarray DataArrays): - ``pose_tracks``: with shape (``time``, ``individuals``, ``keypoints``, ``space``) - ``confidence``: with shape (``time``, ``individuals``, ``keypoints``) + When accessing a ``.move`` property (e.g. ``displacement``, ``velocity``, + ``acceleration``) for the first time, the property is computed and stored + as a data variable with the same name in the dataset. The ``.move`` + accessor can be omitted in subsequent accesses, i.e. + ``ds.move.displacement`` and ``ds.displacement`` will return the same data + variable. + The dataset may also contain following attributes as metadata: - ``fps``: the number of frames per second in the video - ``time_unit``: the unit of the ``time`` coordinates, frames or @@ -45,7 +54,7 @@ class MoveAccessor: Using an accessor is the recommended way to extend xarray objects. See [1]_ for more details. - Methods/properties that are specific to this class can be used via + Methods/properties that are specific to this class can be accessed via the ``.move`` accessor, e.g. ``ds.move.validate()``. References @@ -70,6 +79,56 @@ class MoveAccessor: def __init__(self, ds: xr.Dataset): self._obj = ds + def _compute_property( + self, + property: str, + compute_function: Callable[[xr.DataArray], xr.DataArray], + ) -> xr.DataArray: + """Compute a kinematic property and store it in the dataset. + + Parameters + ---------- + property : str + The name of the property to compute. + compute_function : Callable[[xarray.DataArray], xarray.DataArray] + The function to compute the property. + + Returns + ------- + xarray.DataArray + The computed property. + """ + self.validate() + if property not in self._obj: + pose_tracks = self._obj[self.var_names[0]] + self._obj[property] = compute_function(pose_tracks) + return self._obj[property] + + @property + def displacement(self) -> xr.DataArray: + """Return the displacement between consecutive positions + of each keypoint for each individual across time. + """ + return self._compute_property( + "displacement", kinematics.compute_displacement + ) + + @property + def velocity(self) -> xr.DataArray: + """Return the velocity between consecutive positions + of each keypoint for each individual across time. + """ + return self._compute_property("velocity", kinematics.compute_velocity) + + @property + def acceleration(self) -> xr.DataArray: + """Return the acceleration between consecutive positions + of each keypoint for each individual across time. + """ + return self._compute_property( + "acceleration", kinematics.compute_acceleration + ) + def validate(self) -> None: """Validate the PoseTracks dataset.""" fps = self._obj.attrs.get("fps", None) diff --git a/tests/conftest.py b/tests/conftest.py index 361305dd6..c952d44f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -212,12 +212,22 @@ def valid_tracks_array(): def _valid_tracks_array(array_type): """Return a valid tracks array.""" + # Unless specified, default is a multi_track_array with + # 10 frames, 2 individuals, and 2 keypoints. + n_frames = 10 + n_individuals = 2 + n_keypoints = 2 + base = np.arange(n_frames, dtype=float)[ + :, np.newaxis, np.newaxis, np.newaxis + ] if array_type == "single_keypoint_array": - return np.zeros((10, 2, 1, 2)) + n_keypoints = 1 elif array_type == "single_track_array": - return np.zeros((10, 1, 2, 2)) - else: # "multi_track_array": - return np.zeros((10, 2, 2, 2)) + n_individuals = 1 + x_points = np.repeat(base * base, n_individuals * n_keypoints) + y_points = np.repeat(base * 4, n_individuals * n_keypoints) + tracks_array = np.ravel(np.column_stack((x_points, y_points))) + return tracks_array.reshape(n_frames, n_individuals, n_keypoints, 2) return _valid_tracks_array @@ -237,7 +247,7 @@ def valid_pose_dataset(valid_tracks_array, request): data_vars={ "pose_tracks": xr.DataArray(tracks_array, dims=dim_names), "confidence": xr.DataArray( - tracks_array[..., 0], + np.ones(tracks_array.shape[:-1]), dims=dim_names[:-1], ), }, @@ -256,9 +266,18 @@ def valid_pose_dataset(valid_tracks_array, request): ) +@pytest.fixture +def valid_pose_dataset_with_nan(valid_pose_dataset): + """Return a valid pose tracks dataset with NaN values.""" + valid_pose_dataset.pose_tracks.loc[ + {"individuals": "ind1", "time": [3, 7, 8]} + ] = np.nan + return valid_pose_dataset + + @pytest.fixture def not_a_dataset(): - """Return an invalid pose tracks dataset.""" + """Return data that is not a pose tracks dataset.""" return [1, 2, 3] @@ -289,4 +308,11 @@ def missing_dim_dataset(valid_pose_dataset): ] ) def invalid_pose_dataset(request): + """Return an invalid pose tracks dataset.""" return request.getfixturevalue(request.param) + + +@pytest.fixture(params=["displacement", "velocity", "acceleration"]) +def kinematic_property(request): + """Return a kinematic property.""" + return request.param diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py new file mode 100644 index 000000000..ca9a78d55 --- /dev/null +++ b/tests/test_unit/test_kinematics.py @@ -0,0 +1,110 @@ +from contextlib import nullcontext as does_not_raise + +import numpy as np +import pytest +import xarray as xr + +from movement.analysis import kinematics + + +class TestKinematics: + """Test suite for the kinematics module.""" + + @pytest.fixture + def expected_dataarray(self, valid_pose_dataset): + """Return a function to generate the expected dataarray + for different kinematic properties.""" + + def _expected_dataarray(property): + """Return an xarray.DataArray with default values and + the expected dimensions and coordinates.""" + # Expected x,y values for velocity + x_vals = np.array( + [1.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 17.0] + ) + y_vals = np.full((10, 2, 2, 1), 4.0) + if property == "acceleration": + x_vals = np.array( + [1.0, 1.5, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.5, 1.0] + ) + y_vals = np.full((10, 2, 2, 1), 0) + elif property == "displacement": + x_vals = np.array( + [0.0, 1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0] + ) + y_vals[0] = 0 + + x_vals = x_vals.reshape(-1, 1, 1, 1) + # Repeat the x_vals to match the shape of the pose_tracks + x_vals = np.tile(x_vals, (1, 2, 2, 1)) + return xr.DataArray( + np.concatenate( + [x_vals, y_vals], + axis=-1, + ), + dims=valid_pose_dataset.dims, + coords=valid_pose_dataset.coords, + ) + + return _expected_dataarray + + kinematic_test_params = [ + ("valid_pose_dataset", does_not_raise()), + ("valid_pose_dataset_with_nan", does_not_raise()), + ("missing_dim_dataset", pytest.raises(ValueError)), + ] + + @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) + def test_displacement( + self, ds, expected_exception, expected_dataarray, request + ): + """Test displacement computation.""" + ds = request.getfixturevalue(ds) + with expected_exception: + result = kinematics.compute_displacement(ds.pose_tracks) + expected = expected_dataarray("displacement") + if ds.pose_tracks.isnull().any(): + expected.loc[ + {"individuals": "ind1", "time": [3, 4, 7, 8, 9]} + ] = np.nan + xr.testing.assert_allclose(result, expected) + + @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) + def test_velocity( + self, ds, expected_exception, expected_dataarray, request + ): + """Test velocity computation.""" + ds = request.getfixturevalue(ds) + with expected_exception: + result = kinematics.compute_velocity(ds.pose_tracks) + expected = expected_dataarray("velocity") + if ds.pose_tracks.isnull().any(): + expected.loc[ + {"individuals": "ind1", "time": [2, 4, 6, 7, 8, 9]} + ] = np.nan + xr.testing.assert_allclose(result, expected) + + @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) + def test_acceleration( + self, ds, expected_exception, expected_dataarray, request + ): + """Test acceleration computation.""" + ds = request.getfixturevalue(ds) + with expected_exception: + result = kinematics.compute_acceleration(ds.pose_tracks) + expected = expected_dataarray("acceleration") + if ds.pose_tracks.isnull().any(): + expected.loc[ + {"individuals": "ind1", "time": [1, 3, 5, 6, 7, 8, 9]} + ] = np.nan + xr.testing.assert_allclose(result, expected) + + @pytest.mark.parametrize("order", [0, -1, 1.0, "1"]) + def test_approximate_derivative_with_invalid_order(self, order): + """Test that an error is raised when the order is non-positive.""" + data = np.arange(10) + expected_exception = ( + ValueError if isinstance(order, int) else TypeError + ) + with pytest.raises(expected_exception): + kinematics._compute_approximate_derivative(data, order=order) diff --git a/tests/test_unit/test_move_accessor.py b/tests/test_unit/test_move_accessor.py new file mode 100644 index 000000000..94e410ca9 --- /dev/null +++ b/tests/test_unit/test_move_accessor.py @@ -0,0 +1,31 @@ +import pytest +import xarray as xr + + +class TestMoveAccessor: + """Test suite for the move_accessor module.""" + + def test_property_with_valid_dataset( + self, valid_pose_dataset, kinematic_property + ): + """Test that accessing a property of a valid pose dataset + returns an instance of xr.DataArray with the correct name, + and that the input xr.Dataset now contains the property as + a data variable.""" + result = getattr(valid_pose_dataset.move, kinematic_property) + assert isinstance(result, xr.DataArray) + assert result.name == kinematic_property + assert kinematic_property in valid_pose_dataset.data_vars + + def test_property_with_invalid_dataset( + self, invalid_pose_dataset, kinematic_property + ): + """Test that accessing a property of an invalid pose dataset + raises the appropriate error.""" + expected_exception = ( + ValueError + if isinstance(invalid_pose_dataset, xr.Dataset) + else AttributeError + ) + with pytest.raises(expected_exception): + getattr(invalid_pose_dataset.move, kinematic_property)