Skip to content

Commit 9c92ae8

Browse files
committed
Parametrise kinematics tests
1 parent d3a6e0f commit 9c92ae8

File tree

2 files changed

+59
-24
lines changed

2 files changed

+59
-24
lines changed

tests/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ def _valid_tracks_array(array_type):
217217
n_frames = 10
218218
n_individuals = 2
219219
n_keypoints = 2
220-
base = np.arange(n_frames)[:, np.newaxis, np.newaxis, np.newaxis]
220+
base = np.arange(n_frames, dtype=float)[
221+
:, np.newaxis, np.newaxis, np.newaxis
222+
]
221223
if array_type == "single_keypoint_array":
222224
n_keypoints = 1
223225
elif array_type == "single_track_array":
@@ -264,6 +266,15 @@ def valid_pose_dataset(valid_tracks_array, request):
264266
)
265267

266268

269+
@pytest.fixture
270+
def valid_pose_dataset_with_nan(valid_pose_dataset):
271+
"""Return a valid pose tracks dataset with NaN values."""
272+
valid_pose_dataset.pose_tracks.loc[
273+
{"individuals": "ind1", "time": [3, 7, 8]}
274+
] = np.nan
275+
return valid_pose_dataset
276+
277+
267278
@pytest.fixture
268279
def not_a_dataset():
269280
"""Return data that is not a pose tracks dataset."""

tests/test_unit/test_kinematics.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from contextlib import nullcontext as does_not_raise
2+
13
import numpy as np
24
import pytest
35
import xarray as xr
@@ -46,24 +48,56 @@ def _expected_dataarray(property):
4648

4749
return _expected_dataarray
4850

49-
def test_displacement(self, valid_pose_dataset, expected_dataarray):
51+
kinematic_test_params = [
52+
("valid_pose_dataset", does_not_raise()),
53+
("valid_pose_dataset_with_nan", does_not_raise()),
54+
("missing_dim_dataset", pytest.raises(ValueError)),
55+
]
56+
57+
@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
58+
def test_displacement(
59+
self, ds, expected_exception, expected_dataarray, request
60+
):
5061
"""Test displacement computation."""
51-
result = kinematics.compute_displacement(
52-
valid_pose_dataset.pose_tracks
53-
)
54-
xr.testing.assert_allclose(result, expected_dataarray("displacement"))
62+
ds = request.getfixturevalue(ds)
63+
with expected_exception:
64+
result = kinematics.compute_displacement(ds.pose_tracks)
65+
expected = expected_dataarray("displacement")
66+
if ds.pose_tracks.isnull().any():
67+
expected.loc[
68+
{"individuals": "ind1", "time": [3, 4, 7, 8, 9]}
69+
] = np.nan
70+
xr.testing.assert_allclose(result, expected)
5571

56-
def test_velocity(self, valid_pose_dataset, expected_dataarray):
72+
@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
73+
def test_velocity(
74+
self, ds, expected_exception, expected_dataarray, request
75+
):
5776
"""Test velocity computation."""
58-
result = kinematics.compute_velocity(valid_pose_dataset.pose_tracks)
59-
xr.testing.assert_allclose(result, expected_dataarray("velocity"))
77+
ds = request.getfixturevalue(ds)
78+
with expected_exception:
79+
result = kinematics.compute_velocity(ds.pose_tracks)
80+
expected = expected_dataarray("velocity")
81+
if ds.pose_tracks.isnull().any():
82+
expected.loc[
83+
{"individuals": "ind1", "time": [2, 4, 6, 7, 8, 9]}
84+
] = np.nan
85+
xr.testing.assert_allclose(result, expected)
6086

61-
def test_acceleration(self, valid_pose_dataset, expected_dataarray):
87+
@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
88+
def test_acceleration(
89+
self, ds, expected_exception, expected_dataarray, request
90+
):
6291
"""Test acceleration computation."""
63-
result = kinematics.compute_acceleration(
64-
valid_pose_dataset.pose_tracks
65-
)
66-
xr.testing.assert_allclose(result, expected_dataarray("acceleration"))
92+
ds = request.getfixturevalue(ds)
93+
with expected_exception:
94+
result = kinematics.compute_acceleration(ds.pose_tracks)
95+
expected = expected_dataarray("acceleration")
96+
if ds.pose_tracks.isnull().any():
97+
expected.loc[
98+
{"individuals": "ind1", "time": [1, 3, 5, 6, 7, 8, 9]}
99+
] = np.nan
100+
xr.testing.assert_allclose(result, expected)
67101

68102
@pytest.mark.parametrize("order", [0, -1, 1.0, "1"])
69103
def test_approximate_derivative_with_invalid_order(self, order):
@@ -74,13 +108,3 @@ def test_approximate_derivative_with_invalid_order(self, order):
74108
)
75109
with pytest.raises(expected_exception):
76110
kinematics.compute_approximate_derivative(data, order=order)
77-
78-
def test_compute_with_missing_time_dimension(
79-
self, missing_dim_dataset, kinematic_property
80-
):
81-
"""Test that computing a property of a pose dataset with
82-
missing 'time' dimension raises the appropriate error."""
83-
with pytest.raises(ValueError):
84-
eval(f"kinematics.compute_{kinematic_property}")(
85-
missing_dim_dataset
86-
)

0 commit comments

Comments
 (0)