Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute locomotion features #106

Merged
merged 30 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7015cf5
Draft compute velocity
lochhh Dec 22, 2023
6121be2
Add test for displacement
lochhh Dec 22, 2023
44ff2b0
Fix confidence values in `valid_pose_dataset` fixture
lochhh Jan 8, 2024
a4d1633
Refactor kinematics test and functions
lochhh Jan 8, 2024
e21fbb3
Vectorise kinematic functions
lochhh Jan 25, 2024
630d7a0
Refactor repeated calls to compute magnitude + direction
lochhh Jan 26, 2024
684930f
Displacement to return 0 instead of NaN
lochhh Jan 26, 2024
0fa3acc
Return x y components in kinematic functions
lochhh Jan 26, 2024
39489c0
Refactor kinematics tests
lochhh Jan 29, 2024
6e55ab6
Remove unnecessary instantiations
lochhh Jan 31, 2024
ed10d94
Improve time diff calculation
lochhh Feb 5, 2024
26351c7
Prefix kinematics methods with `compute_`
lochhh Feb 5, 2024
425d4e8
Add kinematic properties to `PosesAccessor`
lochhh Feb 5, 2024
7914156
Update `test_property` docstring
lochhh Feb 5, 2024
7664e6e
Refactor `_vector` methods and kinematics tests
lochhh Feb 5, 2024
8d96421
Update `expected_dataset` docstring
lochhh Feb 5, 2024
f0fd469
Rename `poses` to `move` in `PosesAccessor`
lochhh Feb 9, 2024
56451c1
Refactor properties in `PosesAccessor`
lochhh Feb 9, 2024
344f8c5
Remove vector util functions and tests
lochhh Feb 9, 2024
d8102fc
Update `not_a_dataset` fixture description
lochhh Feb 19, 2024
4838837
Validate dataset upon accessor property access
lochhh Feb 19, 2024
a945999
Update `poses_accessor` test description
lochhh Feb 19, 2024
15ca8a5
Validate input data in kinematic functions
lochhh Feb 19, 2024
d3a6e0f
Remove unused fixture
lochhh Feb 19, 2024
9c92ae8
Parametrise kinematics tests
lochhh Feb 20, 2024
9f1c464
Set `compute_derivative` as internal function
lochhh Feb 23, 2024
b342100
Update `kinematics.py` docstrings
lochhh Feb 23, 2024
ac0a579
Add new modules to API docs
lochhh Feb 23, 2024
e895de0
Update `move_accessor` docstrings
lochhh Feb 23, 2024
e5cb36c
Rename `test_move_accessor` filename
lochhh Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/api_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ Sample Data
fetch_sample_data_path
fetch_sample_data

Analysis
-----------
.. currentmodule:: movement.analysis.kinematics
.. autosummary::
:toctree: api

compute_displacement
compute_velocity
compute_acceleration

Move Accessor
-------------
.. currentmodule:: movement.move_accessor
.. autosummary::
:toctree: api

MoveAccessor

Logging
-------
.. currentmodule:: movement.logging
Expand Down
126 changes: 126 additions & 0 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import numpy as np
import xarray as xr

from movement.logging import log_error


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
"""Compute the displacement between consecutive positions
of each keypoint for each individual across time.

Parameters
----------
data : xarray.DataArray
The input data containing ``time`` as a dimension.

Returns
-------
xarray.DataArray
An xarray DataArray containing the computed displacement.
"""
_validate_time_dimension(data)
result = data.diff(dim="time")
result = result.reindex(data.coords, fill_value=0)
return result


def compute_velocity(data: xr.DataArray) -> xr.DataArray:
"""Compute the velocity between consecutive positions
of each keypoint for each individual across time.

Parameters
----------
data : xarray.DataArray
The input data containing ``time`` as a dimension.

Returns
-------
xarray.DataArray
An xarray DataArray containing the computed velocity.

Notes
-----
This function computes velocity using numerical differentiation
and assumes equidistant time spacing.
"""
return _compute_approximate_derivative(data, order=1)


