Skip to content

Commit b82cc9f

Browse files
committed
Validate input data in kinematic functions
1 parent bc91614 commit b82cc9f

File tree

3 files changed

+106
-57
lines changed

3 files changed

+106
-57
lines changed

movement/analysis/kinematics.py

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,37 @@
11
import numpy as np
22
import xarray as xr
33

4+
from movement.logging import log_error
5+
46

57
def compute_displacement(data: xr.DataArray) -> xr.DataArray:
6-
"""Compute the displacement between consecutive x, y
7-
locations of each keypoint of each individual.
8+
"""Compute the displacement between consecutive locations
9+
of each keypoint of each individual across time.
810
911
Parameters
1012
----------
1113
data : xarray.DataArray
12-
The input data, assumed to be of shape (..., 2), where
13-
the last dimension contains the x and y coordinates.
14+
The input data containing `time` as a dimension.
1415
1516
Returns
1617
-------
1718
xarray.DataArray
1819
An xarray DataArray containing the computed displacement.
1920
"""
20-
displacement_xy = data.diff(dim="time")
21-
displacement_xy = displacement_xy.reindex(data.coords, fill_value=0)
22-
return displacement_xy
21+
_validate_time_dimension(data)
22+
result = data.diff(dim="time")
23+
result = result.reindex(data.coords, fill_value=0)
24+
return result
2325

2426

2527
def compute_velocity(data: xr.DataArray) -> xr.DataArray:
26-
"""Compute the velocity between consecutive x, y locations
27-
of each keypoint of each individual.
28+
"""Compute the velocity between consecutive locations
29+
of each keypoint of each individual across time.
2830
2931
Parameters
3032
----------
3133
data : xarray.DataArray
32-
The input data, assumed to be of shape (..., 2), where the last
33-
dimension contains the x and y coordinates.
34+
The input data containing `time` as a dimension.
3435
3536
Returns
3637
-------
@@ -41,14 +42,13 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray:
4142

4243

4344
def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
44-
"""Compute the acceleration between consecutive x, y
45-
locations of each keypoint of each individual.
45+
"""Compute the acceleration between consecutive locations
46+
of each keypoint of each individual.
4647
4748
Parameters
4849
----------
4950
data : xarray.DataArray
50-
The input data, assumed to be of shape (..., 2), where the last
51-
dimension contains the x and y coordinates.
51+
The input data containing `time` as a dimension.
5252
5353
Returns
5454
-------
@@ -60,48 +60,58 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
6060

6161

6262
def compute_approximate_derivative(
63-
data: xr.DataArray, order: int = 1
63+
data: xr.DataArray, order: int
6464
) -> xr.DataArray:
6565
"""Compute velocity or acceleration using numerical differentiation,
6666
assuming equidistant time spacing.
6767
6868
Parameters
6969
----------
7070
data : xarray.DataArray
71-
The input data, assumed to be of shape (..., 2), where the last
72-
dimension contains data in the x and y dimensions.
71+
The input data containing `time` as a dimension.
7372
order : int
7473
The order of the derivative. 1 for velocity, 2 for
75-
acceleration. Default is 1.
74+
acceleration. Value must be a positive integer.
7675
7776
Returns
7877
-------
7978
xarray.DataArray
8079
An xarray DataArray containing the derived variable.
8180
"""
81+
if not isinstance(order, int):
82+
raise log_error(
83+
TypeError, f"Order must be an integer, but got {type(order)}."
84+
)
8285
if order <= 0:
83-
raise ValueError("order must be a positive integer.")
84-
else:
85-
result = data
86-
dt = data["time"].values[1] - data["time"].values[0]
87-
for _ in range(order):
88-
result = xr.apply_ufunc(
89-
np.gradient,
90-
result,
91-
dt,
92-
kwargs={"axis": 0},
93-
)
94-
result = result.reindex_like(data)
86+
raise log_error(ValueError, "Order must be a positive integer.")
87+
_validate_time_dimension(data)
88+
result = data
89+
dt = data["time"].values[1] - data["time"].values[0]
90+
for _ in range(order):
91+
result = xr.apply_ufunc(
92+
np.gradient,
93+
result,
94+
dt,
95+
kwargs={"axis": 0},
96+
)
97+
result = result.reindex_like(data)
9598
return result
9699

97100

98-
# Locomotion Features
99-
# speed
100-
# speed_centroid
101-
# acceleration
102-
# acceleration_centroid
103-
# speed_fwd
104-
# radial_vel
105-
# tangential_vel
106-
# speed_centroid_w(s)
107-
# speed_(p)_w(s)
101+
def _validate_time_dimension(data: xr.DataArray) -> None:
102+
"""Validate the input data contains a 'time' dimension.
103+
104+
Parameters
105+
----------
106+
data : xarray.DataArray
107+
The input data to validate.
108+
109+
Raises
110+
------
111+
ValueError
112+
If the input data does not contain a 'time' dimension.
113+
"""
114+
if "time" not in data.dims:
115+
raise log_error(
116+
ValueError, "Input data must contain 'time' as a dimension."
117+
)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _valid_tracks_array(array_type):
222222
n_keypoints = 1
223223
elif array_type == "single_track_array":
224224
n_individuals = 1
225-
x_points = np.repeat(base * 3, n_individuals * n_keypoints)
225+
x_points = np.repeat(base * base, n_individuals * n_keypoints)
226226
y_points = np.repeat(base * 4, n_individuals * n_keypoints)
227227
tracks_array = np.ravel(np.column_stack((x_points, y_points)))
228228
return tracks_array.reshape(n_frames, n_individuals, n_keypoints, 2)

tests/test_unit/test_kinematics.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,41 @@ class TestKinematics:
1010

1111
@pytest.fixture
1212
def expected_dataarray(self, valid_pose_dataset):
13-
"""Return an xarray.DataArray with default values and
14-
the expected dimensions and coordinates."""
15-
return xr.DataArray(
16-
np.full((10, 2, 2, 2), [3.0, 4.0]),
17-
dims=valid_pose_dataset.dims,
18-
coords=valid_pose_dataset.coords,
19-
)
13+
"""Return a function to generate the expected dataarray
14+
for different kinematic properties."""
15+
16+
def _expected_dataarray(property):
17+
"""Return an xarray.DataArray with default values and
18+
the expected dimensions and coordinates."""
19+
# Expected x,y values for velocity
20+
x_vals = np.array(
21+
[1.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 17.0]
22+
)
23+
y_vals = np.full((10, 2, 2, 1), 4.0)
24+
if property == "acceleration":
25+
x_vals = np.array(
26+
[1.0, 1.5, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.5, 1.0]
27+
)
28+
y_vals = np.full((10, 2, 2, 1), 0)
29+
elif property == "displacement":
30+
x_vals = np.array(
31+
[0.0, 1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0]
32+
)
33+
y_vals[0] = 0
34+
35+
x_vals = x_vals.reshape(-1, 1, 1, 1)
36+
# Repeat the x_vals to match the shape of the pose_tracks
37+
x_vals = np.tile(x_vals, (1, 2, 2, 1))
38+
return xr.DataArray(
39+
np.concatenate(
40+
[x_vals, y_vals],
41+
axis=-1,
42+
),
43+
dims=valid_pose_dataset.dims,
44+
coords=valid_pose_dataset.coords,
45+
)
46+
47+
return _expected_dataarray
2048

2149
@pytest.fixture
2250
def expected_dataset(self, valid_pose_dataset):
@@ -61,25 +89,36 @@ def test_displacement(self, valid_pose_dataset, expected_dataarray):
6189
result = kinematics.compute_displacement(
6290
valid_pose_dataset.pose_tracks
6391
)
64-
# Set the first displacement to zero
65-
expected_dataarray[0, :, :, :] = 0
66-
xr.testing.assert_allclose(result, expected_dataarray)
92+
xr.testing.assert_allclose(result, expected_dataarray("displacement"))
6793

6894
def test_velocity(self, valid_pose_dataset, expected_dataarray):
6995
"""Test velocity computation."""
7096
result = kinematics.compute_velocity(valid_pose_dataset.pose_tracks)
71-
xr.testing.assert_allclose(result, expected_dataarray)
97+
xr.testing.assert_allclose(result, expected_dataarray("velocity"))
7298

7399
def test_acceleration(self, valid_pose_dataset, expected_dataarray):
74100
"""Test acceleration computation."""
75101
result = kinematics.compute_acceleration(
76102
valid_pose_dataset.pose_tracks
77103
)
78-
expected_dataarray[:] = 0
79-
xr.testing.assert_allclose(result, expected_dataarray)
104+
xr.testing.assert_allclose(result, expected_dataarray("acceleration"))
80105

81-
def test_approximate_derivative_with_nonpositive_order(self):
106+
@pytest.mark.parametrize("order", [0, -1, 1.0, "1"])
107+
def test_approximate_derivative_with_invalid_order(self, order):
82108
"""Test that an error is raised when the order is non-positive."""
83109
data = np.arange(10)
110+
expected_exception = (
111+
ValueError if isinstance(order, int) else TypeError
112+
)
113+
with pytest.raises(expected_exception):
114+
kinematics.compute_approximate_derivative(data, order=order)
115+
116+
def test_compute_with_missing_time_dimension(
117+
self, missing_dim_dataset, kinematic_property
118+
):
119+
"""Test that computing a property of a pose dataset with
120+
missing 'time' dimension raises the appropriate error."""
84121
with pytest.raises(ValueError):
85-
kinematics.compute_approximate_derivative(data, order=0)
122+
eval(f"kinematics.compute_{kinematic_property}")(
123+
missing_dim_dataset
124+
)

0 commit comments

Comments
 (0)