Skip to content

Commit

Permalink
Parametrise kinematics tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Feb 22, 2024
1 parent d3a6e0f commit 9c92ae8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 24 deletions.
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def _valid_tracks_array(array_type):
n_frames = 10
n_individuals = 2
n_keypoints = 2
base = np.arange(n_frames)[:, np.newaxis, np.newaxis, np.newaxis]
base = np.arange(n_frames, dtype=float)[
:, np.newaxis, np.newaxis, np.newaxis
]
if array_type == "single_keypoint_array":
n_keypoints = 1
elif array_type == "single_track_array":
Expand Down Expand Up @@ -264,6 +266,15 @@ def valid_pose_dataset(valid_tracks_array, request):
)


@pytest.fixture
def valid_pose_dataset_with_nan(valid_pose_dataset):
"""Return a valid pose tracks dataset with NaN values."""
valid_pose_dataset.pose_tracks.loc[
{"individuals": "ind1", "time": [3, 7, 8]}
] = np.nan
return valid_pose_dataset


@pytest.fixture
def not_a_dataset():
"""Return data that is not a pose tracks dataset."""
Expand Down
70 changes: 47 additions & 23 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import nullcontext as does_not_raise

import numpy as np
import pytest
import xarray as xr
Expand Down Expand Up @@ -46,24 +48,56 @@ def _expected_dataarray(property):

return _expected_dataarray

def test_displacement(self, valid_pose_dataset, expected_dataarray):
kinematic_test_params = [
("valid_pose_dataset", does_not_raise()),
("valid_pose_dataset_with_nan", does_not_raise()),
("missing_dim_dataset", pytest.raises(ValueError)),
]

@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
def test_displacement(
self, ds, expected_exception, expected_dataarray, request
):
"""Test displacement computation."""
result = kinematics.compute_displacement(
valid_pose_dataset.pose_tracks
)
xr.testing.assert_allclose(result, expected_dataarray("displacement"))
ds = request.getfixturevalue(ds)
with expected_exception:
result = kinematics.compute_displacement(ds.pose_tracks)
expected = expected_dataarray("displacement")
if ds.pose_tracks.isnull().any():
expected.loc[
{"individuals": "ind1", "time": [3, 4, 7, 8, 9]}
] = np.nan
xr.testing.assert_allclose(result, expected)

def test_velocity(self, valid_pose_dataset, expected_dataarray):
@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
def test_velocity(
self, ds, expected_exception, expected_dataarray, request
):
"""Test velocity computation."""
result = kinematics.compute_velocity(valid_pose_dataset.pose_tracks)
xr.testing.assert_allclose(result, expected_dataarray("velocity"))
ds = request.getfixturevalue(ds)
with expected_exception:
result = kinematics.compute_velocity(ds.pose_tracks)
expected = expected_dataarray("velocity")
if ds.pose_tracks.isnull().any():
expected.loc[
{"individuals": "ind1", "time": [2, 4, 6, 7, 8, 9]}
] = np.nan
xr.testing.assert_allclose(result, expected)

def test_acceleration(self, valid_pose_dataset, expected_dataarray):
@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
def test_acceleration(
self, ds, expected_exception, expected_dataarray, request
):
"""Test acceleration computation."""
result = kinematics.compute_acceleration(
valid_pose_dataset.pose_tracks
)
xr.testing.assert_allclose(result, expected_dataarray("acceleration"))
ds = request.getfixturevalue(ds)
with expected_exception:
result = kinematics.compute_acceleration(ds.pose_tracks)
expected = expected_dataarray("acceleration")
if ds.pose_tracks.isnull().any():
expected.loc[
{"individuals": "ind1", "time": [1, 3, 5, 6, 7, 8, 9]}
] = np.nan
xr.testing.assert_allclose(result, expected)

@pytest.mark.parametrize("order", [0, -1, 1.0, "1"])
def test_approximate_derivative_with_invalid_order(self, order):
Expand All @@ -74,13 +108,3 @@ def test_approximate_derivative_with_invalid_order(self, order):
)
with pytest.raises(expected_exception):
kinematics.compute_approximate_derivative(data, order=order)

def test_compute_with_missing_time_dimension(
self, missing_dim_dataset, kinematic_property
):
"""Test that computing a property of a pose dataset with
missing 'time' dimension raises the appropriate error."""
with pytest.raises(ValueError):
eval(f"kinematics.compute_{kinematic_property}")(
missing_dim_dataset
)

0 comments on commit 9c92ae8

Please sign in to comment.