def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
"""Compute the acceleration between consecutive positions
of each keypoint for each individual across time.

Parameters
----------
data : xarray.DataArray
The input data containing ``time`` as a dimension.

Returns
-------
xarray.DataArray
An xarray DataArray containing the computed acceleration.

Notes
-----
This function computes acceleration using numerical differentiation
and assumes equidistant time spacing.
"""
return _compute_approximate_derivative(data, order=2)


def _compute_approximate_derivative(
data: xr.DataArray, order: int
) -> xr.DataArray:
"""Compute velocity or acceleration using numerical differentiation,
assuming equidistant time spacing.

Parameters
----------
data : xarray.DataArray
The input data containing ``time`` as a dimension.
order : int
The order of the derivative. 1 for velocity, 2 for
acceleration. Value must be a positive integer.

Returns
-------
xarray.DataArray
An xarray DataArray containing the derived variable.
"""
if not isinstance(order, int):
raise log_error(
TypeError, f"Order must be an integer, but got {type(order)}."
)
if order <= 0:
raise log_error(ValueError, "Order must be a positive integer.")
_validate_time_dimension(data)
result = data
dt = data["time"].values[1] - data["time"].values[0]
for _ in range(order):
result = xr.apply_ufunc(
np.gradient,
result,
dt,
kwargs={"axis": 0},
)
result = result.reindex_like(data)
return result


def _validate_time_dimension(data: xr.DataArray) -> None:
"""Validate the input data contains a ``time`` dimension.

Parameters
----------
data : xarray.DataArray
The input data to validate.

Raises
------
ValueError
If the input data does not contain a ``time`` dimension.
"""
if "time" not in data.dims:
raise log_error(
ValueError, "Input data must contain 'time' as a dimension."
)
69 changes: 64 additions & 5 deletions movement/move_accessor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from typing import ClassVar
from typing import Callable, ClassVar

import xarray as xr

from movement.analysis import kinematics
from movement.io.validators import ValidPoseTracks

logger = logging.getLogger(__name__)
Expand All @@ -13,9 +14,10 @@

@xr.register_dataset_accessor("move")
class MoveAccessor:
"""An accessor that extends an xarray Dataset object.
"""An accessor that extends an xarray Dataset by implementing
`movement`-specific properties and methods.

The xarray Dataset has the following dimensions:
The xarray Dataset contains the following expected dimensions:
- ``time``: the number of frames in the video
- ``individuals``: the number of individuals in the video
- ``keypoints``: the number of keypoints in the skeleton
Expand All @@ -26,11 +28,18 @@ class MoveAccessor:
['x','y',('z')] for ``space``. The coordinates of the ``time`` dimension
are in seconds if ``fps`` is provided, otherwise they are in frame numbers.

The dataset contains two data variables (xarray DataArray objects):
The dataset contains two expected data variables (xarray DataArrays):
- ``pose_tracks``: with shape (``time``, ``individuals``,
``keypoints``, ``space``)
- ``confidence``: with shape (``time``, ``individuals``, ``keypoints``)

When accessing a ``.move`` property (e.g. ``displacement``, ``velocity``,
``acceleration``) for the first time, the property is computed and stored
as a data variable with the same name in the dataset. The ``.move``
accessor can be omitted in subsequent accesses, i.e.
``ds.move.displacement`` and ``ds.displacement`` will return the same data
variable.

The dataset may also contain following attributes as metadata:
- ``fps``: the number of frames per second in the video
- ``time_unit``: the unit of the ``time`` coordinates, frames or
Expand All @@ -45,7 +54,7 @@ class MoveAccessor:
Using an accessor is the recommended way to extend xarray objects.
See [1]_ for more details.

Methods/properties that are specific to this class can be used via
Methods/properties that are specific to this class can be accessed via
the ``.move`` accessor, e.g. ``ds.move.validate()``.

References
Expand All @@ -70,6 +79,56 @@ class MoveAccessor:
def __init__(self, ds: xr.Dataset):
self._obj = ds

def _compute_property(
self,
property: str,
compute_function: Callable[[xr.DataArray], xr.DataArray],
) -> xr.DataArray:
"""Compute a kinematic property and store it in the dataset.

Parameters
----------
property : str
The name of the property to compute.
compute_function : Callable[[xarray.DataArray], xarray.DataArray]
The function to compute the property.

Returns
-------
xarray.DataArray
The computed property.
"""
self.validate()
if property not in self._obj:
pose_tracks = self._obj[self.var_names[0]]
self._obj[property] = compute_function(pose_tracks)
return self._obj[property]

@property
def displacement(self) -> xr.DataArray:
"""Return the displacement between consecutive positions
of each keypoint for each individual across time.
"""
return self._compute_property(
"displacement", kinematics.compute_displacement
)

@property
def velocity(self) -> xr.DataArray:
"""Return the velocity between consecutive positions
of each keypoint for each individual across time.
"""
return self._compute_property("velocity", kinematics.compute_velocity)

@property
def acceleration(self) -> xr.DataArray:
"""Return the acceleration between consecutive positions
of each keypoint for each individual across time.
"""
return self._compute_property(
"acceleration", kinematics.compute_acceleration
)

def validate(self) -> None:
"""Validate the PoseTracks dataset."""
fps = self._obj.attrs.get("fps", None)
Expand Down
38 changes: 32 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,22 @@ def valid_tracks_array():

def _valid_tracks_array(array_type):
"""Return a valid tracks array."""
# Unless specified, default is a multi_track_array with
# 10 frames, 2 individuals, and 2 keypoints.
n_frames = 10
n_individuals = 2
n_keypoints = 2
base = np.arange(n_frames, dtype=float)[
:, np.newaxis, np.newaxis, np.newaxis
]
if array_type == "single_keypoint_array":
return np.zeros((10, 2, 1, 2))
n_keypoints = 1
elif array_type == "single_track_array":
return np.zeros((10, 1, 2, 2))
else: # "multi_track_array":
return np.zeros((10, 2, 2, 2))
n_individuals = 1
x_points = np.repeat(base * base, n_individuals * n_keypoints)
y_points = np.repeat(base * 4, n_individuals * n_keypoints)
tracks_array = np.ravel(np.column_stack((x_points, y_points)))
return tracks_array.reshape(n_frames, n_individuals, n_keypoints, 2)

return _valid_tracks_array

Expand All @@ -237,7 +247,7 @@ def valid_pose_dataset(valid_tracks_array, request):
data_vars={
"pose_tracks": xr.DataArray(tracks_array, dims=dim_names),
"confidence": xr.DataArray(
tracks_array[..., 0],
np.ones(tracks_array.shape[:-1]),
dims=dim_names[:-1],
),
},
Expand All @@ -256,9 +266,18 @@ def valid_pose_dataset(valid_tracks_array, request):
)


@pytest.fixture
def valid_pose_dataset_with_nan(valid_pose_dataset):
"""Return a valid pose tracks dataset with NaN values."""
valid_pose_dataset.pose_tracks.loc[
{"individuals": "ind1", "time": [3, 7, 8]}
] = np.nan
return valid_pose_dataset


@pytest.fixture
def not_a_dataset():
"""Return an invalid pose tracks dataset."""
"""Return data that is not a pose tracks dataset."""
return [1, 2, 3]


Expand Down Expand Up @@ -289,4 +308,11 @@ def missing_dim_dataset(valid_pose_dataset):
]
)
def invalid_pose_dataset(request):
"""Return an invalid pose tracks dataset."""
return request.getfixturevalue(request.param)


@pytest.fixture(params=["displacement", "velocity", "acceleration"])
def kinematic_property(request):
"""Return a kinematic property."""
return request.param
Loading
Loading