From 41942e1158c61537d06611e7c04b9e1cf06991d2 Mon Sep 17 00:00:00 2001 From: Stella <30465823+stellaprins@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:58:56 +0000 Subject: [PATCH 1/6] Add transforms module with scale function (#384) * add transforms module with scale function * scale drops unit if None provided * allow scaling across multiple dimensions * add docstrings to transforms module and tests * Fix mypy being angry at us * add test_scale_value_error * Use ArrayLike typehint for factor and rename unit to space_unit * ensure broadcasting always happens along the space dimension * parametrize and refactor test_scale_space_dimension * add validate_dims_coords to scale function and compare space shape to factor shape (instead of length) * add test case with number of spatial dimensions equal to the number of timepoints, individuals, and keypoints * add float to scale typing for factor * add 3D space coords validation and test * Revert "add 3D space coords validation and test" This reverts commit 7d5423dcb4f76e4f36fcf8220f4c0cbdebbb71b5. * add 3D space coords validation and test --------- Co-authored-by: willGraham01 --- movement/transforms.py | 72 +++++++++ tests/test_unit/test_transforms.py | 243 +++++++++++++++++++++++++++++ 2 files changed, 315 insertions(+) create mode 100644 movement/transforms.py create mode 100644 tests/test_unit/test_transforms.py diff --git a/movement/transforms.py b/movement/transforms.py new file mode 100644 index 00000000..02fcd44c --- /dev/null +++ b/movement/transforms.py @@ -0,0 +1,72 @@ +"""Transform and add unit attributes to xarray.DataArray datasets.""" + +import numpy as np +import xarray as xr +from numpy.typing import ArrayLike + +from movement.validators.arrays import validate_dims_coords + + +def scale( + data: xr.DataArray, + factor: ArrayLike | float = 1.0, + space_unit: str | None = None, +) -> xr.DataArray: + """Scale data by a given factor with an optional unit. + + Parameters + ---------- + data : xarray.DataArray + The input data to be scaled. + factor : ArrayLike or float + The scaling factor to apply to the data. If factor is a scalar (a + single float), the data array is uniformly scaled by the same factor. + If factor is an object that can be converted to a 1D numpy array (e.g. + a list of floats), the length of the resulting array must match the + length of data array's space dimension along which it will be + broadcasted. + space_unit : str or None + The unit of the scaled data stored as a property in + xarray.DataArray.attrs['space_unit']. In case of the default (``None``) + the ``space_unit`` attribute is dropped. + + Returns + ------- + xarray.DataArray + The scaled data array. + + Notes + ----- + When scale is used multiple times on the same xarray.DataArray, + xarray.DataArray.attrs["space_unit"] is overwritten each time or is dropped + if ``None`` is passed by default or explicitly. + + """ + if len(data.coords["space"]) == 2: + validate_dims_coords(data, {"space": ["x", "y"]}) + else: + validate_dims_coords(data, {"space": ["x", "y", "z"]}) + + if not np.isscalar(factor): + factor = np.array(factor).squeeze() + if factor.ndim != 1: + raise ValueError( + "Factor must be an object that can be converted to a 1D numpy" + f" array, got {factor.ndim}D" + ) + elif factor.shape != data.space.values.shape: + raise ValueError( + f"Factor shape {factor.shape} does not match the shape " + f"of the space dimension {data.space.values.shape}" + ) + else: + factor_dims = [1] * data.ndim # 1s array matching data dimensions + factor_dims[data.get_axis_num("space")] = factor.shape[0] + factor = factor.reshape(factor_dims) + scaled_data = data * factor + + if space_unit is not None: + scaled_data.attrs["space_unit"] = space_unit + elif space_unit is None: + scaled_data.attrs.pop("space_unit", None) + return scaled_data diff --git a/tests/test_unit/test_transforms.py b/tests/test_unit/test_transforms.py new file mode 100644 index 00000000..13fb9b51 --- /dev/null +++ b/tests/test_unit/test_transforms.py @@ -0,0 +1,243 @@ +from typing import Any + +import numpy as np +import pytest +import xarray as xr + +from movement.transforms import scale + +SPATIAL_COORDS_2D = {"space": ["x", "y"]} +SPATIAL_COORDS_3D = {"space": ["x", "y", "z"]} + + +def nparray_0_to_23() -> np.ndarray: + """Create a 2D nparray from 0 to 23.""" + return np.arange(0, 24).reshape(12, 2) + + +def data_array_with_dims_and_coords( + data: np.ndarray, + dims: list | tuple = ("time", "space"), + coords: dict[str, list[str]] = SPATIAL_COORDS_2D, + **attributes: Any, +) -> xr.DataArray: + """Create a DataArray with given data, dimensions, coordinates, and + attributes (e.g. space_unit or factor). + """ + return xr.DataArray( + data, + dims=dims, + coords=coords, + attrs=attributes, + ) + + +@pytest.fixture +def sample_data_2d() -> xr.DataArray: + """Turn the nparray_0_to_23 into a DataArray.""" + return data_array_with_dims_and_coords(nparray_0_to_23()) + + +@pytest.fixture +def sample_data_3d() -> xr.DataArray: + """Turn the nparray_0_to_23 into a DataArray with 3D space.""" + return data_array_with_dims_and_coords( + nparray_0_to_23().reshape(8, 3), + coords=SPATIAL_COORDS_3D, + ) + + +@pytest.mark.parametrize( + ["optional_arguments", "expected_output"], + [ + pytest.param( + {}, + data_array_with_dims_and_coords(nparray_0_to_23()), + id="Do nothing", + ), + pytest.param( + {"space_unit": "elephants"}, + data_array_with_dims_and_coords( + nparray_0_to_23(), space_unit="elephants" + ), + id="No scaling, add space_unit", + ), + pytest.param( + {"factor": 2}, + data_array_with_dims_and_coords(nparray_0_to_23() * 2), + id="Double, no space_unit", + ), + pytest.param( + {"factor": 0.5}, + data_array_with_dims_and_coords(nparray_0_to_23() * 0.5), + id="Halve, no space_unit", + ), + pytest.param( + {"factor": 0.5, "space_unit": "elephants"}, + data_array_with_dims_and_coords( + nparray_0_to_23() * 0.5, space_unit="elephants" + ), + id="Halve, add space_unit", + ), + pytest.param( + {"factor": [0.5, 2]}, + data_array_with_dims_and_coords( + nparray_0_to_23() * [0.5, 2], + ), + id="x / 2, y * 2", + ), + pytest.param( + {"factor": np.array([0.5, 2]).reshape(1, 2)}, + data_array_with_dims_and_coords( + nparray_0_to_23() * [0.5, 2], + ), + id="x / 2, y * 2, should squeeze to cast across space", + ), + ], +) +def test_scale( + sample_data_2d: xr.DataArray, + optional_arguments: dict[str, Any], + expected_output: xr.DataArray, +): + """Test scaling with different factors and space_units.""" + scaled_data = scale(sample_data_2d, **optional_arguments) + xr.testing.assert_equal(scaled_data, expected_output) + assert scaled_data.attrs == expected_output.attrs + + +@pytest.mark.parametrize( + "dims, data_shape", + [ + (["time", "space"], (3, 2)), + (["space", "time"], (2, 3)), + (["time", "individuals", "keypoints", "space"], (3, 6, 4, 2)), + (["time", "individuals", "keypoints", "space"], (2, 2, 2, 2)), + ], + ids=[ + "time-space", + "space-time", + "time-individuals-keypoints-space", + "2x2x2x2", + ], +) +def test_scale_space_dimension(dims: list[str], data_shape): + """Test scaling with transposed data along the correct dimension. + + The scaling factor should be broadcasted along the space axis irrespective + of the order of the dimensions in the input data. + """ + factor = [0.5, 2] + numerical_data = np.arange(np.prod(data_shape)).reshape(data_shape) + data = xr.DataArray(numerical_data, dims=dims, coords=SPATIAL_COORDS_2D) + scaled_data = scale(data, factor=factor) + broadcast_list = [1 if dim != "space" else len(factor) for dim in dims] + expected_output_data = data * np.array(factor).reshape(broadcast_list) + + assert scaled_data.shape == data.shape + xr.testing.assert_equal(scaled_data, expected_output_data) + + +@pytest.mark.parametrize( + ["optional_arguments_1", "optional_arguments_2", "expected_output"], + [ + pytest.param( + {"factor": 2, "space_unit": "elephants"}, + {"factor": 0.5, "space_unit": "crabs"}, + data_array_with_dims_and_coords( + nparray_0_to_23(), space_unit="crabs" + ), + id="No net scaling, final crabs space_unit", + ), + pytest.param( + {"factor": 2, "space_unit": "elephants"}, + {"factor": 0.5, "space_unit": None}, + data_array_with_dims_and_coords(nparray_0_to_23()), + id="No net scaling, no final space_unit", + ), + pytest.param( + {"factor": 2, "space_unit": None}, + {"factor": 0.5, "space_unit": "elephants"}, + data_array_with_dims_and_coords( + nparray_0_to_23(), space_unit="elephants" + ), + id="No net scaling, final elephant space_unit", + ), + ], +) +def test_scale_twice( + sample_data_2d: xr.DataArray, + optional_arguments_1: dict[str, Any], + optional_arguments_2: dict[str, Any], + expected_output: xr.DataArray, +): + """Test scaling when applied twice. + The second scaling operation should update the space_unit attribute if + provided, or remove it if None is passed explicitly or by default. + """ + output_data_array = scale( + scale(sample_data_2d, **optional_arguments_1), + **optional_arguments_2, + ) + xr.testing.assert_equal(output_data_array, expected_output) + assert output_data_array.attrs == expected_output.attrs + + +@pytest.mark.parametrize( + "invalid_factor, expected_error_message", + [ + pytest.param( + np.zeros((3, 3, 4)), + "Factor must be an object that can be converted to a 1D numpy" + " array, got 3D", + id="3D factor", + ), + pytest.param( + np.zeros(3), + "Factor shape (3,) does not match the shape " + "of the space dimension (2,)", + id="space dimension mismatch", + ), + ], +) +def test_scale_value_error( + sample_data_2d: xr.DataArray, + invalid_factor: np.ndarray, + expected_error_message: str, +): + """Test invalid factors raise correct error type and message.""" + with pytest.raises(ValueError) as error: + scale(sample_data_2d, factor=invalid_factor) + assert str(error.value) == expected_error_message + + +@pytest.mark.parametrize( + "factor", [2, [1, 2, 0.5]], ids=["uniform scaling", "multi-axis scaling"] +) +def test_scale_3d_space(factor, sample_data_3d: xr.DataArray): + """Test scaling a DataArray with 3D space.""" + scaled_data = scale(sample_data_3d, factor=factor) + expected_output = data_array_with_dims_and_coords( + nparray_0_to_23().reshape(8, 3) * np.array(factor).reshape(1, -1), + coords=SPATIAL_COORDS_3D, + ) + xr.testing.assert_equal(scaled_data, expected_output) + + +@pytest.mark.parametrize( + "factor", + [2, [1, 2, 0.5]], + ids=["uniform scaling", "multi-axis scaling"], +) +def test_scale_invalid_3d_space(factor): + """Test scaling data with invalid 3D space coordinates.""" + invalid_coords = {"space": ["x", "flubble", "y"]} # "z" is missing + invalid_sample_data_3d = data_array_with_dims_and_coords( + nparray_0_to_23().reshape(8, 3), + coords=invalid_coords, + ) + with pytest.raises(ValueError) as error: + scale(invalid_sample_data_3d, factor=factor) + assert str(error.value) == ( + "Input data must contain ['z'] in the 'space' coordinates.\n" + ) From 84b245cfe8c2375238d3da07abb45e18e98efcbc Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 30 Jan 2025 17:47:23 +0000 Subject: [PATCH 2/6] bounding boxes' tracks --> bounding boxes tracks (#392) * bounding boxes' tracks --> bounding boxes tracks * Remove remaining bounding boxes apostrophes for consistency * bounding boxes tracks --> bounding box tracks * bounding boxes centroids -> bounding box centroids * Apply suggestions from code review --------- Co-authored-by: Chang Huan Lo --- docs/source/blog/movement-v0_0_21.md | 2 +- docs/source/community/mission-scope.md | 4 ++-- docs/source/user_guide/input_output.md | 20 ++++++++++---------- docs/source/user_guide/movement_dataset.md | 18 +++++++++--------- examples/load_and_upsample_bboxes.py | 4 ++-- movement/io/load_bboxes.py | 12 ++++++------ movement/validators/datasets.py | 4 ++-- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/docs/source/blog/movement-v0_0_21.md b/docs/source/blog/movement-v0_0_21.md index 5527ab2d..a0060e26 100644 --- a/docs/source/blog/movement-v0_0_21.md +++ b/docs/source/blog/movement-v0_0_21.md @@ -21,7 +21,7 @@ install the latest version or upgrade from an existing installation. __Input/Output__ - We have added the {func}`movement.io.load_poses.from_multiview_files` function to support loading pose tracking data from multiple camera views. -- We have made several small improvements to reading bounding boxes tracks. See our new {ref}`example ` to learn more about working with bounding boxes. +- We have made several small improvements to reading bounding box tracks. See our new {ref}`example ` to learn more about working with bounding boxes. - We have added a new {ref}`example ` on using `movement` to convert pose tracking data between different file formats. __Kinematics__ diff --git a/docs/source/community/mission-scope.md b/docs/source/community/mission-scope.md index f4f5dc9d..e9663080 100644 --- a/docs/source/community/mission-scope.md +++ b/docs/source/community/mission-scope.md @@ -27,8 +27,8 @@ Animal tracking frameworks such as [DeepLabCut](dlc:) or [SLEAP](sleap:) can generate keypoint representations from video data by detecting body parts and tracking them across frames. In the context of `movement`, we refer to these trajectories as _tracks_: we use _pose tracks_ to refer to the trajectories -of a set of keypoints, _bounding boxes' tracks_ to refer to the trajectories -of bounding boxes' centroids, or _motion tracks_ in the more general case. +of a set of keypoints, _bounding box tracks_ to refer to the trajectories +of bounding box centroids, or _motion tracks_ in the more general case. Our vision is to present a **consistent interface for representing motion tracks** along with **modular and accessible analysis tools**. We aim to diff --git a/docs/source/user_guide/input_output.md b/docs/source/user_guide/input_output.md index 4dd7445f..d25d4510 100644 --- a/docs/source/user_guide/input_output.md +++ b/docs/source/user_guide/input_output.md @@ -4,7 +4,7 @@ (target-formats)= ## Supported formats (target-supported-formats)= -`movement` supports the analysis of trajectories of keypoints (_pose tracks_) and of bounding boxes' centroids (_bounding boxes' tracks_). +`movement` supports the analysis of trajectories of keypoints (_pose tracks_) and of bounding box centroids (_bounding box tracks_). To analyse pose tracks, `movement` supports loading data from various frameworks: - [DeepLabCut](dlc:) (DLC) @@ -12,13 +12,13 @@ To analyse pose tracks, `movement` supports loading data from various frameworks - [LightingPose](lp:) (LP) - [Anipose](anipose:) (Anipose) -To analyse bounding boxes' tracks, `movement` currently supports the [VGG Image Annotator](via:) (VIA) format for [tracks annotation](via:docs/face_track_annotation.html). +To analyse bounding box tracks, `movement` currently supports the [VGG Image Annotator](via:) (VIA) format for [tracks annotation](via:docs/face_track_annotation.html). :::{note} At the moment `movement` only deals with tracked data: either keypoints or bounding boxes whose identities are known from one frame to the next, for a consecutive set of frames. For the pose estimation case, this means it only deals with the predictions output by the software packages above. It currently does not support loading manually labelled data (since this is most often defined over a non-continuous set of frames). ::: -Below we explain how you can load pose tracks and bounding boxes' tracks into `movement`, and how you can export a `movement` poses dataset to different file formats. You can also try `movement` out on some [sample data](target-sample-data) +Below we explain how you can load pose tracks and bounding box tracks into `movement`, and how you can export a `movement` poses dataset to different file formats. You can also try `movement` out on some [sample data](target-sample-data) included with the package. @@ -129,15 +129,15 @@ For more information on the poses data structure, see the [movement poses datase (target-loading-bbox-tracks)= -## Loading bounding boxes' tracks -To load bounding boxes' tracks into a [movement bounding boxes dataset](target-poses-and-bboxes-dataset), we need the functions from the +## Loading bounding box tracks +To load bounding box tracks into a [movement bounding boxes dataset](target-poses-and-bboxes-dataset), we need the functions from the {mod}`movement.io.load_bboxes` module. This module can be imported as: ```python from movement.io import load_bboxes ``` -We currently support loading bounding boxes' tracks in the VGG Image Annotator (VIA) format only. However, like in the poses datasets, we additionally provide a `from_numpy()` method, with which we can build a [movement bounding boxes dataset](target-poses-and-bboxes-dataset) from a set of NumPy arrays. +We currently support loading bounding box tracks in the VGG Image Annotator (VIA) format only. However, like in the poses datasets, we additionally provide a `from_numpy()` method, with which we can build a [movement bounding boxes dataset](target-poses-and-bboxes-dataset) from a set of NumPy arrays. ::::{tab-set} :::{tab-item} VGG Image Annotator @@ -247,9 +247,9 @@ save_poses.to_dlc_file(ds, "/path/to/file.csv", split_individuals=True) (target-saving-bboxes-tracks)= -## Saving bounding boxes' tracks +## Saving bounding box tracks -We currently do not provide explicit methods to export a movement bounding boxes dataset in a specific format. However, you can easily save the bounding boxes' trajectories to a .csv file using the standard Python library `csv`. +We currently do not provide explicit methods to export a movement bounding boxes dataset in a specific format. However, you can easily save the bounding box tracks to a .csv file using the standard Python library `csv`. Here is an example of how you can save a bounding boxes dataset to a .csv file: @@ -273,14 +273,14 @@ with open(filepath, mode="w", newline="") as file: writer.writerow([frame, individual, x, y, width, height, confidence]) ``` -Alternatively, we can convert the `movement` bounding boxes' dataset to a pandas DataFrame with the {meth}`xarray.DataArray.to_dataframe` method, wrangle the dataframe as required, and then apply the {meth}`pandas.DataFrame.to_csv` method to save the data as a .csv file. +Alternatively, we can convert the `movement` bounding boxes dataset to a pandas DataFrame with the {meth}`xarray.DataArray.to_dataframe` method, wrangle the dataframe as required, and then apply the {meth}`pandas.DataFrame.to_csv` method to save the data as a .csv file. (target-sample-data)= ## Sample data `movement` includes some sample data files that you can use to -try the package out. These files contain pose and bounding boxes' tracks from +try the package out. These files contain pose and bounding box tracks from various [supported formats](target-supported-formats). You can list the available sample data files using: diff --git a/docs/source/user_guide/movement_dataset.md b/docs/source/user_guide/movement_dataset.md index 2b080ae1..8c83b378 100644 --- a/docs/source/user_guide/movement_dataset.md +++ b/docs/source/user_guide/movement_dataset.md @@ -1,14 +1,14 @@ (target-poses-and-bboxes-dataset)= # The movement datasets -In `movement`, poses or bounding boxes' tracks are represented +In `movement`, poses or bounding box tracks are represented as an {class}`xarray.Dataset` object. An {class}`xarray.Dataset` object is a container for multiple arrays. Each array is an {class}`xarray.DataArray` object holding different aspects of the collected data (position, time, confidence scores...). You can think of a {class}`xarray.DataArray` object as a multi-dimensional {class}`numpy.ndarray` with pandas-style indexing and labelling. So a `movement` dataset is simply an {class}`xarray.Dataset` with a specific -structure to represent pose tracks or bounding boxes' tracks. Because pose data and bounding boxes data are somewhat different, `movement` provides two types of datasets: `poses` datasets and `bboxes` datasets. +structure to represent pose tracks or bounding box tracks. Because pose data and bounding box data are somewhat different, `movement` provides two types of datasets: `poses` datasets and `bboxes` datasets. To discuss the specifics of both types of `movement` datasets, it is useful to clarify some concepts such as **data variables**, **dimensions**, **coordinates** and **attributes**. In the next section, we will describe these concepts and the `movement` datasets' structure in some detail. @@ -64,8 +64,8 @@ Attributes: ::: -:::{tab-item} Bounding boxes' dataset -To inspect a sample bounding boxes' dataset, we can run: +:::{tab-item} Bounding boxes dataset +To inspect a sample bounding boxes dataset, we can run: ```python from movement import sample_data @@ -119,7 +119,7 @@ A `movement` poses dataset has the following **dimensions**: - `individuals`, with size equal to the number of tracked individuals/instances. ::: -:::{tab-item} Bounding boxes' dataset +:::{tab-item} Bounding boxes dataset A `movement` bounding boxes dataset has the following **dimensions**s: - `time`, with size equal to the number of frames in the video. - `space`, which is the number of spatial dimensions. Currently, we support only 2D bounding boxes data. @@ -139,7 +139,7 @@ In both cases, appropriate **coordinates** are assigned to each **dimension**. :icon: info The above **dimensions** and **coordinates** are created by default when loading a `movement` dataset from a single -file containing pose or bounding boxes tracks. +file containing pose or bounding box tracks. In some cases, you may encounter or create datasets with extra **dimensions**. For example, the @@ -160,9 +160,9 @@ A `movement` poses dataset contains two **data variables**: - `confidence`: the confidence scores associated with each predicted keypoint (as reported by the pose estimation model), with shape (`time`, `keypoints`, `individuals`). ::: -:::{tab-item} Bounding boxes' dataset +:::{tab-item} Bounding boxes dataset A `movement` bounding boxes dataset contains three **data variables**: -- `position`: the 2D locations of the bounding boxes' centroids over time, with shape (`time`, `space`, `individuals`). +- `position`: the 2D locations of the bounding box centroids over time, with shape (`time`, `space`, `individuals`). - `shape`: the width and height of the bounding boxes over time, with shape (`time`, `space`, `individuals`). - `confidence`: the confidence scores associated with each predicted bounding box, with shape (`time`, `individuals`). ::: @@ -179,7 +179,7 @@ Both poses and bounding boxes datasets in `movement` have associated metadata. T Right after loading a `movement` dataset, the following **attributes** are created: - `fps`: the number of frames per second in the video. If not provided, it is set to `None`. - `time_unit`: the unit of the `time` **coordinates** (either `frames` or `seconds`). -- `source_software`: the software that produced the pose or bounding boxes tracks. +- `source_software`: the software that produced the pose or bounding box tracks. - `source_file`: the path to the file from which the data were loaded. - `ds_type`: the type of dataset loaded (either `poses` or `bboxes`). diff --git a/examples/load_and_upsample_bboxes.py b/examples/load_and_upsample_bboxes.py index cfb1bba3..045832dd 100644 --- a/examples/load_and_upsample_bboxes.py +++ b/examples/load_and_upsample_bboxes.py @@ -1,7 +1,7 @@ -"""Load and upsample bounding boxes tracks +"""Load and upsample bounding box tracks ========================================== -Load bounding boxes tracks and upsample them to match the video frame rate. +Load bounding box tracks and upsample them to match the video frame rate. """ # %% diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index c6f3bef1..e5c2014b 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -1,4 +1,4 @@ -"""Load bounding boxes' tracking data into ``movement``.""" +"""Load bounding boxes tracking data into ``movement``.""" import ast import logging @@ -37,7 +37,7 @@ def from_numpy( ---------- position_array : np.ndarray Array of shape (n_frames, n_space, n_individuals) - containing the tracks of the bounding boxes' centroids. + containing the tracks of the bounding box centroids. It will be converted to a :class:`xarray.DataArray` object named "position". shape_array : np.ndarray @@ -277,7 +277,7 @@ def from_via_tracks_file( Notes ----- - The bounding boxes' IDs specified in the "track" field of the VIA + The bounding boxes IDs specified in the "track" field of the VIA tracks .csv file are mapped to the "individual_name" column of the ``movement`` dataset. The individual names follow the format ``id_``, with N being the bounding box ID. @@ -377,7 +377,7 @@ def _numpy_arrays_from_via_tracks_file( keys: - position_array (n_frames, n_space, n_individuals): - contains the trajectories of the bounding boxes' centroids. + contains the trajectories of the bounding box centroids. - shape_array (n_frames, n_space, n_individuals): contains the shape of the bounding boxes (width and height). - confidence_array (n_frames, n_individuals): @@ -391,7 +391,7 @@ def _numpy_arrays_from_via_tracks_file( Parameters ---------- file_path : pathlib.Path - Path to the VIA tracks .csv file containing the bounding boxes' tracks. + Path to the VIA tracks .csv file containing the bounding box tracks. frame_regexp : str Regular expression pattern to extract the frame number from the frame @@ -402,7 +402,7 @@ def _numpy_arrays_from_via_tracks_file( Returns ------- dict - The validated bounding boxes' arrays. + The validated bounding boxes arrays. """ # Extract 2D dataframe from input data diff --git a/movement/validators/datasets.py b/movement/validators/datasets.py index 6e1dc066..0f51eec5 100644 --- a/movement/validators/datasets.py +++ b/movement/validators/datasets.py @@ -226,7 +226,7 @@ def __attrs_post_init__(self): @define(kw_only=True) class ValidBboxesDataset: - """Class for validating bounding boxes' data for a ``movement`` dataset. + """Class for validating bounding boxes data for a ``movement`` dataset. The validator considers 2D bounding boxes only. It ensures that within the ``movement bboxes`` dataset: @@ -250,7 +250,7 @@ class ValidBboxesDataset: ---------- position_array : np.ndarray Array of shape (n_frames, n_space, n_individuals) - containing the tracks of the bounding boxes' centroids. + containing the tracks of the bounding box centroids. shape_array : np.ndarray Array of shape (n_frames, n_space, n_individuals) containing the shape of the bounding boxes. The shape of a bounding From 199f8f22e4881afcedbf7987cf1a9776d3fa632e Mon Sep 17 00:00:00 2001 From: Chang Huan Lo Date: Fri, 31 Jan 2025 12:32:48 +0100 Subject: [PATCH 3/6] Reorganise test fixtures (#380) * Use consistent test ids (starting from `id_0`) * Use consistent names for fixtures `with_nan` * Align uniform linear motion fixture with bboxes fixture * Replace poses fixture in test_filtering * Group filtering tests by common dataset params * Replace poses fixture in test_kinematics * Replace poses fixture in test_save_poses * Replace poses fixture in datasets missing dim and var * Replace poses fixture in test_reports * Replace poses fixture in test_io * Replace poses fixture in test_logging * Remove valid_poses_dataset fixtures * Rename valid_poses_dataset_uniform_linear_motion fixtures * Replace valid_position_array fixtures in test_load_poses * Replace valid_position_array fixtures in test_datasets_validators * Remove valid_position_array fixture * Rename valid_poses_array_uniform_linear_motion * Fix newlines * Fix up rebase merge error * Modularise fixtures * Group kinematics tests by common params * Shorten arg name * Refactor fixtures * Rename datasets.py to data.py * Swap poses "centroid" and "left" keypoint NaNs * Fix up filtering test expectations after rebase * Suggestion to move dataset validator fixture under helpers * Apply suggestions from code review Co-authored-by: sfmig <33267254+sfmig@users.noreply.github.com> * Fix up `valid_dlc_poses_df` rename and E501 * Rename dataset fixtures module * Fix newlines * Update fixture descriptions * Apply suggestions from code review Co-authored-by: sfmig <33267254+sfmig@users.noreply.github.com> * Update anipose fixture descriptions Co-authored-by: sfmig <33267254+sfmig@users.noreply.github.com> * Rename `wrong_extension`-related fixtures * Refer to fixtures in test descriptions * Rename `compute_time_derivative` test + parametrise expectations --------- Co-authored-by: sfmig <33267254+sfmig@users.noreply.github.com> --- .pre-commit-config.yaml | 1 + tests/conftest.py | 992 +----------------- tests/fixtures/__init__.py | 0 tests/fixtures/datasets.py | 356 +++++++ tests/fixtures/files.py | 440 ++++++++ tests/fixtures/helpers.py | 97 ++ tests/test_integration/test_io.py | 6 +- .../test_kinematics_vector_transform.py | 13 +- tests/test_unit/test_filtering.py | 408 +++---- tests/test_unit/test_kinematics.py | 409 +++----- tests/test_unit/test_load_bboxes.py | 12 +- tests/test_unit/test_load_poses.py | 38 +- tests/test_unit/test_logging.py | 5 +- tests/test_unit/test_reports.py | 97 +- tests/test_unit/test_save_poses.py | 12 +- .../test_validators/test_array_validators.py | 4 +- .../test_datasets_validators.py | 30 +- .../test_validators/test_files_validators.py | 28 +- 18 files changed, 1335 insertions(+), 1613 deletions(-) create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/datasets.py create mode 100644 tests/fixtures/files.py create mode 100644 tests/fixtures/helpers.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54fb80ad..9258a3fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,7 @@ repos: args: [--fix=lf] - id: name-tests-test args: ["--pytest-test-first"] + exclude: ^tests/fixtures/ - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/pre-commit/pygrep-hooks diff --git a/tests/conftest.py b/tests/conftest.py index 4a760514..29e823ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,24 @@ -"""Fixtures and configurations applied to the entire test suite.""" +"""Fixtures and configurations shared by the entire test suite.""" import logging -import os -from pathlib import Path -from unittest.mock import mock_open, patch +from glob import glob -import h5py -import numpy as np -import pandas as pd import pytest -import xarray as xr from movement.sample_data import fetch_dataset_paths, list_datasets from movement.utils.logging import configure_logging -from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset + + +def _to_module_string(path: str) -> str: + """Convert a file path to a module string.""" + return path.replace("/", ".").replace("\\", ".").replace(".py", "") + + +pytest_plugins = [ + _to_module_string(fixture) + for fixture in glob("tests/fixtures/*.py") + if "__" not in fixture +] def pytest_configure(): @@ -37,972 +42,3 @@ def setup_logging(tmp_path): logger_name="movement", log_directory=(tmp_path / ".movement"), ) - - -# --------- File validator fixtures --------------------------------- -@pytest.fixture -def unreadable_file(tmp_path): - """Return a dictionary containing the file path and - expected permission for an unreadable .h5 file. - """ - file_path = tmp_path / "unreadable.h5" - file_mock = mock_open() - file_mock.return_value.read.side_effect = PermissionError - with ( - patch("builtins.open", side_effect=file_mock), - patch.object(Path, "exists", return_value=True), - ): - yield { - "file_path": file_path, - "expected_permission": "r", - } - - -@pytest.fixture -def unwriteable_file(tmp_path): - """Return a dictionary containing the file path and - expected permission for an unwriteable .h5 file. - """ - unwriteable_dir = tmp_path / "no_write" - unwriteable_dir.mkdir() - original_access = os.access - - def mock_access(path, mode): - if path == unwriteable_dir and mode == os.W_OK: - return False - # Ensure that the original access function is called - # for all other cases - return original_access(path, mode) - - with patch("os.access", side_effect=mock_access): - file_path = unwriteable_dir / "unwriteable.h5" - yield { - "file_path": file_path, - "expected_permission": "w", - } - - -@pytest.fixture -def wrong_ext_file(tmp_path): - """Return a dictionary containing the file path, - expected permission, and expected suffix for a file - with an incorrect extension. - """ - file_path = tmp_path / "wrong_extension.txt" - with open(file_path, "w") as f: - f.write("") - return { - "file_path": file_path, - "expected_permission": "r", - "expected_suffix": ["h5", "csv"], - } - - -@pytest.fixture -def nonexistent_file(tmp_path): - """Return a dictionary containing the file path and - expected permission for a nonexistent file. - """ - file_path = tmp_path / "nonexistent.h5" - return { - "file_path": file_path, - "expected_permission": "r", - } - - -@pytest.fixture -def directory(tmp_path): - """Return a dictionary containing the file path and - expected permission for a directory. - """ - file_path = tmp_path / "directory" - file_path.mkdir() - return { - "file_path": file_path, - "expected_permission": "r", - } - - -@pytest.fixture -def h5_file_no_dataframe(tmp_path): - """Return a dictionary containing the file path and - expected datasets for a .h5 file with no dataframe. - """ - file_path = tmp_path / "no_dataframe.h5" - with h5py.File(file_path, "w") as f: - f.create_dataset("data_in_list", data=[1, 2, 3]) - return { - "file_path": file_path, - "expected_datasets": ["dataframe"], - } - - -@pytest.fixture -def fake_h5_file(tmp_path): - """Return a dictionary containing the file path, - expected exception, and expected datasets for - a file with .h5 extension that is not in HDF5 format. - """ - file_path = tmp_path / "fake.h5" - with open(file_path, "w") as f: - f.write("") - return { - "file_path": file_path, - "expected_datasets": ["dataframe"], - "expected_permission": "w", - } - - -@pytest.fixture -def invalid_single_individual_csv_file(tmp_path): - """Return the file path for a fake single-individual .csv file.""" - file_path = tmp_path / "fake_single_individual.csv" - with open(file_path, "w") as f: - f.write("scorer,columns\nsome,columns\ncoords,columns\n") - f.write("1,2") - return file_path - - -@pytest.fixture -def invalid_multi_individual_csv_file(tmp_path): - """Return the file path for a fake multi-individual .csv file.""" - file_path = tmp_path / "fake_multi_individual.csv" - with open(file_path, "w") as f: - f.write( - "scorer,columns\nindividuals,columns\nbodyparts,columns\nsome,columns\n" - ) - f.write("1,2") - return file_path - - -@pytest.fixture -def new_file_wrong_ext(tmp_path): - """Return the file path for a new file with the wrong extension.""" - return tmp_path / "new_file_wrong_ext.txt" - - -@pytest.fixture -def new_h5_file(tmp_path): - """Return the file path for a new .h5 file.""" - return tmp_path / "new_file.h5" - - -@pytest.fixture -def new_csv_file(tmp_path): - """Return the file path for a new .csv file.""" - return tmp_path / "new_file.csv" - - -@pytest.fixture -def dlc_style_df(): - """Return a valid DLC-style DataFrame.""" - return pd.read_hdf(pytest.DATA_PATHS.get("DLC_single-wasp.predictions.h5")) - - -@pytest.fixture -def missing_keypoint_columns_anipose_csv_file(tmp_path): - """Return the file path for a fake single-individual .csv file.""" - file_path = tmp_path / "missing_keypoint_columns.csv" - columns = [ - "fnum", - "center_0", - "center_1", - "center_2", - "M_00", - "M_01", - "M_02", - "M_10", - "M_11", - "M_12", - "M_20", - "M_21", - "M_22", - ] - # Here we are missing kp0_z: - columns.extend(["kp0_x", "kp0_y", "kp0_score", "kp0_error", "kp0_ncams"]) - with open(file_path, "w") as f: - f.write(",".join(columns)) - f.write("\n") - f.write(",".join(["1"] * len(columns))) - return file_path - - -@pytest.fixture -def spurious_column_anipose_csv_file(tmp_path): - """Return the file path for a fake single-individual .csv file.""" - file_path = tmp_path / "spurious_column.csv" - columns = [ - "fnum", - "center_0", - "center_1", - "center_2", - "M_00", - "M_01", - "M_02", - "M_10", - "M_11", - "M_12", - "M_20", - "M_21", - "M_22", - ] - columns.extend(["funny_column"]) - with open(file_path, "w") as f: - f.write(",".join(columns)) - f.write("\n") - f.write(",".join(["1"] * len(columns))) - return file_path - - -@pytest.fixture( - params=[ - "SLEAP_single-mouse_EPM.analysis.h5", - "SLEAP_single-mouse_EPM.predictions.slp", - "SLEAP_three-mice_Aeon_proofread.analysis.h5", - "SLEAP_three-mice_Aeon_proofread.predictions.slp", - "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5", - "SLEAP_three-mice_Aeon_mixed-labels.predictions.slp", - ] -) -def sleap_file(request): - """Return the file path for a SLEAP .h5 or .slp file.""" - return pytest.DATA_PATHS.get(request.param) - - -# ------------ Dataset validator fixtures --------------------------------- - - -@pytest.fixture -def valid_bboxes_arrays_all_zeros(): - """Return a dictionary of valid zero arrays (in terms of shape) for a - ValidBboxesDataset. - """ - # define the shape of the arrays - n_frames, n_space, n_individuals = (10, 2, 2) - - # build a valid array for position or shape with all zeros - valid_bbox_array_all_zeros = np.zeros((n_frames, n_space, n_individuals)) - - # return as a dict - return { - "position": valid_bbox_array_all_zeros, - "shape": valid_bbox_array_all_zeros, - "individual_names": ["id_" + str(id) for id in range(n_individuals)], - } - - -# --------------------- Bboxes dataset fixtures ---------------------------- -@pytest.fixture -def valid_bboxes_arrays(): - """Return a dictionary of valid arrays for a - ValidBboxesDataset representing a uniform linear motion. - - It represents 2 individuals for 10 frames, in 2D space. - - Individual 0 moves along the x=y line from the origin. - - Individual 1 moves along the x=-y line line from the origin. - - All confidence values are set to 0.9 except the following which are set - to 0.1: - - Individual 0 at frames 2, 3, 4 - - Individual 1 at frames 2, 3 - """ - # define the shape of the arrays - n_frames, n_space, n_individuals = (10, 2, 2) - - # build a valid array for position - # make bbox with id_i move along x=((-1)**(i))*y line from the origin - # if i is even: along x = y line - # if i is odd: along x = -y line - # moving one unit along each axis in each frame - position = np.zeros((n_frames, n_space, n_individuals)) - for i in range(n_individuals): - position[:, 0, i] = np.arange(n_frames) - position[:, 1, i] = (-1) ** i * np.arange(n_frames) - - # build a valid array for constant bbox shape (60, 40) - constant_shape = (60, 40) # width, height in pixels - shape = np.tile(constant_shape, (n_frames, n_individuals, 1)).transpose( - 0, 2, 1 - ) - - # build an array of confidence values, all 0.9 - confidence = np.full((n_frames, n_individuals), 0.9) - - # set 5 low-confidence values - # - set 3 confidence values for bbox id_0 to 0.1 - # - set 2 confidence values for bbox id_1 to 0.1 - idx_start = 2 - confidence[idx_start : idx_start + 3, 0] = 0.1 - confidence[idx_start : idx_start + 2, 1] = 0.1 - - return { - "position": position, - "shape": shape, - "confidence": confidence, - } - - -@pytest.fixture -def valid_bboxes_dataset( - valid_bboxes_arrays, -): - """Return a valid bboxes dataset for two individuals moving in uniform - linear motion, with 5 frames with low confidence values and time in frames. - """ - dim_names = ValidBboxesDataset.DIM_NAMES - - position_array = valid_bboxes_arrays["position"] - shape_array = valid_bboxes_arrays["shape"] - confidence_array = valid_bboxes_arrays["confidence"] - - n_frames, n_individuals, _ = position_array.shape - - return xr.Dataset( - data_vars={ - "position": xr.DataArray(position_array, dims=dim_names), - "shape": xr.DataArray(shape_array, dims=dim_names), - "confidence": xr.DataArray( - confidence_array, dims=dim_names[:1] + dim_names[2:] - ), - }, - coords={ - dim_names[0]: np.arange(n_frames), - dim_names[1]: ["x", "y"], - dim_names[2]: [f"id_{id}" for id in range(n_individuals)], - }, - attrs={ - "fps": None, - "time_unit": "frames", - "source_software": "test", - "source_file": "test_bboxes.csv", - "ds_type": "bboxes", - }, - ) - - -@pytest.fixture -def valid_bboxes_dataset_in_seconds(valid_bboxes_dataset): - """Return a valid bboxes dataset with time in seconds. - - The origin of time is assumed to be time = frame 0 = 0 seconds. - """ - fps = 60 - valid_bboxes_dataset["time"] = valid_bboxes_dataset.time / fps - valid_bboxes_dataset.attrs["time_unit"] = "seconds" - valid_bboxes_dataset.attrs["fps"] = fps - return valid_bboxes_dataset - - -@pytest.fixture -def valid_bboxes_dataset_with_nan(valid_bboxes_dataset): - """Return a valid bboxes dataset with NaN values in the position array.""" - # Set 3 NaN values in the position array for id_0 - valid_bboxes_dataset.position.loc[ - {"individuals": "id_0", "time": [3, 7, 8]} - ] = np.nan - return valid_bboxes_dataset - - -# --------------------- Poses dataset fixtures ---------------------------- -@pytest.fixture -def valid_position_array(): - """Return a function that generates different kinds - of a valid position array. - """ - - def _valid_position_array(array_type): - """Return a valid position array.""" - # Unless specified, default is a multi_individual_array with - # 10 frames, 2 keypoints, and 2 individuals. - n_frames = 10 - n_keypoints = 2 - n_individuals = 2 - base = np.arange(n_frames, dtype=float)[ - :, np.newaxis, np.newaxis, np.newaxis - ] - if array_type == "single_keypoint_array": - n_keypoints = 1 - elif array_type == "single_individual_array": - n_individuals = 1 - x_points = np.repeat(base * base, n_keypoints * n_individuals) - y_points = np.repeat(base * 4, n_keypoints * n_individuals) - position_array = np.vstack((x_points, y_points)) - return position_array.reshape(n_frames, 2, n_keypoints, n_individuals) - - return _valid_position_array - - -@pytest.fixture -def valid_poses_dataset(valid_position_array, request): - """Return a valid pose tracks dataset.""" - dim_names = ValidPosesDataset.DIM_NAMES - # create a multi_individual_array by default unless overridden via param - try: - array_format = request.param - except AttributeError: - array_format = "multi_individual_array" - position_array = valid_position_array(array_format) - n_frames, n_keypoints, n_individuals = ( - position_array.shape[:1] + position_array.shape[2:] - ) - return xr.Dataset( - data_vars={ - "position": xr.DataArray(position_array, dims=dim_names), - "confidence": xr.DataArray( - np.repeat( - np.linspace(0.1, 1.0, n_frames), - n_keypoints * n_individuals, - ).reshape(position_array.shape[:1] + position_array.shape[2:]), - dims=dim_names[:1] + dim_names[2:], # exclude "space" - ), - }, - coords={ - "time": np.arange(n_frames), - "space": ["x", "y"], - "keypoints": [f"key{i}" for i in range(1, n_keypoints + 1)], - "individuals": [f"ind{i}" for i in range(1, n_individuals + 1)], - }, - attrs={ - "fps": None, - "time_unit": "frames", - "source_software": "SLEAP", - "source_file": "test.h5", - "ds_type": "poses", - }, - ) - - -@pytest.fixture -def valid_poses_dataset_with_nan(valid_poses_dataset): - """Return a valid pose tracks dataset with NaN values.""" - # Sets position for all keypoints in individual ind1 to NaN - # at timepoints 3, 7, 8 - valid_poses_dataset.position.loc[ - {"individuals": "ind1", "time": [3, 7, 8]} - ] = np.nan - return valid_poses_dataset - - -@pytest.fixture -def valid_poses_array_uniform_linear_motion(): - """Return a dictionary of valid arrays for a - ValidPosesDataset representing a uniform linear motion. - - It represents 2 individuals with 3 keypoints, for 10 frames, in 2D space. - - Individual 0 moves along the x=y line from the origin. - - Individual 1 moves along the x=-y line line from the origin. - - All confidence values for all keypoints are set to 0.9 except - for the keypoints at the following frames which are set to 0.1: - - Individual 0 at frames 2, 3, 4 - - Individual 1 at frames 2, 3 - """ - # define the shape of the arrays - n_frames, n_space, n_keypoints, n_individuals = (10, 2, 3, 2) - - # define centroid (index=0) trajectory in position array - # for each individual, the centroid moves along - # the x=+/-y line, starting from the origin. - # - individual 0 moves along x = y line - # - individual 1 moves along x = -y line - # They move one unit along x and y axes in each frame - frames = np.arange(n_frames) - position = np.zeros((n_frames, n_space, n_keypoints, n_individuals)) - position[:, 0, 0, :] = frames[:, None] # reshape to (n_frames, 1) - position[:, 1, 0, 0] = frames - position[:, 1, 0, 1] = -frames - - # define trajectory of left and right keypoints - # for individual 0, at each timepoint: - # - the left keypoint (index=1) is at x_centroid, y_centroid + 1 - # - the right keypoint (index=2) is at x_centroid + 1, y_centroid - # for individual 1, at each timepoint: - # - the left keypoint (index=1) is at x_centroid - 1, y_centroid - # - the right keypoint (index=2) is at x_centroid, y_centroid + 1 - offsets = [ - [(0, 1), (1, 0)], # individual 0: left, right keypoints (x,y) offsets - [(-1, 0), (0, 1)], # individual 1: left, right keypoints (x,y) offsets - ] - for i in range(n_individuals): - for kpt in range(1, n_keypoints): - position[:, 0, kpt, i] = ( - position[:, 0, 0, i] + offsets[i][kpt - 1][0] - ) - position[:, 1, kpt, i] = ( - position[:, 1, 0, i] + offsets[i][kpt - 1][1] - ) - - # build an array of confidence values, all 0.9 - confidence = np.full((n_frames, n_keypoints, n_individuals), 0.9) - # set 5 low-confidence values - # - set 3 confidence values for individual id_0's centroid to 0.1 - # - set 2 confidence values for individual id_1's centroid to 0.1 - idx_start = 2 - confidence[idx_start : idx_start + 3, 0, 0] = 0.1 - confidence[idx_start : idx_start + 2, 0, 1] = 0.1 - - return {"position": position, "confidence": confidence} - - -@pytest.fixture -def valid_poses_dataset_uniform_linear_motion( - valid_poses_array_uniform_linear_motion, -): - """Return a valid poses dataset for two individuals moving in uniform - linear motion, with 5 frames with low confidence values and time in frames. - """ - dim_names = ValidPosesDataset.DIM_NAMES - - position_array = valid_poses_array_uniform_linear_motion["position"] - confidence_array = valid_poses_array_uniform_linear_motion["confidence"] - - n_frames, _, _, n_individuals = position_array.shape - - return xr.Dataset( - data_vars={ - "position": xr.DataArray(position_array, dims=dim_names), - "confidence": xr.DataArray( - confidence_array, dims=dim_names[:1] + dim_names[2:] - ), - }, - coords={ - dim_names[0]: np.arange(n_frames), - dim_names[1]: ["x", "y"], - dim_names[2]: ["centroid", "left", "right"], - dim_names[3]: [f"id_{i}" for i in range(1, n_individuals + 1)], - }, - attrs={ - "fps": None, - "time_unit": "frames", - "source_software": "test", - "source_file": "test_poses.h5", - "ds_type": "poses", - }, - ) - - -@pytest.fixture -def valid_poses_dataset_uniform_linear_motion_with_nans( - valid_poses_dataset_uniform_linear_motion, -): - """Return a valid poses dataset with NaN values in the position array. - - Specifically, we will introducde: - - 1 NaN value in the centroid keypoint of individual id_1 at time=0 - - 5 NaN values in the left keypoint of individual id_1 (frames 3-7) - - 10 NaN values in the right keypoint of individual id_1 (all frames) - """ - valid_poses_dataset_uniform_linear_motion.position.loc[ - { - "individuals": "id_1", - "keypoints": "centroid", - "time": 0, - } - ] = np.nan - valid_poses_dataset_uniform_linear_motion.position.loc[ - { - "individuals": "id_1", - "keypoints": "left", - "time": slice(3, 7), - } - ] = np.nan - valid_poses_dataset_uniform_linear_motion.position.loc[ - { - "individuals": "id_1", - "keypoints": "right", - } - ] = np.nan - return valid_poses_dataset_uniform_linear_motion - - -# -------------------- Invalid datasets fixtures ------------------------------ -@pytest.fixture -def not_a_dataset(): - """Return data that is not a pose tracks dataset.""" - return [1, 2, 3] - - -@pytest.fixture -def empty_dataset(): - """Return an empty pose tracks dataset.""" - return xr.Dataset() - - -@pytest.fixture -def missing_var_poses_dataset(valid_poses_dataset): - """Return a poses dataset missing position variable.""" - return valid_poses_dataset.drop_vars("position") - - -@pytest.fixture -def missing_var_bboxes_dataset(valid_bboxes_dataset): - """Return a bboxes dataset missing position variable.""" - return valid_bboxes_dataset.drop_vars("position") - - -@pytest.fixture -def missing_two_vars_bboxes_dataset(valid_bboxes_dataset): - """Return a bboxes dataset missing position and shape variables.""" - return valid_bboxes_dataset.drop_vars(["position", "shape"]) - - -@pytest.fixture -def missing_dim_poses_dataset(valid_poses_dataset): - """Return a poses dataset missing the time dimension.""" - return valid_poses_dataset.rename({"time": "tame"}) - - -@pytest.fixture -def missing_dim_bboxes_dataset(valid_bboxes_dataset): - """Return a bboxes dataset missing the time dimension.""" - return valid_bboxes_dataset.rename({"time": "tame"}) - - -@pytest.fixture -def missing_two_dims_bboxes_dataset(valid_bboxes_dataset): - """Return a bboxes dataset missing the time and space dimensions.""" - return valid_bboxes_dataset.rename({"time": "tame", "space": "spice"}) - - -# --------------------------- Kinematics fixtures --------------------------- -@pytest.fixture(params=["displacement", "velocity", "acceleration"]) -def kinematic_property(request): - """Return a kinematic property.""" - return request.param - - -# ---------------- VIA tracks CSV file fixtures ---------------------------- -@pytest.fixture -def via_tracks_csv_with_invalid_header(tmp_path): - """Return the file path for a file with invalid header.""" - file_path = tmp_path / "invalid_via_tracks.csv" - with open(file_path, "w") as f: - f.write("filename,file_size,file_attributes\n") - f.write("1,2,3") - return file_path - - -@pytest.fixture -def via_tracks_csv_with_valid_header(tmp_path): - file_path = tmp_path / "sample_via_tracks.csv" - with open(file_path, "w") as f: - f.write( - "filename," - "file_size," - "file_attributes," - "region_count," - "region_id," - "region_shape_attributes," - "region_attributes" - ) - f.write("\n") - return file_path - - -@pytest.fixture -def frame_number_in_file_attribute_not_integer( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with invalid frame - number defined as file_attribute. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test_frame_A.png," - "26542080," - '"{""clip"":123, ""frame"":""FOO""}",' # frame number is a string - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""71""}"' - ) - return file_path - - -@pytest.fixture -def frame_number_in_filename_wrong_pattern( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with invalid frame - number defined in the frame's filename. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test_frame_1.png," # frame not zero-padded - "26542080," - '"{""clip"":123}",' - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""71""}"' - ) - return file_path - - -@pytest.fixture -def more_frame_numbers_than_filenames( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with more - frame numbers than filenames. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test.png," - "26542080," - '"{""clip"":123, ""frame"":24}",' - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""71""}"' - ) - f.write("\n") - f.write( - "04.09.2023-04-Right_RE_test.png," # same filename as previous row - "26542080," - '"{""clip"":123, ""frame"":25}",' # different frame number - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""71""}"' - ) - return file_path - - -@pytest.fixture -def less_frame_numbers_than_filenames( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with with less - frame numbers than filenames. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test_A.png," - "26542080," - '"{""clip"":123, ""frame"":24}",' - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""71""}"' - ) - f.write("\n") - f.write( - "04.09.2023-04-Right_RE_test_B.png," # different filename - "26542080," - '"{""clip"":123, ""frame"":24}",' # same frame as previous row - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""71""}"' - ) - return file_path - - -@pytest.fixture -def region_shape_attribute_not_rect( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with invalid shape in - region_shape_attributes. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test_frame_01.png," - "26542080," - '"{""clip"":123}",' - "1," - "0," - '"{""name"":""circle"",""cx"":1049,""cy"":1006,""r"":125}",' - '"{""track"":""71""}"' - ) # annotation of circular shape - return file_path - - -@pytest.fixture -def region_shape_attribute_missing_x( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with missing `x` key in - region_shape_attributes. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test_frame_01.png," - "26542080," - '"{""clip"":123}",' - "1," - "0," - '"{""name"":""rect"",""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""71""}"' - ) # region_shape_attributes is missing ""x"" key - return file_path - - -@pytest.fixture -def region_attribute_missing_track( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with missing track - attribute in region_attributes. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test_frame_01.png," - "26542080," - '"{""clip"":123}",' - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""foo"":""71""}"' # missing ""track"" - ) - return file_path - - -@pytest.fixture -def track_id_not_castable_as_int( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with a track ID - attribute not castable as an integer. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test_frame_01.png," - "26542080," - '"{""clip"":123}",' - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""FOO""}"' # ""track"" not castable as int - ) - return file_path - - -@pytest.fixture -def track_ids_not_unique_per_frame( - via_tracks_csv_with_valid_header, -): - """Return the file path for a VIA tracks .csv file with a track ID - that appears twice in the same frame. - """ - file_path = via_tracks_csv_with_valid_header - with open(file_path, "a") as f: - f.write( - "04.09.2023-04-Right_RE_test_frame_01.png," - "26542080," - '"{""clip"":123}",' - "1," - "0," - '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' - '"{""track"":""71""}"' - ) - f.write("\n") - f.write( - "04.09.2023-04-Right_RE_test_frame_01.png," - "26542080," - '"{""clip"":123}",' - "1," - "0," - '"{""name"":""rect"",""x"":2567.627,""y"":466.888,""width"":40,""height"":37}",' - '"{""track"":""71""}"' # same track ID as the previous row - ) - return file_path - - -# ----------------- Helpers fixture ----------------- -class Helpers: - """Generic helper methods for ``movement`` test modules.""" - - @staticmethod - def count_nans(da): - """Count number of NaNs in a DataArray.""" - return da.isnull().sum().item() - - @staticmethod - def count_consecutive_nans(da): - """Count occurrences of consecutive NaNs in a DataArray.""" - return (da.isnull().astype(int).diff("time") == 1).sum().item() - - -@pytest.fixture -def helpers(): - """Return an instance of the ``Helpers`` class.""" - return Helpers - - -# --------- movement dataset assertion fixtures --------- -class MovementDatasetAsserts: - """Class for asserting valid ``movement`` poses or bboxes datasets.""" - - @staticmethod - def valid_dataset(dataset, expected_values): - """Assert the dataset is a proper ``movement`` Dataset. - - Parameters - ---------- - dataset : xr.Dataset - The dataset to validate. - expected_values : dict - A dictionary containing the expected values for the dataset. - It must contain the following keys: - - - dim_names: list of expected dimension names as defined in - movement.validators.datasets - - vars_dims: dictionary of data variable names and the - corresponding dimension sizes - - Optional keys include: - - - file_path: Path to the source file - - fps: int, frames per second - - source_software: str, name of the software used to generate - the dataset - - """ - expected_dim_names = expected_values.get("dim_names") - expected_file_path = expected_values.get("file_path") - assert isinstance(dataset, xr.Dataset) - # Expected variables are present and of right shape/type - for var, ndim in expected_values.get("vars_dims").items(): - data_var = dataset.get(var) - assert isinstance(data_var, xr.DataArray) - assert data_var.ndim == ndim - position_shape = dataset.position.shape - # Confidence has the same shape as position, except for the space dim - assert ( - dataset.confidence.shape == position_shape[:1] + position_shape[2:] - ) - # Check the dims and coords - expected_dim_length_dict = dict( - zip(expected_dim_names, position_shape, strict=True) - ) - assert expected_dim_length_dict == dataset.sizes - # Check the coords - for dim in expected_dim_names[1:]: - assert all(isinstance(s, str) for s in dataset.coords[dim].values) - assert all(coord in dataset.coords["space"] for coord in ["x", "y"]) - # Check the metadata attributes - assert dataset.source_file == ( - expected_file_path.as_posix() - if expected_file_path is not None - else None - ) - assert dataset.source_software == expected_values.get( - "source_software" - ) - assert dataset.fps == expected_values.get("fps") - - -@pytest.fixture -def movement_dataset_asserts(): - """Return an instance of the ``MovementDatasetAsserts`` class.""" - return MovementDatasetAsserts diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py new file mode 100644 index 00000000..3ee48a14 --- /dev/null +++ b/tests/fixtures/datasets.py @@ -0,0 +1,356 @@ +"""Valid and invalid data fixtures.""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset + + +# -------------------- Valid bboxes datasets and arrays -------------------- +@pytest.fixture +def valid_bboxes_arrays_all_zeros(): + """Return a dictionary of valid zero arrays (in terms of shape) for a + ValidBboxesDataset. + """ + # define the shape of the arrays + n_frames, n_space, n_individuals = (10, 2, 2) + + # build a valid array for position or shape with all zeros + valid_bbox_array_all_zeros = np.zeros((n_frames, n_space, n_individuals)) + + # return as a dict + return { + "position": valid_bbox_array_all_zeros, + "shape": valid_bbox_array_all_zeros, + "individual_names": ["id_" + str(id) for id in range(n_individuals)], + } + + +@pytest.fixture +def valid_bboxes_arrays(): + """Return a dictionary of valid arrays for a + ValidBboxesDataset representing a uniform linear motion. + + It represents 2 individuals for 10 frames, in 2D space. + - Individual 0 moves along the x=y line from the origin. + - Individual 1 moves along the x=-y line line from the origin. + + All confidence values are set to 0.9 except the following which are set + to 0.1: + - Individual 0 at frames 2, 3, 4 + - Individual 1 at frames 2, 3 + """ + # define the shape of the arrays + n_frames, n_space, n_individuals = (10, 2, 2) + + # build a valid array for position + # make bbox with id_i move along x=((-1)**(i))*y line from the origin + # if i is even: along x = y line + # if i is odd: along x = -y line + # moving one unit along each axis in each frame + position = np.zeros((n_frames, n_space, n_individuals)) + for i in range(n_individuals): + position[:, 0, i] = np.arange(n_frames) + position[:, 1, i] = (-1) ** i * np.arange(n_frames) + + # build a valid array for constant bbox shape (60, 40) + constant_shape = (60, 40) # width, height in pixels + shape = np.tile(constant_shape, (n_frames, n_individuals, 1)).transpose( + 0, 2, 1 + ) + + # build an array of confidence values, all 0.9 + confidence = np.full((n_frames, n_individuals), 0.9) + + # set 5 low-confidence values + # - set 3 confidence values for bbox id_0 to 0.1 + # - set 2 confidence values for bbox id_1 to 0.1 + idx_start = 2 + confidence[idx_start : idx_start + 3, 0] = 0.1 + confidence[idx_start : idx_start + 2, 1] = 0.1 + + return { + "position": position, + "shape": shape, + "confidence": confidence, + } + + +@pytest.fixture +def valid_bboxes_dataset(valid_bboxes_arrays): + """Return a valid bboxes dataset for two individuals moving in uniform + linear motion, with 5 frames with low confidence values and time in frames. + """ + dim_names = ValidBboxesDataset.DIM_NAMES + + position_array = valid_bboxes_arrays["position"] + shape_array = valid_bboxes_arrays["shape"] + confidence_array = valid_bboxes_arrays["confidence"] + + n_frames, n_individuals, _ = position_array.shape + + return xr.Dataset( + data_vars={ + "position": xr.DataArray(position_array, dims=dim_names), + "shape": xr.DataArray(shape_array, dims=dim_names), + "confidence": xr.DataArray( + confidence_array, dims=dim_names[:1] + dim_names[2:] + ), + }, + coords={ + dim_names[0]: np.arange(n_frames), + dim_names[1]: ["x", "y"], + dim_names[2]: [f"id_{id}" for id in range(n_individuals)], + }, + attrs={ + "fps": None, + "time_unit": "frames", + "source_software": "test", + "source_file": "test_bboxes.csv", + "ds_type": "bboxes", + }, + ) + + +@pytest.fixture +def valid_bboxes_dataset_in_seconds(valid_bboxes_dataset): + """Return a valid bboxes dataset with time in seconds. + + The origin of time is assumed to be time = frame 0 = 0 seconds. + """ + fps = 60 + valid_bboxes_dataset["time"] = valid_bboxes_dataset.time / fps + valid_bboxes_dataset.attrs["time_unit"] = "seconds" + valid_bboxes_dataset.attrs["fps"] = fps + return valid_bboxes_dataset + + +@pytest.fixture +def valid_bboxes_dataset_with_nan(valid_bboxes_dataset): + """Return a valid bboxes dataset with NaN values in the position array.""" + # Set 3 NaN values in the position array for id_0 + valid_bboxes_dataset.position.loc[ + {"individuals": "id_0", "time": [3, 7, 8]} + ] = np.nan + return valid_bboxes_dataset + + +# -------------------- Valid poses datasets and arrays -------------------- +@pytest.fixture +def valid_poses_arrays(): + """Return a dictionary of valid arrays for a + ValidPosesDataset representing a uniform linear motion. + + This fixture is a factory of fixtures. + Depending on the ``array_type`` requested (``multi_individual_array``, + ``single_keypoint_array``, or ``single_individual_array``), + the returned array can represent up to 2 individuals with + up to 3 keypoints, moving at constant velocity for 10 frames in 2D space. + Default is a ``multi_individual_array`` (2 individuals, 3 keypoints each). + At each frame the individuals cover a distance of sqrt(2) in x-y space. + Specifically: + - Individual 0 moves along the x=y line from the origin. + - Individual 1 moves along the x=-y line line from the origin. + + All confidence values for all keypoints are set to 0.9 except + for the "centroid" (index=0) at the following frames, + which are set to 0.1: + - Individual 0 at frames 2, 3, 4 + - Individual 1 at frames 2, 3 + """ + + def _valid_poses_arrays(array_type): + """Return a dictionary of valid arrays for a ValidPosesDataset.""" + # Unless specified, default is a ``multi_individual_array`` with + # 10 frames, 3 keypoints, and 2 individuals in 2D space. + n_frames, n_space, n_keypoints, n_individuals = (10, 2, 3, 2) + + # define centroid (index=0) trajectory in position array + # for each individual, the centroid moves along + # the x=+/-y line, starting from the origin. + # - individual 0 moves along x = y line + # - individual 1 moves along x = -y line (if applicable) + # They move one unit along x and y axes in each frame + frames = np.arange(n_frames) + position = np.zeros((n_frames, n_space, n_keypoints, n_individuals)) + position[:, 0, 0, :] = frames[:, None] # reshape to (n_frames, 1) + position[:, 1, 0, 0] = frames + position[:, 1, 0, 1] = -frames + + # define trajectory of left and right keypoints + # for individual 0, at each timepoint: + # - the left keypoint (index=1) is at x_centroid, y_centroid + 1 + # - the right keypoint (index=2) is at x_centroid + 1, y_centroid + # for individual 1, at each timepoint: + # - the left keypoint (index=1) is at x_centroid - 1, y_centroid + # - the right keypoint (index=2) is at x_centroid, y_centroid + 1 + offsets = [ + [ + (0, 1), + (1, 0), + ], # individual 0: left, right keypoints (x,y) offsets + [ + (-1, 0), + (0, 1), + ], # individual 1: left, right keypoints (x,y) offsets + ] + for i in range(n_individuals): + for kpt in range(1, n_keypoints): + position[:, 0, kpt, i] = ( + position[:, 0, 0, i] + offsets[i][kpt - 1][0] + ) + position[:, 1, kpt, i] = ( + position[:, 1, 0, i] + offsets[i][kpt - 1][1] + ) + + # build an array of confidence values, all 0.9 + confidence = np.full((n_frames, n_keypoints, n_individuals), 0.9) + # set 5 low-confidence values + # - set 3 confidence values for individual id_0's centroid to 0.1 + # - set 2 confidence values for individual id_1's centroid to 0.1 + idx_start = 2 + confidence[idx_start : idx_start + 3, 0, 0] = 0.1 + confidence[idx_start : idx_start + 2, 0, 1] = 0.1 + + if array_type == "single_keypoint_array": + # return only the centroid keypoint + position = position[:, :, :1, :] + confidence = confidence[:, :1, :] + elif array_type == "single_individual_array": + # return only the first individual + position = position[:, :, :, :1] + confidence = confidence[:, :, :1] + return {"position": position, "confidence": confidence} + + return _valid_poses_arrays + + +@pytest.fixture +def valid_poses_dataset(valid_poses_arrays, request): + """Return a valid poses dataset. + + Depending on the ``array_type`` requested (``multi_individual_array``, + ``single_keypoint_array``, or ``single_individual_array``), + the dataset can represent up to 2 individuals ("id_0" and "id_1") + with up to 3 keypoints ("centroid", "left", "right") + moving in uniform linear motion for 10 frames in 2D space. + Default is a ``multi_individual_array`` (2 individuals, 3 keypoints each). + See the ``valid_poses_arrays`` fixture for details. + """ + dim_names = ValidPosesDataset.DIM_NAMES + # create a multi_individual_array by default unless overridden via param + try: + array_type = request.param + except AttributeError: + array_type = "multi_individual_array" + poses_array = valid_poses_arrays(array_type) + position_array = poses_array["position"] + confidence_array = poses_array["confidence"] + n_frames, _, n_keypoints, n_individuals = position_array.shape + return xr.Dataset( + data_vars={ + "position": xr.DataArray(position_array, dims=dim_names), + "confidence": xr.DataArray( + confidence_array, dims=dim_names[:1] + dim_names[2:] + ), + }, + coords={ + dim_names[0]: np.arange(n_frames), + dim_names[1]: ["x", "y"], + dim_names[2]: ["centroid", "left", "right"][:n_keypoints], + dim_names[3]: [f"id_{i}" for i in range(n_individuals)], + }, + attrs={ + "fps": None, + "time_unit": "frames", + "source_software": "test", + "source_file": "test_poses.h5", + "ds_type": "poses", + }, + ) + + +@pytest.fixture +def valid_poses_dataset_with_nan(valid_poses_dataset): + """Return a valid poses dataset with NaNs introduced in the position array. + + Using ``valid_poses_dataset`` as the base dataset, + the following NaN values are introduced: + - Individual "id_0": + - 3 NaNs in the centroid keypoint of individual id_0 (frames 3, 7, 8) + - 1 NaN in the left keypoint of individual id_0 at time=0 + - 10 NaNs in the right keypoint of individual id_0 (all frames) + - Individual "id_1" has no missing values. + """ + valid_poses_dataset.position.loc[ + {"individuals": "id_0", "keypoints": "centroid", "time": [3, 7, 8]} + ] = np.nan + valid_poses_dataset.position.loc[ + {"individuals": "id_0", "keypoints": "left", "time": 0} + ] = np.nan + valid_poses_dataset.position.loc[ + {"individuals": "id_0", "keypoints": "right"} + ] = np.nan + return valid_poses_dataset + + +@pytest.fixture +def valid_dlc_poses_df(): + """Return a valid DLC-style poses DataFrame.""" + return pd.read_hdf(pytest.DATA_PATHS.get("DLC_single-wasp.predictions.h5")) + + +# -------------------- Invalid bboxes datasets -------------------- +@pytest.fixture +def missing_var_bboxes_dataset(valid_bboxes_dataset): + """Return a bboxes dataset missing the required position variable.""" + return valid_bboxes_dataset.drop_vars("position") + + +@pytest.fixture +def missing_two_vars_bboxes_dataset(valid_bboxes_dataset): + """Return a bboxes dataset missing the required position + and shape variables. + """ + return valid_bboxes_dataset.drop_vars(["position", "shape"]) + + +@pytest.fixture +def missing_dim_bboxes_dataset(valid_bboxes_dataset): + """Return a bboxes dataset missing the required time dimension.""" + return valid_bboxes_dataset.rename({"time": "tame"}) + + +@pytest.fixture +def missing_two_dims_bboxes_dataset(valid_bboxes_dataset): + """Return a bboxes dataset missing the required time + and space dimensions. + """ + return valid_bboxes_dataset.rename({"time": "tame", "space": "spice"}) + + +# -------------------- Invalid poses datasets -------------------- +@pytest.fixture +def not_a_dataset(): + """Return data that is not a pose tracks dataset.""" + return [1, 2, 3] + + +@pytest.fixture +def empty_dataset(): + """Return an empty pose tracks dataset.""" + return xr.Dataset() + + +@pytest.fixture +def missing_var_poses_dataset(valid_poses_dataset): + """Return a poses dataset missing the required position variable.""" + return valid_poses_dataset.drop_vars("position") + + +@pytest.fixture +def missing_dim_poses_dataset(valid_poses_dataset): + """Return a poses dataset missing the required time dimension.""" + return valid_poses_dataset.rename({"time": "tame"}) diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py new file mode 100644 index 00000000..0ad892ea --- /dev/null +++ b/tests/fixtures/files.py @@ -0,0 +1,440 @@ +"""Valid and invalid file fixtures.""" + +import os +from pathlib import Path +from unittest.mock import mock_open, patch + +import h5py +import pytest + + +# ------------------ Generic file fixtures ---------------------- +@pytest.fixture +def unreadable_file(tmp_path): + """Return a dictionary containing the file path and + expected permission for an unreadable .h5 file. + """ + file_path = tmp_path / "unreadable.h5" + file_mock = mock_open() + file_mock.return_value.read.side_effect = PermissionError + with ( + patch("builtins.open", side_effect=file_mock), + patch.object(Path, "exists", return_value=True), + ): + yield { + "file_path": file_path, + "expected_permission": "r", + } + + +@pytest.fixture +def unwriteable_file(tmp_path): + """Return a dictionary containing the file path and + expected permission for an unwriteable .h5 file. + """ + unwriteable_dir = tmp_path / "no_write" + unwriteable_dir.mkdir() + original_access = os.access + + def mock_access(path, mode): + if path == unwriteable_dir and mode == os.W_OK: + return False + # Ensure that the original access function is called + # for all other cases + return original_access(path, mode) + + with patch("os.access", side_effect=mock_access): + file_path = unwriteable_dir / "unwriteable.h5" + yield { + "file_path": file_path, + "expected_permission": "w", + } + + +@pytest.fixture +def wrong_extension_file(tmp_path): + """Return a dictionary containing the file path, + expected permission, and expected suffix for a file + with unsupported extension. + """ + file_path = tmp_path / "wrong_extension.txt" + with open(file_path, "w") as f: + f.write("") + return { + "file_path": file_path, + "expected_permission": "r", + "expected_suffix": ["h5", "csv"], + } + + +@pytest.fixture +def nonexistent_file(tmp_path): + """Return a dictionary containing the file path and + expected permission for a nonexistent file. + """ + file_path = tmp_path / "nonexistent.h5" + return { + "file_path": file_path, + "expected_permission": "r", + } + + +@pytest.fixture +def no_dataframe_h5_file(tmp_path): + """Return a dictionary containing the file path and + expected datasets for a .h5 file that lacks the + dataset "dataframe". + """ + file_path = tmp_path / "no_dataframe.h5" + with h5py.File(file_path, "w") as f: + f.create_dataset("data_in_list", data=[1, 2, 3]) + return { + "file_path": file_path, + "expected_datasets": ["dataframe"], + } + + +@pytest.fixture +def fake_h5_file(tmp_path): + """Return a dictionary containing the file path, + expected permission, and expected datasets for + a file with .h5 extension that is not in HDF5 format. + """ + file_path = tmp_path / "fake.h5" + with open(file_path, "w") as f: + f.write("") + return { + "file_path": file_path, + "expected_datasets": ["dataframe"], + "expected_permission": "w", + } + + +@pytest.fixture +def invalid_single_individual_csv_file(tmp_path): + """Return the file path for a fake single-individual .csv file.""" + file_path = tmp_path / "fake_single_individual.csv" + with open(file_path, "w") as f: + f.write("scorer,columns\nsome,columns\ncoords,columns\n") + f.write("1,2") + return file_path + + +@pytest.fixture +def invalid_multi_individual_csv_file(tmp_path): + """Return the file path for a fake multi-individual .csv file.""" + file_path = tmp_path / "fake_multi_individual.csv" + with open(file_path, "w") as f: + f.write( + "scorer,columns\nindividuals,columns\nbodyparts,columns\nsome,columns\n" + ) + f.write("1,2") + return file_path + + +@pytest.fixture +def wrong_extension_new_file(tmp_path): + """Return the file path for a new file with unsupported extension.""" + return tmp_path / "wrong_extension_new_file.txt" + + +@pytest.fixture +def directory(tmp_path): + """Return a dictionary containing the file path and + expected permission for a directory. + """ + file_path = tmp_path / "directory" + file_path.mkdir() + return { + "file_path": file_path, + "expected_permission": "r", + } + + +@pytest.fixture +def new_h5_file(tmp_path): + """Return the file path for a new .h5 file.""" + return tmp_path / "new_file.h5" + + +@pytest.fixture +def new_csv_file(tmp_path): + """Return the file path for a new .csv file.""" + return tmp_path / "new_file.csv" + + +# ---------------- Anipose file fixtures ---------------------------- +@pytest.fixture +def missing_keypoint_columns_anipose_csv_file(tmp_path): + """Return the file path for a single-individual anipose .csv file + missing the z-coordinate of keypoint kp0 "kp0_z". + """ + file_path = tmp_path / "missing_keypoint_columns.csv" + columns = [ + "fnum", + "center_0", + "center_1", + "center_2", + "M_00", + "M_01", + "M_02", + "M_10", + "M_11", + "M_12", + "M_20", + "M_21", + "M_22", + ] + # Here we are missing kp0_z: + columns.extend(["kp0_x", "kp0_y", "kp0_score", "kp0_error", "kp0_ncams"]) + with open(file_path, "w") as f: + f.write(",".join(columns)) + f.write("\n") + f.write(",".join(["1"] * len(columns))) + return file_path + + +@pytest.fixture +def spurious_column_anipose_csv_file(tmp_path): + """Return the file path for a single-individual anipose .csv file + with an unexpected column. + """ + file_path = tmp_path / "spurious_column.csv" + columns = [ + "fnum", + "center_0", + "center_1", + "center_2", + "M_00", + "M_01", + "M_02", + "M_10", + "M_11", + "M_12", + "M_20", + "M_21", + "M_22", + ] + columns.extend(["funny_column"]) + with open(file_path, "w") as f: + f.write(",".join(columns)) + f.write("\n") + f.write(",".join(["1"] * len(columns))) + return file_path + + +# ---------------- SLEAP file fixtures ---------------------------- +@pytest.fixture( + params=[ + "SLEAP_single-mouse_EPM.analysis.h5", + "SLEAP_single-mouse_EPM.predictions.slp", + "SLEAP_three-mice_Aeon_proofread.analysis.h5", + "SLEAP_three-mice_Aeon_proofread.predictions.slp", + "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5", + "SLEAP_three-mice_Aeon_mixed-labels.predictions.slp", + ] +) +def sleap_file(request): + """Return the file path for a SLEAP .h5 or .slp file.""" + return pytest.DATA_PATHS.get(request.param) + + +# ---------------- VIA tracks CSV file fixtures ---------------------------- +via_tracks_csv_file_valid_header = ( + "filename,file_size,file_attributes,region_count," + "region_id,region_shape_attributes,region_attributes\n" +) + + +@pytest.fixture +def invalid_via_tracks_csv_file(tmp_path, request): + """Return the file path for an invalid VIA tracks .csv file.""" + + def _invalid_via_tracks_csv_file(invalid_content): + file_path = tmp_path / "invalid_via_tracks.csv" + with open(file_path, "w") as f: + f.write(request.getfixturevalue(invalid_content)) + return file_path + + return _invalid_via_tracks_csv_file + + +@pytest.fixture +def via_invalid_header(): + """Return the content of a VIA tracks .csv file with invalid header.""" + return "filename,file_size,file_attributes\n1,2,3" + + +@pytest.fixture +def via_frame_number_in_file_attribute_not_integer(): + """Return the content of a VIA tracks .csv file with invalid frame + number defined as file_attribute. + """ + return ( + via_tracks_csv_file_valid_header + + "04.09.2023-04-Right_RE_test_frame_A.png," + "26542080," + '"{""clip"":123, ""frame"":""FOO""}",' # frame number is a string + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + + +@pytest.fixture +def via_frame_number_in_filename_wrong_pattern(): + """Return the content of a VIA tracks .csv file with invalid frame + number defined in the frame's filename. + """ + return ( + via_tracks_csv_file_valid_header + + "04.09.2023-04-Right_RE_test_frame_1.png," # frame not zero-padded + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + + +@pytest.fixture +def via_more_frame_numbers_than_filenames(): + """Return the content of a VIA tracks .csv file with more + frame numbers than filenames. + """ + return ( + via_tracks_csv_file_valid_header + "04.09.2023-04-Right_RE_test.png," + "26542080," + '"{""clip"":123, ""frame"":24}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + "\n" + "04.09.2023-04-Right_RE_test.png," # same filename as previous row + "26542080," + '"{""clip"":123, ""frame"":25}",' # different frame number + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + + +@pytest.fixture +def via_less_frame_numbers_than_filenames(): + """Return the content of a VIA tracks .csv file with with less + frame numbers than filenames. + """ + return ( + via_tracks_csv_file_valid_header + "04.09.2023-04-Right_RE_test_A.png," + "26542080," + '"{""clip"":123, ""frame"":24}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + "\n" + "04.09.2023-04-Right_RE_test_B.png," # different filename + "26542080," + '"{""clip"":123, ""frame"":24}",' # same frame as previous row + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + + +@pytest.fixture +def via_region_shape_attribute_not_rect(): + """Return the content of a VIA tracks .csv file with invalid shape in + region_shape_attributes. + """ + return ( + via_tracks_csv_file_valid_header + + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""circle"",""cx"":1049,""cy"":1006,""r"":125}",' + '"{""track"":""71""}"' + ) # annotation of circular shape + + +@pytest.fixture +def via_region_shape_attribute_missing_x(): + """Return the content of a VIA tracks .csv file with missing `x` key in + region_shape_attributes. + """ + return ( + via_tracks_csv_file_valid_header + + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) # region_shape_attributes is missing ""x"" key + + +@pytest.fixture +def via_region_attribute_missing_track(): + """Return the content of a VIA tracks .csv file with missing track + attribute in region_attributes. + """ + return ( + via_tracks_csv_file_valid_header + + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""foo"":""71""}"' # missing ""track"" + ) + + +@pytest.fixture +def via_track_id_not_castable_as_int(): + """Return the content of a VIA tracks .csv file with a track ID + attribute not castable as an integer. + """ + return ( + via_tracks_csv_file_valid_header + + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""FOO""}"' # ""track"" not castable as int + ) + + +@pytest.fixture +def via_track_ids_not_unique_per_frame(): + """Return the content of a VIA tracks .csv file with a track ID + that appears twice in the same frame. + """ + return ( + via_tracks_csv_file_valid_header + + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + "\n" + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":2567.627,""y"":466.888,""width"":40,""height"":37}",' + '"{""track"":""71""}"' # same track ID as the previous row + ) diff --git a/tests/fixtures/helpers.py b/tests/fixtures/helpers.py new file mode 100644 index 00000000..32b5ee1d --- /dev/null +++ b/tests/fixtures/helpers.py @@ -0,0 +1,97 @@ +"""Helpers fixture for ``movement`` test modules.""" + +import pytest +import xarray as xr + + +class Helpers: + """General helper methods for ``movement`` test modules.""" + + @staticmethod + def assert_valid_dataset(dataset, expected_values): + """Assert the dataset is a valid ``movement`` Dataset. + + The validation includes: + - checking the dataset is an xarray Dataset + - checking the expected variables are present and are of the right + shape and type + - checking the confidence array shape matches the position array + - checking the dimensions and coordinates against the expected values + - checking the coordinates' names and size + - checking the metadata attributes + + Parameters + ---------- + dataset : xr.Dataset + The dataset to validate. + expected_values : dict + A dictionary containing the expected values for the dataset. + It must contain the following keys: + + - dim_names: list of expected dimension names as defined in + movement.validators.datasets + - vars_dims: dictionary of data variable names and the + corresponding dimension sizes + + Optional keys include: + + - file_path: Path to the source file + - fps: int, frames per second + - source_software: str, name of the software used to generate + the dataset + + """ + # Check dataset is an xarray Dataset + assert isinstance(dataset, xr.Dataset) + + # Expected variables are present and of right shape/type + for var, ndim in expected_values.get("vars_dims").items(): + data_var = dataset.get(var) + assert isinstance(data_var, xr.DataArray) + assert data_var.ndim == ndim + position_shape = dataset.position.shape + + # Confidence has the same shape as position, except for the space dim + assert ( + dataset.confidence.shape == position_shape[:1] + position_shape[2:] + ) + + # Check the dims and coords + expected_dim_names = expected_values.get("dim_names") + expected_dim_length_dict = dict( + zip(expected_dim_names, position_shape, strict=True) + ) + assert expected_dim_length_dict == dataset.sizes + + # Check the coords + for dim in expected_dim_names[1:]: + assert all(isinstance(s, str) for s in dataset.coords[dim].values) + assert all(coord in dataset.coords["space"] for coord in ["x", "y"]) + + # Check the metadata attributes + expected_file_path = expected_values.get("file_path") + assert dataset.source_file == ( + expected_file_path.as_posix() + if expected_file_path is not None + else None + ) + assert dataset.source_software == expected_values.get( + "source_software" + ) + assert dataset.fps == expected_values.get("fps") + + @staticmethod + def count_nans(da): + """Count number of NaNs in a DataArray.""" + return da.isnull().sum().item() + + @staticmethod + def count_consecutive_nans(da): + """Count occurrences of consecutive NaNs in a DataArray.""" + return (da.isnull().astype(int).diff("time") != 0).sum().item() + + +@pytest.fixture +def helpers(): + """Return an instance of the ``Helpers`` class.""" + return Helpers diff --git a/tests/test_integration/test_io.py b/tests/test_integration/test_io.py index 50f03933..2e1bf36a 100644 --- a/tests/test_integration/test_io.py +++ b/tests/test_integration/test_io.py @@ -15,13 +15,13 @@ def dlc_output_file(self, request, tmp_path): """Return the output file path for a DLC .h5 or .csv file.""" return tmp_path / request.param - def test_load_and_save_to_dlc_style_df(self, dlc_style_df): + def test_load_and_save_to_dlc_style_df(self, valid_dlc_poses_df): """Test that loading pose tracks from a DLC-style DataFrame and converting back to a DataFrame returns the same data values. """ - ds = load_poses.from_dlc_style_df(dlc_style_df) + ds = load_poses.from_dlc_style_df(valid_dlc_poses_df) df = save_poses.to_dlc_style_df(ds, split_individuals=False) - np.testing.assert_allclose(df.values, dlc_style_df.values) + np.testing.assert_allclose(df.values, valid_dlc_poses_df.values) def test_save_and_load_dlc_file( self, dlc_output_file, valid_poses_dataset diff --git a/tests/test_integration/test_kinematics_vector_transform.py b/tests/test_integration/test_kinematics_vector_transform.py index c1a3ce91..feae39bb 100644 --- a/tests/test_integration/test_kinematics_vector_transform.py +++ b/tests/test_integration/test_kinematics_vector_transform.py @@ -9,11 +9,7 @@ @pytest.mark.parametrize( - "valid_dataset_uniform_linear_motion", - [ - "valid_poses_dataset_uniform_linear_motion", - "valid_bboxes_dataset", - ], + "valid_dataset", ["valid_poses_dataset", "valid_bboxes_dataset"] ) @pytest.mark.parametrize( "kinematic_variable, expected_kinematics_polar", @@ -56,15 +52,12 @@ ], ) def test_cart2pol_transform_on_kinematics( - valid_dataset_uniform_linear_motion, - kinematic_variable, - expected_kinematics_polar, - request, + valid_dataset, kinematic_variable, expected_kinematics_polar, request ): """Test transformation between Cartesian and polar coordinates with various kinematic properties. """ - ds = request.getfixturevalue(valid_dataset_uniform_linear_motion) + ds = request.getfixturevalue(valid_dataset) kinematic_array_cart = getattr(kin, f"compute_{kinematic_variable}")( ds.position ) diff --git a/tests/test_unit/test_filtering.py b/tests/test_unit/test_filtering.py index d51af1be..d97dcc8e 100644 --- a/tests/test_unit/test_filtering.py +++ b/tests/test_unit/test_filtering.py @@ -24,81 +24,176 @@ @pytest.mark.parametrize( - "valid_dataset_with_nan", - list_valid_datasets_with_nans, -) -@pytest.mark.parametrize( - "max_gap, expected_n_nans_in_position", [(None, 0), (0, 3), (1, 2), (2, 0)] + "valid_dataset", + list_all_valid_datasets, ) -def test_interpolate_over_time_on_position( - valid_dataset_with_nan, - max_gap, - expected_n_nans_in_position, - helpers, - request, -): - """Test that the number of NaNs decreases after linearly interpolating - over time and that the resulting number of NaNs is as expected - for different values of ``max_gap``. - """ - valid_dataset_in_frames = request.getfixturevalue(valid_dataset_with_nan) - - # Get position array with time unit in frames & seconds - # assuming 10 fps = 0.1 s per frame - valid_dataset_in_seconds = valid_dataset_in_frames.copy() - valid_dataset_in_seconds.coords["time"] = ( - valid_dataset_in_seconds.coords["time"] * 0.1 +class TestFilteringValidDataset: + """Test median and savgol filtering on valid datasets with/without NaNs.""" + + @pytest.mark.parametrize( + ("filter_func, filter_kwargs"), + [ + (median_filter, {"window": 3}), + (savgol_filter, {"window": 3, "polyorder": 2}), + ], ) - position = { - "frames": valid_dataset_in_frames.position, - "seconds": valid_dataset_in_seconds.position, - } - - # Count number of NaNs before and after interpolating position - n_nans_before = helpers.count_nans(position["frames"]) - n_nans_after_per_time_unit = {} - for time_unit in ["frames", "seconds"]: - # interpolate - position_interp = interpolate_over_time( - position[time_unit], method="linear", max_gap=max_gap + def test_filter_with_nans_on_position( + self, filter_func, filter_kwargs, valid_dataset, helpers, request + ): + """Test NaN behaviour of the median and SG filters. + Both filters should set all values to NaN if one element of the + sliding window is NaN. + """ + # Expected number of nans in the position array per individual + expected_nans_in_filtered_position_per_indiv = { + "valid_poses_dataset": [0, 0], # no nans in input + "valid_bboxes_dataset": [0, 0], # no nans in input + "valid_poses_dataset_with_nan": [38, 0], + "valid_bboxes_dataset_with_nan": [14, 0], + } + # Filter position + valid_input_dataset = request.getfixturevalue(valid_dataset) + position_filtered = filter_func( + valid_input_dataset.position, **filter_kwargs ) - # count nans - n_nans_after_per_time_unit[time_unit] = helpers.count_nans( - position_interp + # Compute n nans in position after filtering per individual + n_nans_after_filtering_per_indiv = [ + helpers.count_nans(position_filtered.isel(individuals=i)) + for i in range(valid_input_dataset.sizes["individuals"]) + ] + # Check number of nans per indiv is as expected + assert ( + n_nans_after_filtering_per_indiv + == expected_nans_in_filtered_position_per_indiv[valid_dataset] ) - # The number of NaNs should be the same for both datasets - # as max_gap is based on number of missing observations (NaNs) - assert ( - n_nans_after_per_time_unit["frames"] - == n_nans_after_per_time_unit["seconds"] + @pytest.mark.parametrize( + "override_kwargs, expected_exception", + [ + ({"mode": "nearest"}, does_not_raise()), + ({"axis": 1}, pytest.raises(ValueError)), + ({"mode": "nearest", "axis": 1}, pytest.raises(ValueError)), + ], ) + def test_savgol_filter_kwargs_override( + self, valid_dataset, override_kwargs, expected_exception, request + ): + """Test that overriding keyword arguments in the + Savitzky-Golay filter works, except for the ``axis`` argument, + which should raise a ValueError. + """ + with expected_exception: + savgol_filter( + request.getfixturevalue(valid_dataset).position, + window=3, + **override_kwargs, + ) - # The number of NaNs should decrease after interpolation - n_nans_after = n_nans_after_per_time_unit["frames"] - if max_gap == 0: - assert n_nans_after == n_nans_before - else: - assert n_nans_after < n_nans_before - # The number of NaNs after interpolating should be as expected - assert n_nans_after == ( - valid_dataset_in_frames.sizes["space"] - * valid_dataset_in_frames.sizes.get("keypoints", 1) - # in bboxes dataset there is no keypoints dimension - * expected_n_nans_in_position +@pytest.mark.parametrize( + "valid_dataset_with_nan", + list_valid_datasets_with_nans, +) +class TestFilteringValidDatasetWithNaNs: + """Test filtering functions on datasets with NaNs.""" + + @pytest.mark.parametrize( + "max_gap, expected_n_nans_in_position", + [(None, [22, 0]), (0, [28, 6]), (1, [26, 4]), (2, [22, 0])], + # expected total n nans: [poses, bboxes] + ) + def test_interpolate_over_time_on_position( + self, + valid_dataset_with_nan, + max_gap, + expected_n_nans_in_position, + helpers, + request, + ): + """Test that the number of NaNs decreases after linearly interpolating + over time and that the resulting number of NaNs is as expected + for different values of ``max_gap``. + """ + valid_dataset_in_frames = request.getfixturevalue( + valid_dataset_with_nan + ) + # Get position array with time unit in frames & seconds + # assuming 10 fps = 0.1 s per frame + valid_dataset_in_seconds = valid_dataset_in_frames.copy() + valid_dataset_in_seconds.coords["time"] = ( + valid_dataset_in_seconds.coords["time"] * 0.1 + ) + position = { + "frames": valid_dataset_in_frames.position, + "seconds": valid_dataset_in_seconds.position, + } + # Count number of NaNs + n_nans_after_per_time_unit = {} + for time_unit in ["frames", "seconds"]: + # interpolate + position_interp = interpolate_over_time( + position[time_unit], method="linear", max_gap=max_gap + ) + # count nans + n_nans_after_per_time_unit[time_unit] = helpers.count_nans( + position_interp + ) + # The number of NaNs should be the same for both datasets + # as max_gap is based on number of missing observations (NaNs) + assert ( + n_nans_after_per_time_unit["frames"] + == n_nans_after_per_time_unit["seconds"] + ) + # The number of NaNs after interpolating should be as expected + n_nans_after = n_nans_after_per_time_unit["frames"] + dataset_index = list_valid_datasets_with_nans.index( + valid_dataset_with_nan + ) + assert n_nans_after == expected_n_nans_in_position[dataset_index] + + @pytest.mark.parametrize( + "window", + [3, 5, 6, 10], # input data has 10 frames ) + @pytest.mark.parametrize("filter_func", [median_filter, savgol_filter]) + def test_filter_with_nans_on_position_varying_window( + self, valid_dataset_with_nan, window, filter_func, helpers, request + ): + """Test that the number of NaNs in the filtered position data + increases at most by the filter's window length minus one + multiplied by the number of consecutive NaNs in the input data. + """ + # Prepare kwargs per filter + kwargs = {"window": window} + if filter_func == savgol_filter: + kwargs["polyorder"] = 2 + # Filter position + valid_input_dataset = request.getfixturevalue(valid_dataset_with_nan) + position_filtered = filter_func( + valid_input_dataset.position, + **kwargs, + ) + # Count number of NaNs in the input and filtered position data + n_total_nans_initial = helpers.count_nans(valid_input_dataset.position) + n_consecutive_nans_initial = helpers.count_consecutive_nans( + valid_input_dataset.position + ) + n_total_nans_filtered = helpers.count_nans(position_filtered) + max_nans_increase = (window - 1) * n_consecutive_nans_initial + # Check that filtering does not reduce number of nans + assert n_total_nans_filtered >= n_total_nans_initial + # Check that the increase in nans is below the expected threshold + assert ( + n_total_nans_filtered - n_total_nans_initial <= max_nans_increase + ) @pytest.mark.parametrize( - "valid_dataset_no_nans, n_low_confidence_kpts", - [ - ("valid_poses_dataset", 20), - ("valid_bboxes_dataset", 5), - ], + "valid_dataset_no_nans", + list_valid_datasets_without_nans, ) def test_filter_by_confidence_on_position( - valid_dataset_no_nans, n_low_confidence_kpts, helpers, request + valid_dataset_no_nans, helpers, request ): """Test that points below the default 0.6 confidence threshold are converted to NaN. @@ -110,202 +205,13 @@ def test_filter_by_confidence_on_position( confidence=valid_input_dataset.confidence, threshold=0.6, ) - # Count number of NaNs in the full array n_nans = helpers.count_nans(position_filtered) - # expected number of nans for poses: # 5 timepoints * 2 individuals * 2 keypoints # Note: we count the number of nans in the array, so we multiply # the number of low confidence keypoints by the number of # space dimensions + n_low_confidence_kpts = 5 assert isinstance(position_filtered, xr.DataArray) assert n_nans == valid_input_dataset.sizes["space"] * n_low_confidence_kpts - - -@pytest.mark.parametrize( - "valid_dataset", - list_all_valid_datasets, -) -@pytest.mark.parametrize( - ("filter_func, filter_kwargs"), - [ - (median_filter, {"window": 2}), - (median_filter, {"window": 4}), - (savgol_filter, {"window": 2, "polyorder": 1}), - (savgol_filter, {"window": 4, "polyorder": 2}), - ], -) -def test_filter_on_position( - filter_func, filter_kwargs, valid_dataset, request -): - """Test that applying a filter to the position data returns - a different xr.DataArray than the input position data. - """ - # Filter position - valid_input_dataset = request.getfixturevalue(valid_dataset) - position_filtered = filter_func( - valid_input_dataset.position, **filter_kwargs - ) - - del position_filtered.attrs["log"] - - # filtered array is an xr.DataArray - assert isinstance(position_filtered, xr.DataArray) - - # filtered data should not be equal to the original data - assert not position_filtered.equals(valid_input_dataset.position) - - -# Expected number of nans in the position array per -# individual, after applying a filter with window size 3 -@pytest.mark.parametrize( - ("valid_dataset, expected_nans_in_filtered_position_per_indiv"), - [ - ( - "valid_poses_dataset", - {0: 0, 1: 0}, - ), # filtering should not introduce nans if input has no nans - ("valid_bboxes_dataset", {0: 0, 1: 0}), - ("valid_poses_dataset_with_nan", {0: 7, 1: 0}), - ("valid_bboxes_dataset_with_nan", {0: 7, 1: 0}), - ], -) -@pytest.mark.parametrize( - ("filter_func, filter_kwargs"), - [ - (median_filter, {"window": 3}), - (savgol_filter, {"window": 3, "polyorder": 2}), - ], -) -def test_filter_with_nans_on_position( - filter_func, - filter_kwargs, - valid_dataset, - expected_nans_in_filtered_position_per_indiv, - helpers, - request, -): - """Test NaN behaviour of the selected filter. The median and SG filters - should set all values to NaN if one element of the sliding window is NaN. - """ - - def _assert_n_nans_in_position_per_individual( - valid_input_dataset, - position_filtered, - expected_nans_in_filt_position_per_indiv, - ): - # compute n nans in position after filtering per individual - n_nans_after_filtering_per_indiv = { - i: helpers.count_nans(position_filtered.isel(individuals=i)) - for i in range(valid_input_dataset.sizes["individuals"]) - } - - # check number of nans per indiv is as expected - for i in range(valid_input_dataset.sizes["individuals"]): - assert n_nans_after_filtering_per_indiv[i] == ( - expected_nans_in_filt_position_per_indiv[i] - * valid_input_dataset.sizes["space"] - * valid_input_dataset.sizes.get("keypoints", 1) - ) - - # Filter position - valid_input_dataset = request.getfixturevalue(valid_dataset) - position_filtered = filter_func( - valid_input_dataset.position, **filter_kwargs - ) - - # check number of nans per indiv is as expected - _assert_n_nans_in_position_per_individual( - valid_input_dataset, - position_filtered, - expected_nans_in_filtered_position_per_indiv, - ) - - # if input had nans, - # individual 1's position at exact timepoints 0, 1 and 5 is not nan - n_nans_input = helpers.count_nans(valid_input_dataset.position) - if n_nans_input != 0: - assert not ( - position_filtered.isel(individuals=0, time=[0, 1, 5]) - .isnull() - .any() - ) - - -@pytest.mark.parametrize( - "valid_dataset_with_nan", - list_valid_datasets_with_nans, -) -@pytest.mark.parametrize( - "window", - [3, 5, 6, 10], # data is nframes = 10 -) -@pytest.mark.parametrize( - "filter_func", - [median_filter, savgol_filter], -) -def test_filter_with_nans_on_position_varying_window( - valid_dataset_with_nan, window, filter_func, helpers, request -): - """Test that the number of NaNs in the filtered position data - increases at most by the filter's window length minus one - multiplied by the number of consecutive NaNs in the input data. - """ - # Prepare kwargs per filter - kwargs = {"window": window} - if filter_func == savgol_filter: - kwargs["polyorder"] = 2 - - # Filter position - valid_input_dataset = request.getfixturevalue(valid_dataset_with_nan) - position_filtered = filter_func( - valid_input_dataset.position, - **kwargs, - ) - - # Count number of NaNs in the input and filtered position data - n_total_nans_initial = helpers.count_nans(valid_input_dataset.position) - n_consecutive_nans_initial = helpers.count_consecutive_nans( - valid_input_dataset.position - ) - - n_total_nans_filtered = helpers.count_nans(position_filtered) - - max_nans_increase = (window - 1) * n_consecutive_nans_initial - - # Check that filtering does not reduce number of nans - assert n_total_nans_filtered >= n_total_nans_initial - # Check that the increase in nans is below the expected threshold - assert n_total_nans_filtered - n_total_nans_initial <= max_nans_increase - - -@pytest.mark.parametrize( - "valid_dataset", - list_all_valid_datasets, -) -@pytest.mark.parametrize( - "override_kwargs", - [ - {"mode": "nearest"}, - {"axis": 1}, - {"mode": "nearest", "axis": 1}, - ], -) -def test_savgol_filter_kwargs_override( - valid_dataset, override_kwargs, request -): - """Test that overriding keyword arguments in the Savitzky-Golay filter - works, except for the ``axis`` argument, which should raise a ValueError. - """ - expected_exception = ( - pytest.raises(ValueError) - if "axis" in override_kwargs - else does_not_raise() - ) - with expected_exception: - savgol_filter( - request.getfixturevalue(valid_dataset).position, - window=3, - **override_kwargs, - ) diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index ff3ffc22..b8a2c9c3 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -10,194 +10,149 @@ @pytest.mark.parametrize( - "valid_dataset_uniform_linear_motion", - [ - "valid_poses_dataset_uniform_linear_motion", - "valid_bboxes_dataset", - ], + "kinematic_variable", ["displacement", "velocity", "acceleration", "speed"] ) -@pytest.mark.parametrize( - "kinematic_variable, expected_kinematics", - [ - ( - "displacement", - [ - np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), # Individual 0 - np.multiply( - np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), - np.array([1, -1]), - ), # Individual 1 - ], - ), - ( - "velocity", - [ - np.ones((10, 2)), # Individual 0 - np.multiply( - np.ones((10, 2)), np.array([1, -1]) - ), # Individual 1 - ], - ), - ( - "acceleration", - [ - np.zeros((10, 2)), # Individual 0 - np.zeros((10, 2)), # Individual 1 - ], - ), - ( - "speed", # magnitude of velocity - [ - np.ones(10) * np.sqrt(2), # Individual 0 - np.ones(10) * np.sqrt(2), # Individual 1 - ], - ), - ], -) -def test_kinematics_uniform_linear_motion( - valid_dataset_uniform_linear_motion, - kinematic_variable, - expected_kinematics, # 2D: n_frames, n_space_dims - request, -): - """Test computed kinematics for a uniform linear motion case. - - Uniform linear motion means the individuals move along a line - at constant velocity. - - We consider 2 individuals ("id_0" and "id_1"), - tracked for 10 frames, along x and y: - - id_0 moves along x=y line from the origin - - id_1 moves along x=-y line from the origin - - they both move one unit (pixel) along each axis in each frame +class TestComputeKinematics: + """Test ``compute_[kinematic_variable]`` with valid and invalid inputs.""" + + expected_kinematics = { + "displacement": [ + np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), + np.multiply( + np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), + np.array([1, -1]), + ), + ], # [Individual 0, Individual 1] + "velocity": [ + np.ones((10, 2)), + np.multiply(np.ones((10, 2)), np.array([1, -1])), + ], + "acceleration": [np.zeros((10, 2)), np.zeros((10, 2))], + "speed": [np.ones(10) * np.sqrt(2), np.ones(10) * np.sqrt(2)], + } # 2D: n_frames, n_space_dims - If the dataset is a poses dataset, we consider 3 keypoints per individual - (centroid, left, right), that are always in front of the centroid keypoint - at 45deg from the trajectory. - """ - # Compute kinematic array from input dataset - position = request.getfixturevalue( - valid_dataset_uniform_linear_motion - ).position - kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")( - position + @pytest.mark.parametrize( + "valid_dataset", ["valid_poses_dataset", "valid_bboxes_dataset"] ) + def test_kinematics(self, valid_dataset, kinematic_variable, request): + """Test computed kinematics for a uniform linear motion case. + See the ``valid_poses_dataset`` and ``valid_bboxes_dataset`` fixtures + for details. + """ + # Compute kinematic array from input dataset + position = request.getfixturevalue(valid_dataset).position + kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")( + position + ) + # Figure out which dimensions to expect in kinematic_array + # and in the final xarray.DataArray + expected_dims = ["time", "individuals"] + if kinematic_variable in ["displacement", "velocity", "acceleration"]: + expected_dims.insert(1, "space") + # Build expected data array from the expected numpy array + expected_array = xr.DataArray( + # Stack along the "individuals" axis + np.stack( + self.expected_kinematics.get(kinematic_variable), axis=-1 + ), + dims=expected_dims, + ) + if "keypoints" in position.coords: + expected_array = expected_array.expand_dims( + {"keypoints": position.coords["keypoints"].size} + ) + expected_dims.insert(-1, "keypoints") + expected_array = expected_array.transpose(*expected_dims) + # Compare the values of the kinematic_array against the expected_array + np.testing.assert_allclose( + kinematic_array.values, expected_array.values + ) - # Figure out which dimensions to expect in kinematic_array - # and in the final xarray.DataArray - expected_dims = ["time", "individuals"] - if kinematic_variable in ["displacement", "velocity", "acceleration"]: - expected_dims.insert(1, "space") - - # Build expected data array from the expected numpy array - expected_array = xr.DataArray( - # Stack along the "individuals" axis - np.stack(expected_kinematics, axis=-1), - dims=expected_dims, + @pytest.mark.parametrize( + "valid_dataset_with_nan, expected_nans_per_individual", + [ + ( + "valid_poses_dataset_with_nan", + { + "displacement": [30, 0], + "velocity": [36, 0], + "acceleration": [40, 0], + "speed": [18, 0], + }, + ), + ( + "valid_bboxes_dataset_with_nan", + { + "displacement": [10, 0], + "velocity": [12, 0], + "acceleration": [14, 0], + "speed": [6, 0], + }, + ), + ], ) - if "keypoints" in position.coords: - expected_array = expected_array.expand_dims( - {"keypoints": position.coords["keypoints"].size} + def test_kinematics_with_dataset_with_nans( + self, + valid_dataset_with_nan, + expected_nans_per_individual, + kinematic_variable, + helpers, + request, + ): + """Test kinematics computation for a dataset with nans. + See the ``valid_poses_dataset_with_nan`` and + ``valid_bboxes_dataset_with_nan`` fixtures for details. + """ + # compute kinematic array + valid_dataset = request.getfixturevalue(valid_dataset_with_nan) + position = valid_dataset.position + kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")( + position + ) + # compute n nans in kinematic array per individual + n_nans_kinematics_per_indiv = [ + helpers.count_nans(kinematic_array.isel(individuals=i)) + for i in range(valid_dataset.sizes["individuals"]) + ] + # assert n nans in kinematic array per individual matches expected + assert ( + n_nans_kinematics_per_indiv + == expected_nans_per_individual[kinematic_variable] ) - expected_dims.insert(-1, "keypoints") - expected_array = expected_array.transpose(*expected_dims) - - # Compare the values of the kinematic_array against the expected_array - np.testing.assert_allclose(kinematic_array.values, expected_array.values) - - -@pytest.mark.parametrize( - "valid_dataset_with_nan", - [ - "valid_poses_dataset_with_nan", - "valid_bboxes_dataset_with_nan", - ], -) -@pytest.mark.parametrize( - "kinematic_variable, expected_nans_per_individual", - [ - ("displacement", [5, 0]), # individual 0, individual 1 - ("velocity", [6, 0]), - ("acceleration", [7, 0]), - ("speed", [6, 0]), - ], -) -def test_kinematics_with_dataset_with_nans( - valid_dataset_with_nan, - kinematic_variable, - expected_nans_per_individual, - helpers, - request, -): - """Test kinematics computation for a dataset with nans. - - We test that the kinematics can be computed and that the number - of nan values in the kinematic array is as expected. - """ - # compute kinematic array - valid_dataset = request.getfixturevalue(valid_dataset_with_nan) - position = valid_dataset.position - kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")( - position - ) - # compute n nans in kinematic array per individual - n_nans_kinematics_per_indiv = [ - helpers.count_nans(kinematic_array.isel(individuals=i)) - for i in range(valid_dataset.sizes["individuals"]) - ] - # expected nans per individual adjusted for space and keypoints dimensions - n_space_dims = ( - position.sizes["space"] if "space" in kinematic_array.dims else 1 - ) - expected_nans_adjusted = [ - n * n_space_dims * valid_dataset.sizes.get("keypoints", 1) - for n in expected_nans_per_individual - ] - # check number of nans per individual is as expected in kinematic array - np.testing.assert_array_equal( - n_nans_kinematics_per_indiv, expected_nans_adjusted + @pytest.mark.parametrize( + "invalid_dataset, expected_exception", + [ + ("not_a_dataset", pytest.raises(AttributeError)), + ("empty_dataset", pytest.raises(AttributeError)), + ("missing_var_poses_dataset", pytest.raises(AttributeError)), + ("missing_var_bboxes_dataset", pytest.raises(AttributeError)), + ("missing_dim_poses_dataset", pytest.raises(ValueError)), + ("missing_dim_bboxes_dataset", pytest.raises(ValueError)), + ], ) + def test_kinematics_with_invalid_dataset( + self, invalid_dataset, expected_exception, kinematic_variable, request + ): + """Test kinematics computation with an invalid dataset.""" + with expected_exception: + position = request.getfixturevalue(invalid_dataset).position + getattr(kinematics, f"compute_{kinematic_variable}")(position) @pytest.mark.parametrize( - "invalid_dataset, expected_exception", - [ - ("not_a_dataset", pytest.raises(AttributeError)), - ("empty_dataset", pytest.raises(AttributeError)), - ("missing_var_poses_dataset", pytest.raises(AttributeError)), - ("missing_var_bboxes_dataset", pytest.raises(AttributeError)), - ("missing_dim_poses_dataset", pytest.raises(ValueError)), - ("missing_dim_bboxes_dataset", pytest.raises(ValueError)), - ], -) -@pytest.mark.parametrize( - "kinematic_variable", + "order, expected_exception", [ - "displacement", - "velocity", - "acceleration", - "speed", + (0, pytest.raises(ValueError)), + (-1, pytest.raises(ValueError)), + (1.0, pytest.raises(TypeError)), + ("1", pytest.raises(TypeError)), ], ) -def test_kinematics_with_invalid_dataset( - invalid_dataset, - expected_exception, - kinematic_variable, - request, -): - """Test kinematics computation with an invalid dataset.""" - with expected_exception: - position = request.getfixturevalue(invalid_dataset).position - getattr(kinematics, f"compute_{kinematic_variable}")(position) - - -@pytest.mark.parametrize("order", [0, -1, 1.0, "1"]) -def test_approximate_derivative_with_invalid_order(order): +def test_time_derivative_with_invalid_order(order, expected_exception): """Test that an error is raised when the order is non-positive.""" data = np.arange(10) - expected_exception = ValueError if isinstance(order, int) else TypeError - with pytest.raises(expected_exception): + with expected_exception: kinematics.compute_time_derivative(data, order=order) @@ -228,22 +183,19 @@ def test_approximate_derivative_with_invalid_order(order): ], ) def test_path_length_across_time_ranges( - valid_poses_dataset_uniform_linear_motion, - start, - stop, - expected_exception, + valid_poses_dataset, start, stop, expected_exception ): """Test path length computation for a uniform linear motion case, across different time ranges. - The test dataset ``valid_poses_dataset_uniform_linear_motion`` + The test dataset ``valid_poses_dataset`` contains 2 individuals ("id_0" and "id_1"), moving along x=y and x=-y lines, respectively, at a constant velocity. At each frame they cover a distance of sqrt(2) in x-y space, so in total we expect a path length of sqrt(2) * num_segments, where num_segments is the number of selected frames minus 1. """ - position = valid_poses_dataset_uniform_linear_motion.position + position = valid_poses_dataset.position with expected_exception: path_length = kinematics.compute_path_length( position, start=start, stop=stop @@ -271,11 +223,11 @@ def test_path_length_across_time_ranges( @pytest.mark.parametrize( - "nan_policy, expected_path_lengths_id_1, expected_exception", + "nan_policy, expected_path_lengths_id_0, expected_exception", [ ( "ffill", - np.array([np.sqrt(2) * 8, np.sqrt(2) * 9, np.nan]), + np.array([np.sqrt(2) * 9, np.sqrt(2) * 8, np.nan]), does_not_raise(), ), ( @@ -290,43 +242,30 @@ def test_path_length_across_time_ranges( ), ], ) -def test_path_length_with_nans( - valid_poses_dataset_uniform_linear_motion_with_nans, +def test_path_length_with_nan( + valid_poses_dataset_with_nan, nan_policy, - expected_path_lengths_id_1, + expected_path_lengths_id_0, expected_exception, ): """Test path length computation for a uniform linear motion case, with varying number of missing values per individual and keypoint. - - The test dataset ``valid_poses_dataset_uniform_linear_motion_with_nans`` - contains 2 individuals ("id_0" and "id_1"), moving - along x=y and x=-y lines, respectively, at a constant velocity. - At each frame they cover a distance of sqrt(2) in x-y space. - - Individual "id_1" has some missing values per keypoint: - - "centroid" is missing a value on the very first frame - - "left" is missing 5 values in middle frames (not at the edges) - - "right" is missing values in all frames - - Individual "id_0" has no missing values. - Because the underlying motion is uniform linear, the "scale" policy should - perfectly restore the path length for individual "id_1" to its true value. + perfectly restore the path length for individual "id_0" to its true value. The "ffill" policy should do likewise if frames are missing in the middle, but will not "correct" for missing values at the edges. """ - position = valid_poses_dataset_uniform_linear_motion_with_nans.position + position = valid_poses_dataset_with_nan.position with expected_exception: path_length = kinematics.compute_path_length( position, nan_policy=nan_policy, ) - # Get path_length for individual "id_1" as a numpy array - path_length_id_1 = path_length.sel(individuals="id_1").values + # Get path_length for individual "id_0" as a numpy array + path_length_id_0 = path_length.sel(individuals="id_0").values # Check them against the expected values np.testing.assert_allclose( - path_length_id_1, expected_path_lengths_id_1 + path_length_id_0, expected_path_lengths_id_0 ) @@ -339,35 +278,32 @@ def test_path_length_with_nans( ], ) def test_path_length_warns_about_nans( - valid_poses_dataset_uniform_linear_motion_with_nans, + valid_poses_dataset_with_nan, nan_warn_threshold, expected_exception, caplog, ): """Test that a warning is raised when the number of missing values exceeds a given threshold. - - See the docstring of ``test_path_length_with_nans`` for details - about what's in the dataset. """ - position = valid_poses_dataset_uniform_linear_motion_with_nans.position + position = valid_poses_dataset_with_nan.position with expected_exception: kinematics.compute_path_length( position, nan_warn_threshold=nan_warn_threshold ) - - if (nan_warn_threshold > 0.1) and (nan_warn_threshold < 0.5): + if 0.1 < nan_warn_threshold < 0.5: # Make sure that a warning was emitted assert caplog.records[0].levelname == "WARNING" assert "The result may be unreliable" in caplog.records[0].message # Make sure that the NaN report only mentions # the individual and keypoint that violate the threshold + info_msg = caplog.records[1].message assert caplog.records[1].levelname == "INFO" - assert "Individual: id_1" in caplog.records[1].message - assert "Individual: id_2" not in caplog.records[1].message - assert "left: 5/10 (50.0%)" in caplog.records[1].message - assert "right: 10/10 (100.0%)" in caplog.records[1].message - assert "centroid" not in caplog.records[1].message + assert "Individual: id_0" in info_msg + assert "Individual: id_1" not in info_msg + assert "centroid: 3/10 (30.0%)" in info_msg + assert "right: 10/10 (100.0%)" in info_msg + assert "left" not in info_msg @pytest.fixture @@ -376,7 +312,7 @@ def valid_data_array_for_forward_vector(): (left ear, right ear and nose), tracked for 4 frames, in x-y space. """ time = [0, 1, 2, 3] - individuals = ["individual_0"] + individuals = ["id_0"] keypoints = ["left_ear", "right_ear", "nose"] space = ["x", "y"] @@ -426,7 +362,7 @@ def invalid_spatial_dimensions_for_forward_vector( @pytest.fixture -def valid_data_array_for_forward_vector_with_nans( +def valid_data_array_for_forward_vector_with_nan( valid_data_array_for_forward_vector, ): """Return a position DataArray where position values are NaN for the @@ -522,7 +458,7 @@ def test_compute_forward_vector_with_invalid_input( def test_nan_behavior_forward_vector( - valid_data_array_for_forward_vector_with_nans: xr.DataArray, + valid_data_array_for_forward_vector_with_nan, ): """Test that ``compute_forward_vector()`` generates the expected output for a valid input DataArray containing ``NaN`` @@ -531,13 +467,13 @@ def test_nan_behavior_forward_vector( """ nan_time = 1 forward_vector = kinematics.compute_forward_vector( - valid_data_array_for_forward_vector_with_nans, "left_ear", "right_ear" + valid_data_array_for_forward_vector_with_nan, "left_ear", "right_ear" ) # Check coord preservation for preserved_coord in ["time", "space", "individuals"]: assert np.all( forward_vector[preserved_coord] - == valid_data_array_for_forward_vector_with_nans[preserved_coord] + == valid_data_array_for_forward_vector_with_nan[preserved_coord] ) assert set(forward_vector["space"].values) == {"x", "y"} # Should have NaN values in the forward vector at time 1 and left_ear @@ -551,7 +487,7 @@ def test_nan_behavior_forward_vector( forward_vector.sel( time=[ t - for t in valid_data_array_for_forward_vector_with_nans.time + for t in valid_data_array_for_forward_vector_with_nan.time if t != nan_time ] ) @@ -586,12 +522,10 @@ def test_nan_behavior_forward_vector( ), ], ) -def test_cdist_with_known_values( - dim, expected_data, valid_poses_dataset_uniform_linear_motion -): +def test_cdist_with_known_values(dim, expected_data, valid_poses_dataset): """Test the computation of pairwise distances with known values.""" labels_dim = "keypoints" if dim == "individuals" else "individuals" - input_dataarray = valid_poses_dataset_uniform_linear_motion.position.sel( + input_dataarray = valid_poses_dataset.position.sel( time=slice(0, 1) ) # Use only the first two frames for simplicity pairs = input_dataarray[dim].values[:2] @@ -615,10 +549,7 @@ def test_cdist_with_known_values( @pytest.mark.parametrize( "valid_dataset", - [ - "valid_poses_dataset_uniform_linear_motion", - "valid_bboxes_dataset", - ], + ["valid_poses_dataset", "valid_bboxes_dataset"], ) @pytest.mark.parametrize( "selection_fn", @@ -688,12 +619,12 @@ def test_cdist_with_single_dim_inputs(valid_dataset, selection_fn, request): @pytest.mark.parametrize( "dim, pairs, expected_data_vars", [ - ("individuals", {"id_1": ["id_2"]}, None), # list input - ("individuals", {"id_1": "id_2"}, None), # string input + ("individuals", {"id_0": ["id_1"]}, None), # list input + ("individuals", {"id_0": "id_1"}, None), # string input ( "individuals", - {"id_1": ["id_2"], "id_2": "id_1"}, - [("id_1", "id_2"), ("id_2", "id_1")], + {"id_0": ["id_1"], "id_1": "id_0"}, + [("id_0", "id_1"), ("id_1", "id_0")], ), ("individuals", "all", None), # all pairs ("keypoints", {"centroid": ["left"]}, None), # list input @@ -711,13 +642,13 @@ def test_cdist_with_single_dim_inputs(valid_dataset, selection_fn, request): ], ) def test_compute_pairwise_distances_with_valid_pairs( - valid_poses_dataset_uniform_linear_motion, dim, pairs, expected_data_vars + valid_poses_dataset, dim, pairs, expected_data_vars ): """Test that the expected pairwise distances are computed for valid ``pairs`` inputs. """ result = kinematics.compute_pairwise_distances( - valid_poses_dataset_uniform_linear_motion.position, dim, pairs + valid_poses_dataset.position, dim, pairs ) if isinstance(result, dict): expected_data_vars = [ @@ -732,20 +663,16 @@ def test_compute_pairwise_distances_with_valid_pairs( "ds, dim, pairs", [ ( - "valid_poses_dataset_uniform_linear_motion", + "valid_poses_dataset", "invalid_dim", - {"id_1": "id_2"}, + {"id_0": "id_1"}, ), # invalid dim ( - "valid_poses_dataset_uniform_linear_motion", + "valid_poses_dataset", "keypoints", "invalid_string", ), # invalid pairs - ( - "valid_poses_dataset_uniform_linear_motion", - "individuals", - {}, - ), # empty pairs + ("valid_poses_dataset", "individuals", {}), # empty pairs ("missing_dim_poses_dataset", "keypoints", "all"), # invalid dataset ( "missing_dim_bboxes_dataset", diff --git a/tests/test_unit/test_load_bboxes.py b/tests/test_unit/test_load_bboxes.py index a15e2644..5922d498 100644 --- a/tests/test_unit/test_load_bboxes.py +++ b/tests/test_unit/test_load_bboxes.py @@ -261,11 +261,7 @@ def test_from_file( @pytest.mark.parametrize("use_frame_numbers_from_file", [True, False]) @pytest.mark.parametrize("frame_regexp", [None, r"(00\d*)\.\w+$"]) def test_from_via_tracks_file( - via_file_path, - fps, - use_frame_numbers_from_file, - frame_regexp, - movement_dataset_asserts, + via_file_path, fps, use_frame_numbers_from_file, frame_regexp, helpers ): """Test that loading tracked bounding box data from a valid VIA tracks .csv file returns a proper Dataset. @@ -283,7 +279,7 @@ def test_from_via_tracks_file( "fps": fps, "file_path": via_file_path, } - movement_dataset_asserts.valid_dataset(ds, expected_values) + helpers.assert_valid_dataset(ds, expected_values) @pytest.mark.parametrize( @@ -342,7 +338,7 @@ def test_from_numpy( with_frame_array, fps, source_software, - movement_dataset_asserts, + helpers, ): """Test that loading bounding boxes trajectories from the input numpy arrays returns a proper Dataset. @@ -360,7 +356,7 @@ def test_from_numpy( "source_software": source_software, "fps": fps, } - movement_dataset_asserts.valid_dataset(ds, expected_values) + helpers.assert_valid_dataset(ds, expected_values) # check time coordinates are as expected start_frame = ( from_numpy_inputs["frame_array"][0, 0] diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index 8c07ae9d..90126f15 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -74,7 +74,7 @@ def sleap_file_without_tracks(request): } -def test_load_from_sleap_file(sleap_file, movement_dataset_asserts): +def test_load_from_sleap_file(sleap_file, helpers): """Test that loading pose tracks from valid SLEAP files returns a proper Dataset. """ @@ -84,7 +84,7 @@ def test_load_from_sleap_file(sleap_file, movement_dataset_asserts): "source_software": "SLEAP", "file_path": sleap_file, } - movement_dataset_asserts.valid_dataset(ds, expected_values) + helpers.assert_valid_dataset(ds, expected_values) def test_load_from_sleap_file_without_tracks(sleap_file_without_tracks): @@ -142,7 +142,7 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same(slp_file, h5_file): "DLC_two-mice.predictions.csv", ], ) -def test_load_from_dlc_file(file_name, movement_dataset_asserts): +def test_load_from_dlc_file(file_name, helpers): """Test that loading pose tracks from valid DLC files returns a proper Dataset. """ @@ -153,26 +153,24 @@ def test_load_from_dlc_file(file_name, movement_dataset_asserts): "source_software": "DeepLabCut", "file_path": file_path, } - movement_dataset_asserts.valid_dataset(ds, expected_values) + helpers.assert_valid_dataset(ds, expected_values) @pytest.mark.parametrize( "source_software", ["DeepLabCut", "LightningPose", None] ) -def test_load_from_dlc_style_df( - dlc_style_df, source_software, movement_dataset_asserts -): +def test_load_from_dlc_style_df(valid_dlc_poses_df, source_software, helpers): """Test that loading pose tracks from a valid DLC-style DataFrame returns a proper Dataset. """ ds = load_poses.from_dlc_style_df( - dlc_style_df, source_software=source_software + valid_dlc_poses_df, source_software=source_software ) expected_values = { **expected_values_poses, "source_software": source_software, } - movement_dataset_asserts.valid_dataset(ds, expected_values) + helpers.assert_valid_dataset(ds, expected_values) def test_load_from_dlc_file_csv_or_h5_file_returns_same(): @@ -220,7 +218,7 @@ def test_fps_and_time_coords(fps, expected_fps, expected_time_unit): "LP_mouse-twoview_AIND.predictions.csv", ], ) -def test_load_from_lp_file(file_name, movement_dataset_asserts): +def test_load_from_lp_file(file_name, helpers): """Test that loading pose tracks from valid LightningPose (LP) files returns a proper Dataset. """ @@ -231,7 +229,7 @@ def test_load_from_lp_file(file_name, movement_dataset_asserts): "source_software": "LightningPose", "file_path": file_path, } - movement_dataset_asserts.valid_dataset(ds, expected_values) + helpers.assert_valid_dataset(ds, expected_values) def test_load_from_lp_or_dlc_file_returns_same(): @@ -281,20 +279,16 @@ def test_from_file_delegates_correctly(source_software, fps): @pytest.mark.parametrize("source_software", [None, "SLEAP"]) -def test_from_numpy_valid( - valid_position_array, source_software, movement_dataset_asserts -): +def test_from_numpy_valid(valid_poses_arrays, source_software, helpers): """Test that loading pose tracks from a multi-animal numpy array with valid parameters returns a proper Dataset. """ - valid_position = valid_position_array("multi_individual_array") - rng = np.random.default_rng(seed=42) - valid_confidence = rng.random(valid_position.shape[:-1]) + poses_arrays = valid_poses_arrays("multi_individual_array") ds = load_poses.from_numpy( - valid_position, - valid_confidence, - individual_names=["mouse1", "mouse2"], - keypoint_names=["snout", "tail"], + poses_arrays["position"], + poses_arrays["confidence"], + individual_names=["id_0", "id_1"], + keypoint_names=["centroid", "left", "right"], fps=None, source_software=source_software, ) @@ -302,7 +296,7 @@ def test_from_numpy_valid( **expected_values_poses, "source_software": source_software, } - movement_dataset_asserts.valid_dataset(ds, expected_values) + helpers.assert_valid_dataset(ds, expected_values) def test_from_multiview_files(): diff --git a/tests/test_unit/test_logging.py b/tests/test_unit/test_logging.py index 348a3687..583fd79a 100644 --- a/tests/test_unit/test_logging.py +++ b/tests/test_unit/test_logging.py @@ -48,10 +48,7 @@ def test_log_warning(caplog): @pytest.mark.parametrize( "input_data", - [ - "valid_poses_dataset", - "valid_bboxes_dataset", - ], + ["valid_poses_dataset", "valid_bboxes_dataset"], ) @pytest.mark.parametrize( "selector_fn, expected_selector_type", diff --git a/tests/test_unit/test_reports.py b/tests/test_unit/test_reports.py index 79c3bc89..d7c896a4 100644 --- a/tests/test_unit/test_reports.py +++ b/tests/test_unit/test_reports.py @@ -13,85 +13,76 @@ ], ) @pytest.mark.parametrize( - "data_selection, list_expected_individuals_indices", + "data_selection, expected_individuals_indices", [ (lambda ds: ds.position, [0, 1]), # full position data array ( lambda ds: ds.position.isel(individuals=0), [0], - ), # position of individual 0 only + ), # individual 0 only ], ) def test_report_nan_values_in_position_selecting_individual( valid_dataset, data_selection, - list_expected_individuals_indices, + expected_individuals_indices, request, ): """Test that the nan-value reporting function handles position data - with specific ``individuals`` , and that the data array name (position) + with specific ``individuals``, and that the data array name (position) and only the relevant individuals are included in the report. """ # extract relevant position data input_dataset = request.getfixturevalue(valid_dataset) output_data_array = data_selection(input_dataset) - # produce report report_str = report_nan_values(output_data_array) - # check report of nan values includes name of data array assert output_data_array.name in report_str - # check report of nan values includes selected individuals only - list_expected_individuals = [ - input_dataset["individuals"][idx].item() - for idx in list_expected_individuals_indices - ] - list_not_expected_individuals = [ - indiv.item() - for indiv in input_dataset["individuals"] - if indiv.item() not in list_expected_individuals - ] - assert all([ind in report_str for ind in list_expected_individuals]) - assert all( - [ind not in report_str for ind in list_not_expected_individuals] + list_of_individuals = input_dataset["individuals"].values.tolist() + all_individuals = set(list_of_individuals) + expected_individuals = set( + list_of_individuals[i] for i in expected_individuals_indices ) + not_expected_individuals = all_individuals - expected_individuals + assert all(ind in report_str for ind in expected_individuals) and all( + ind not in report_str for ind in not_expected_individuals + ), "Report contains incorrect individuals." @pytest.mark.parametrize( "valid_dataset", - [ - "valid_poses_dataset", - "valid_poses_dataset_with_nan", - ], + ["valid_poses_dataset", "valid_poses_dataset_with_nan"], ) @pytest.mark.parametrize( - "data_selection, list_expected_keypoints, list_expected_individuals", + "data_selection, expected_keypoints, expected_individuals", [ ( lambda ds: ds.position, - ["key1", "key2"], - ["ind1", "ind2"], + {"centroid", "left", "right"}, + {"id_0", "id_1"}, ), # Report nans in position for all keypoints and individuals ( - lambda ds: ds.position.sel(keypoints="key1"), - [], - ["ind1", "ind2"], - ), # Report nans in position for keypoint "key1", for all individuals - # Note: if only one keypoint exists, it is not explicitly reported + lambda ds: ds.position.sel(keypoints=["centroid", "left"]), + {"centroid", "left"}, + {"id_0", "id_1"}, + ), # Report nans in position for 2 keypoints, for all individuals ( - lambda ds: ds.position.sel(individuals="ind1", keypoints="key1"), - [], - ["ind1"], - ), # Report nans in position for individual "ind1" and keypoint "key1" - # Note: if only one keypoint exists, it is not explicitly reported + lambda ds: ds.position.sel( + individuals="id_0", keypoints="centroid" + ), + set(), + {"id_0"}, + ), # Report nans in position for centroid of individual id_0 + # Note: if only 1 keypoint exists, its name is not explicitly reported ], ) def test_report_nan_values_in_position_selecting_keypoint( valid_dataset, data_selection, - list_expected_keypoints, - list_expected_individuals, + expected_keypoints, + expected_individuals, request, ): """Test that the nan-value reporting function handles position data @@ -101,29 +92,19 @@ def test_report_nan_values_in_position_selecting_keypoint( # extract relevant position data input_dataset = request.getfixturevalue(valid_dataset) output_data_array = data_selection(input_dataset) - # produce report report_str = report_nan_values(output_data_array) - # check report of nan values includes name of data array assert output_data_array.name in report_str - # check report of nan values includes only selected keypoints - list_not_expected_keypoints = [ - indiv.item() - for indiv in input_dataset["keypoints"] - if indiv.item() not in list_expected_keypoints - ] - assert all([kpt in report_str for kpt in list_expected_keypoints]) - assert all([kpt not in report_str for kpt in list_not_expected_keypoints]) - + all_keypoints = set(input_dataset["keypoints"].values.tolist()) + not_expected_keypoints = all_keypoints - expected_keypoints + assert all(kpt in report_str for kpt in expected_keypoints) and all( + kpt not in report_str for kpt in not_expected_keypoints + ), "Report contains incorrect keypoints." # check report of nan values includes selected individuals only - list_not_expected_individuals = [ - indiv.item() - for indiv in input_dataset["individuals"] - if indiv.item() not in list_expected_individuals - ] - assert all([ind in report_str for ind in list_expected_individuals]) - assert all( - [ind not in report_str for ind in list_not_expected_individuals] - ) + all_individuals = set(input_dataset["individuals"].values.tolist()) + not_expected_individuals = all_individuals - expected_individuals + assert all(ind in report_str for ind in expected_individuals) and all( + ind not in report_str for ind in not_expected_individuals + ), "Report contains incorrect individuals." diff --git a/tests/test_unit/test_save_poses.py b/tests/test_unit/test_save_poses.py index 592f0c9a..a495e790 100644 --- a/tests/test_unit/test_save_poses.py +++ b/tests/test_unit/test_save_poses.py @@ -31,7 +31,7 @@ class TestSavePoses: # invalid file path }, { - "file_fixture": "new_file_wrong_ext", + "file_fixture": "wrong_extension_new_file", "to_dlc_file_expected_exception": pytest.raises(ValueError), "to_sleap_file_expected_exception": pytest.raises(ValueError), "to_lp_file_expected_exception": pytest.raises(ValueError), @@ -175,9 +175,7 @@ def test_auto_split_individuals(self, valid_poses_dataset, split_value): indirect=["valid_poses_dataset"], ) def test_to_dlc_style_df_split_individuals( - self, - valid_poses_dataset, - split_individuals, + self, valid_poses_dataset, split_individuals ): """Test that the `split_individuals` argument affects the behaviour of the `to_dlc_style_df` function as expected. @@ -231,9 +229,7 @@ def test_to_dlc_file_split_individuals( """ with expected_exception: save_poses.to_dlc_file( - valid_poses_dataset, - new_h5_file, - split_individuals, + valid_poses_dataset, new_h5_file, split_individuals ) # Get the names of the individuals in the dataset ind_names = valid_poses_dataset.individuals.values @@ -311,7 +307,7 @@ def test_remove_unoccupied_tracks(self, valid_poses_dataset): """Test that removing unoccupied tracks from a valid pose dataset returns the expected result. """ - new_individuals = [f"ind{i}" for i in range(1, 4)] + new_individuals = [f"id_{i}" for i in range(3)] # Add new individual with NaN data ds = valid_poses_dataset.reindex(individuals=new_individuals) ds = save_poses._remove_unoccupied_tracks(ds) diff --git a/tests/test_unit/test_validators/test_array_validators.py b/tests/test_unit/test_validators/test_array_validators.py index 674d7b80..39305bc8 100644 --- a/tests/test_unit/test_validators/test_array_validators.py +++ b/tests/test_unit/test_validators/test_array_validators.py @@ -59,13 +59,13 @@ def expect_value_error_with_message(error_msg): valid_cases + invalid_cases, ) def test_validate_dims_coords( - valid_poses_dataset_uniform_linear_motion, # fixture from conftest.py + valid_poses_dataset, # fixture from conftest.py required_dims_coords, exact_coords, expected_exception, ): """Test validate_dims_coords for both valid and invalid inputs.""" - position_array = valid_poses_dataset_uniform_linear_motion["position"] + position_array = valid_poses_dataset["position"] with expected_exception: validate_dims_coords( position_array, required_dims_coords, exact_coords=exact_coords diff --git a/tests/test_unit/test_validators/test_datasets_validators.py b/tests/test_unit/test_validators/test_datasets_validators.py index 17048334..fcc63d14 100644 --- a/tests/test_unit/test_validators/test_datasets_validators.py +++ b/tests/test_unit/test_validators/test_datasets_validators.py @@ -120,9 +120,9 @@ def test_poses_dataset_validator_with_invalid_position_array( "confidence_array, expected_exception", [ ( - np.ones((10, 3, 2)), + np.ones((10, 2, 2)), pytest.raises(ValueError), - ), # will not match position_array shape + ), # will not match position_array shape (10, 2, 3, 2) ( [1, 2, 3], pytest.raises(ValueError), @@ -134,14 +134,14 @@ def test_poses_dataset_validator_with_invalid_position_array( ], ) def test_poses_dataset_validator_confidence_array( - confidence_array, - expected_exception, - valid_position_array, + confidence_array, expected_exception, valid_poses_arrays ): """Test that invalid confidence arrays raise the appropriate errors.""" with expected_exception: poses = ValidPosesDataset( - position_array=valid_position_array("multi_individual_array"), + position_array=valid_poses_arrays("multi_individual_array")[ + "position" + ], confidence_array=confidence_array, ) if confidence_array is None: @@ -149,28 +149,28 @@ def test_poses_dataset_validator_confidence_array( def test_poses_dataset_validator_keypoint_names( - position_array_params, valid_position_array + position_array_params, valid_poses_arrays ): """Test that invalid keypoint names raise the appropriate errors.""" with position_array_params.get("keypoint_names_expected_exception") as e: poses = ValidPosesDataset( - position_array=valid_position_array( + position_array=valid_poses_arrays( position_array_params.get("array_type") - ), + )["position"][:, :, :2, :], # select up to the first 2 keypoints keypoint_names=position_array_params.get("names"), ) assert poses.keypoint_names == e def test_poses_dataset_validator_individual_names( - position_array_params, valid_position_array + position_array_params, valid_poses_arrays ): """Test that invalid keypoint names raise the appropriate errors.""" with position_array_params.get("individual_names_expected_exception") as e: poses = ValidPosesDataset( - position_array=valid_position_array( + position_array=valid_poses_arrays( position_array_params.get("array_type") - ), + )["position"], individual_names=position_array_params.get("names"), ) assert poses.individual_names == e @@ -188,14 +188,16 @@ def test_poses_dataset_validator_individual_names( ], ) def test_poses_dataset_validator_source_software( - valid_position_array, source_software, expected_exception + valid_poses_arrays, source_software, expected_exception ): """Test that the source_software attribute is validated properly. LightnigPose is incompatible with multi-individual arrays. """ with expected_exception: ds = ValidPosesDataset( - position_array=valid_position_array("multi_individual_array"), + position_array=valid_poses_arrays("multi_individual_array")[ + "position" + ], source_software=source_software, ) diff --git a/tests/test_unit/test_validators/test_files_validators.py b/tests/test_unit/test_validators/test_files_validators.py index b3149d64..4b5288cf 100644 --- a/tests/test_unit/test_validators/test_files_validators.py +++ b/tests/test_unit/test_validators/test_files_validators.py @@ -15,7 +15,7 @@ ("unreadable_file", pytest.raises(PermissionError)), ("unwriteable_file", pytest.raises(PermissionError)), ("fake_h5_file", pytest.raises(FileExistsError)), - ("wrong_ext_file", pytest.raises(ValueError)), + ("wrong_extension_file", pytest.raises(ValueError)), ("nonexistent_file", pytest.raises(FileNotFoundError)), ("directory", pytest.raises(IsADirectoryError)), ], @@ -36,7 +36,7 @@ def test_file_validator_with_invalid_input( @pytest.mark.parametrize( "invalid_input, expected_exception", [ - ("h5_file_no_dataframe", pytest.raises(ValueError)), + ("no_dataframe_h5_file", pytest.raises(ValueError)), ("fake_h5_file", pytest.raises(ValueError)), ], ) @@ -72,7 +72,7 @@ def test_deeplabcut_csv_validator_with_invalid_input( "invalid_input, error_type, log_message", [ ( - "via_tracks_csv_with_invalid_header", + "via_invalid_header", ValueError, ".csv header row does not match the known format for " "VIA tracks .csv files. " @@ -83,7 +83,7 @@ def test_deeplabcut_csv_validator_with_invalid_input( "but got ['filename', 'file_size', 'file_attributes'].", ), ( - "frame_number_in_file_attribute_not_integer", + "via_frame_number_in_file_attribute_not_integer", ValueError, "04.09.2023-04-Right_RE_test_frame_A.png (row 0): " "'frame' file attribute cannot be cast as an integer. " @@ -91,7 +91,7 @@ def test_deeplabcut_csv_validator_with_invalid_input( "{'clip': 123, 'frame': 'FOO'}.", ), ( - "frame_number_in_filename_wrong_pattern", + "via_frame_number_in_filename_wrong_pattern", AttributeError, "04.09.2023-04-Right_RE_test_frame_1.png (row 0): " "The provided frame regexp ((0\d*)\.\w+$) did not return " @@ -100,28 +100,28 @@ def test_deeplabcut_csv_validator_with_invalid_input( "filename.", ), ( - "more_frame_numbers_than_filenames", + "via_more_frame_numbers_than_filenames", ValueError, "The number of unique frame numbers does not match the number " "of unique image files. Please review the VIA tracks .csv file " "and ensure a unique frame number is defined for each file. ", ), ( - "less_frame_numbers_than_filenames", + "via_less_frame_numbers_than_filenames", ValueError, "The number of unique frame numbers does not match the number " "of unique image files. Please review the VIA tracks .csv file " "and ensure a unique frame number is defined for each file. ", ), ( - "region_shape_attribute_not_rect", + "via_region_shape_attribute_not_rect", ValueError, "04.09.2023-04-Right_RE_test_frame_01.png (row 0): " "bounding box shape must be 'rect' (rectangular) " "but instead got 'circle'.", ), ( - "region_shape_attribute_missing_x", + "via_region_shape_attribute_missing_x", ValueError, "04.09.2023-04-Right_RE_test_frame_01.png (row 0): " "at least one bounding box shape parameter is missing. " @@ -130,7 +130,7 @@ def test_deeplabcut_csv_validator_with_invalid_input( "'['name', 'y', 'width', 'height']'.", ), ( - "region_attribute_missing_track", + "via_region_attribute_missing_track", ValueError, "04.09.2023-04-Right_RE_test_frame_01.png (row 0): " "bounding box does not have a 'track' attribute defined " @@ -138,7 +138,7 @@ def test_deeplabcut_csv_validator_with_invalid_input( "Please review the VIA tracks .csv file.", ), ( - "track_id_not_castable_as_int", + "via_track_id_not_castable_as_int", ValueError, "04.09.2023-04-Right_RE_test_frame_01.png (row 0): " "the track ID for the bounding box cannot be cast " @@ -146,7 +146,7 @@ def test_deeplabcut_csv_validator_with_invalid_input( "Please review the VIA tracks .csv file.", ), ( - "track_ids_not_unique_per_frame", + "via_track_ids_not_unique_per_frame", ValueError, "04.09.2023-04-Right_RE_test_frame_01.png: " "multiple bounding boxes in this file have the same track ID. " @@ -155,7 +155,7 @@ def test_deeplabcut_csv_validator_with_invalid_input( ], ) def test_via_tracks_csv_validator_with_invalid_input( - invalid_input, error_type, log_message, request + invalid_via_tracks_csv_file, invalid_input, error_type, log_message ): """Test that invalid VIA tracks .csv files raise the appropriate errors. @@ -171,7 +171,7 @@ def test_via_tracks_csv_validator_with_invalid_input( (i.e., bboxes IDs must exist only once per frame) - error if bboxes IDs are not 1-based integers """ - file_path = request.getfixturevalue(invalid_input) + file_path = invalid_via_tracks_csv_file(invalid_input) with pytest.raises(error_type) as excinfo: ValidVIATracksCSV(file_path) From 25a37178da8c9a8bbc47937e1789ae27e82acec3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 17:09:35 +0000 Subject: [PATCH 4/6] [pre-commit.ci] pre-commit autoupdate (#401) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.6 → v0.9.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.6...v0.9.4) - [github.com/codespell-project/codespell: v2.3.0 → v2.4.1](https://github.com/codespell-project/codespell/compare/v2.3.0...v2.4.1) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- movement/kinematics.py | 3 +-- movement/validators/files.py | 2 +- tests/test_unit/test_kinematics.py | 6 +++--- tests/test_unit/test_sample_data.py | 18 +++++++++--------- 5 files changed, 16 insertions(+), 17 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9258a3fb..318f29a6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: - id: rst-directive-colons - id: rst-inline-touching-normal - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.4 hooks: - id: ruff - id: ruff-format @@ -53,7 +53,7 @@ repos: additional_dependencies: [setuptools-scm, wheel] - repo: https://github.com/codespell-project/codespell # Configuration for codespell is in pyproject.toml - rev: v2.3.0 + rev: v2.4.1 hooks: - id: codespell additional_dependencies: diff --git a/movement/kinematics.py b/movement/kinematics.py index 4abe9e99..2880a488 100644 --- a/movement/kinematics.py +++ b/movement/kinematics.py @@ -742,8 +742,7 @@ def compute_pairwise_distances( paired_elements = [ (elem1, elem2) for elem1, elem2_list in pairs.items() - for elem2 in - ( + for elem2 in ( # Ensure elem2_list is a list [elem2_list] if isinstance(elem2_list, str) else elem2_list ) diff --git a/movement/validators/files.py b/movement/validators/files.py index bfca9116..3d179efe 100644 --- a/movement/validators/files.py +++ b/movement/validators/files.py @@ -156,7 +156,7 @@ def _file_is_h5(self, attribute, value): except Exception as e: raise log_error( ValueError, - f"File {value} does not seem to be in valid" "HDF5 format.", + f"File {value} does not seem to be in validHDF5 format.", ) from e @path.validator diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index b8a2c9c3..57fd2479 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -479,9 +479,9 @@ def test_nan_behavior_forward_vector( # Should have NaN values in the forward vector at time 1 and left_ear nan_values = forward_vector.sel(time=nan_time) assert nan_values.shape == (1, 2) - assert np.isnan( - nan_values - ).all(), "NaN values not returned where expected!" + assert np.isnan(nan_values).all(), ( + "NaN values not returned where expected!" + ) # Should have no NaN values in the forward vector in other positions assert not np.isnan( forward_vector.sel( diff --git a/tests/test_unit/test_sample_data.py b/tests/test_unit/test_sample_data.py index ad408126..7f548490 100644 --- a/tests/test_unit/test_sample_data.py +++ b/tests/test_unit/test_sample_data.py @@ -55,15 +55,15 @@ def validate_metadata(metadata: dict[str, dict]) -> None: "note", ] check_yaml_msg = "Check the format of the metadata .yaml file." - assert isinstance( - metadata, dict - ), f"Expected metadata to be a dictionary. {check_yaml_msg}" - assert all( - isinstance(ds, str) for ds in metadata - ), f"Expected metadata keys to be strings. {check_yaml_msg}" - assert all( - isinstance(val, dict) for val in metadata.values() - ), f"Expected metadata values to be dicts. {check_yaml_msg}" + assert isinstance(metadata, dict), ( + f"Expected metadata to be a dictionary. {check_yaml_msg}" + ) + assert all(isinstance(ds, str) for ds in metadata), ( + f"Expected metadata keys to be strings. {check_yaml_msg}" + ) + assert all(isinstance(val, dict) for val in metadata.values()), ( + f"Expected metadata values to be dicts. {check_yaml_msg}" + ) assert all( set(val.keys()) == set(metadata_fields) for val in metadata.values() ), f"Found issues with the names of metadata fields. {check_yaml_msg}" From 4fd9366dc8d1c8c93c30e1d92f49d144084b5bda Mon Sep 17 00:00:00 2001 From: Will Graham <32364977+willGraham01@users.noreply.github.com> Date: Tue, 4 Feb 2025 12:39:23 +0000 Subject: [PATCH 5/6] Introduce basic classes for regions of interest (#396) * Create RoI base skeleton * Write 1D and 2D ROI classes * Write some basic instantiation tests * Pre-commit lint * Fix polygon boundary attributes, return our wrappers instead * CodeCov and SonarQube recommendations * rst != markdown, again, Will * Apply batch suggestions from code review Co-authored-by: Niko Sirmpilatze * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Expose non-base ROI classes via __init__ * shapely is a core dependency now * Address kwargs and boundary naming conventions * Update tests to compute holes too * Fix docstring of holes now disambiguation is complete. Co-authored-by: Niko Sirmpilatze --------- Co-authored-by: Niko Sirmpilatze Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- movement/roi/__init__.py | 2 + movement/roi/base.py | 160 ++++++++++++++++++ movement/roi/line.py | 58 +++++++ movement/roi/polygon.py | 105 ++++++++++++ pyproject.toml | 1 + tests/test_unit/test_roi/conftest.py | 21 +++ tests/test_unit/test_roi/test_instantiate.py | 108 ++++++++++++ .../test_roi/test_polygon_boundary.py | 64 +++++++ 8 files changed, 519 insertions(+) create mode 100644 movement/roi/__init__.py create mode 100644 movement/roi/base.py create mode 100644 movement/roi/line.py create mode 100644 movement/roi/polygon.py create mode 100644 tests/test_unit/test_roi/conftest.py create mode 100644 tests/test_unit/test_roi/test_instantiate.py create mode 100644 tests/test_unit/test_roi/test_polygon_boundary.py diff --git a/movement/roi/__init__.py b/movement/roi/__init__.py new file mode 100644 index 00000000..35657bc1 --- /dev/null +++ b/movement/roi/__init__.py @@ -0,0 +1,2 @@ +from movement.roi.line import LineOfInterest +from movement.roi.polygon import PolygonOfInterest diff --git a/movement/roi/base.py b/movement/roi/base.py new file mode 100644 index 00000000..d91b8aac --- /dev/null +++ b/movement/roi/base.py @@ -0,0 +1,160 @@ +"""Class for representing 1- or 2-dimensional regions of interest (RoIs).""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Literal, TypeAlias + +import shapely +from shapely.coords import CoordinateSequence + +from movement.utils.logging import log_error + +LineLike: TypeAlias = shapely.LinearRing | shapely.LineString +PointLike: TypeAlias = tuple[float, float] +PointLikeList: TypeAlias = Sequence[PointLike] +RegionLike: TypeAlias = shapely.Polygon +SupportedGeometry: TypeAlias = LineLike | RegionLike + + +class BaseRegionOfInterest: + """Base class for representing regions of interest (RoIs). + + Regions of interest can be either 1 or 2 dimensional, and are represented + by appropriate ``shapely.Geometry`` objects depending on which. Note that + there are a number of discussions concerning subclassing ``shapely`` + objects; + + - https://github.com/shapely/shapely/issues/1233. + - https://stackoverflow.com/questions/10788976/how-do-i-properly-inherit-from-a-superclass-that-has-a-new-method + + To avoid the complexities of subclassing ourselves, we simply elect to wrap + the appropriate ``shapely`` object in the ``_shapely_geometry`` attribute, + accessible via the property ``region``. This also has the benefit of + allowing us to 'forbid' certain operations (that ``shapely`` would + otherwise interpret in a set-theoretic sense, giving confusing answers to + users). + + This class is not designed to be instantiated directly. It can be + instantiated, however its primary purpose is to reduce code duplication. + """ + + __default_name: str = "Un-named region" + + _name: str | None + _shapely_geometry: SupportedGeometry + + @property + def coords(self) -> CoordinateSequence: + """Coordinates of the points that define the region. + + These are the points passed to the constructor argument ``points``. + + Note that for Polygonal regions, these are the coordinates of the + exterior boundary, interior boundaries must be accessed via + ``self.region.interior.coords``. + """ + return ( + self.region.coords + if self.dimensions < 2 + else self.region.exterior.coords + ) + + @property + def dimensions(self) -> int: + """Dimensionality of the region.""" + return shapely.get_dimensions(self.region) + + @property + def is_closed(self) -> bool: + """Return True if the region is closed. + + A closed region is either: + - A polygon (2D RoI). + - A 1D LoI whose final point connects back to its first. + """ + return self.dimensions > 1 or ( + self.dimensions == 1 + and self.region.coords[0] == self.region.coords[-1] + ) + + @property + def name(self) -> str: + """Name of the instance.""" + return self._name if self._name else self.__default_name + + @property + def region(self) -> SupportedGeometry: + """``shapely.Geometry`` representation of the region.""" + return self._shapely_geometry + + def __init__( + self, + points: PointLikeList, + dimensions: Literal[1, 2] = 2, + closed: bool = False, + holes: Sequence[PointLikeList] | None = None, + name: str | None = None, + ) -> None: + """Initialise a region of interest. + + Parameters + ---------- + points : Sequence of (x, y) values + Sequence of (x, y) coordinate pairs that will form the region. + dimensions : Literal[1, 2], default 2 + The dimensionality of the region to construct. + '1' creates a sequence of joined line segments, + '2' creates a polygon whose boundary is defined by ``points``. + closed : bool, default False + Whether the line to be created should be closed. That is, whether + the final point should also link to the first point. + Ignored if ``dimensions`` is 2. + holes : sequence of sequences of (x, y) pairs, default None + A sequence of items, where each item will be interpreted like + ``points``. These items will be used to construct internal holes + within the region. See the ``holes`` argument to + ``shapely.Polygon`` for details. Ignored if ``dimensions`` is 1. + name : str, default None + Human-readable name to assign to the given region, for + user-friendliness. Default name given is 'Un-named region' if no + explicit name is provided. + + """ + self._name = name + if len(points) < dimensions + 1: + raise log_error( + ValueError, + f"Need at least {dimensions + 1} points to define a " + f"{dimensions}D region (got {len(points)}).", + ) + elif dimensions < 1 or dimensions > 2: + raise log_error( + ValueError, + "Only regions of interest of dimension 1 or 2 are supported " + f"(requested {dimensions})", + ) + elif dimensions == 1 and len(points) < 3 and closed: + raise log_error( + ValueError, + "Cannot create a loop from a single line segment.", + ) + if dimensions == 2: + self._shapely_geometry = shapely.Polygon(shell=points, holes=holes) + else: + self._shapely_geometry = ( + shapely.LinearRing(coordinates=points) + if closed + else shapely.LineString(coordinates=points) + ) + + def __repr__(self) -> str: # noqa: D105 + return str(self) + + def __str__(self) -> str: # noqa: D105 + display_type = "-gon" if self.dimensions > 1 else " line segment(s)" + n_points = len(self.coords) - 1 + return ( + f"{self.__class__.__name__} {self.name} " + f"({n_points}{display_type})\n" + ) + " -> ".join(f"({c[0]}, {c[1]})" for c in self.coords) diff --git a/movement/roi/line.py b/movement/roi/line.py new file mode 100644 index 00000000..5359c830 --- /dev/null +++ b/movement/roi/line.py @@ -0,0 +1,58 @@ +"""1-dimensional lines of interest.""" + +from movement.roi.base import BaseRegionOfInterest, PointLikeList + + +class LineOfInterest(BaseRegionOfInterest): + """Representation of boundaries or other lines of interest. + + This class can be used to represent boundaries or other internal divisions + of the area in which the experimental data was gathered. These might + include segments of a wall that are removed partway through a behavioural + study, or coloured marking on the floor of the experimental enclosure that + have some significance. Instances of this class also constitute the + boundary of two-dimensional regions (polygons) of interest. + + An instance of this class can be used to represent these "one dimensional + regions" (lines of interest, LoIs) in an analysis. The basic usage is to + construct an instance of this class by passing in a list of points, which + will then be joined (in sequence) by straight lines between consecutive + pairs of points, to form the LoI that is to be studied. + """ + + def __init__( + self, + points: PointLikeList, + loop: bool = False, + name: str | None = None, + ) -> None: + """Create a new line of interest (LoI). + + Parameters + ---------- + points : tuple of (x, y) pairs + The points (in sequence) that make up the line segment. At least + two points must be provided. + loop : bool, default False + If True, the final point in ``points`` will be connected by an + additional line segment to the first, creating a closed loop. + (See Notes). + name : str, optional + Name of the LoI that is to be created. A default name will be + inherited from the base class if not provided, and + defaults are inherited from. + + Notes + ----- + The constructor supports 'rings' or 'closed loops' via the ``loop`` + argument. However, if you want to define an enclosed region for your + analysis, we recommend you create a ``PolygonOfInterest`` and use + its ``boundary`` property instead. + + See Also + -------- + movement.roi.base.BaseRegionOfInterest + The base class that constructor arguments are passed to. + + """ + super().__init__(points, dimensions=1, closed=loop, name=name) diff --git a/movement/roi/polygon.py b/movement/roi/polygon.py new file mode 100644 index 00000000..b3c8e4c8 --- /dev/null +++ b/movement/roi/polygon.py @@ -0,0 +1,105 @@ +"""2-dimensional regions of interest.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from movement.roi.base import BaseRegionOfInterest, PointLikeList +from movement.roi.line import LineOfInterest + + +class PolygonOfInterest(BaseRegionOfInterest): + """Representation of a two-dimensional region in the x-y plane. + + This class can be used to represent polygonal regions or subregions + of the area in which the experimental data was gathered. These might + include the arms of a maze, a nesting area, a food source, or other + similar areas of the experimental enclosure that have some significance. + + An instance of this class can be used to represent these regions of + interest (RoIs) in an analysis. The basic usage is to construct an + instance of this class by passing in a list of points, which will then be + joined (in sequence) by straight lines between consecutive pairs of points, + to form the exterior boundary of the RoI. Note that the exterior boundary + (accessible as via the ``.exterior`` property) is a (closed) + ``LineOfInterest``, and may be treated accordingly. + + The class also supports holes - subregions properly contained inside the + region that are not part of the region itself. These can be specified by + the ``holes`` argument, and define the interior boundaries of the region. + These interior boundaries are accessible via the ``.interior_boundaries`` + property, and the polygonal regions that make up the holes are accessible + via the ``holes`` property. + """ + + def __init__( + self, + exterior_boundary: PointLikeList, + holes: Sequence[PointLikeList] | None = None, + name: str | None = None, + ) -> None: + """Create a new region of interest (RoI). + + Parameters + ---------- + exterior_boundary : tuple of (x, y) pairs + The points (in sequence) that make up the boundary of the region. + At least three points must be provided. + holes : sequence of sequences of (x, y) pairs, default None + A sequence of items, where each item will be interpreted as the + ``exterior_boundary`` of an internal hole within the region. See + the ``holes`` argument to ``shapely.Polygon`` for details. + name : str, optional + Name of the RoI that is to be created. A default name will be + inherited from the base class if not provided. + + See Also + -------- + movement.roi.base.BaseRegionOfInterest : The base class that + constructor arguments are passed to, and defaults are inherited + from. + + """ + super().__init__( + points=exterior_boundary, dimensions=2, holes=holes, name=name + ) + + @property + def exterior_boundary(self) -> LineOfInterest: + """The exterior boundary of this RoI.""" + return LineOfInterest( + self.region.exterior.coords, + loop=True, + name=f"Exterior boundary of {self.name}", + ) + + @property + def holes(self) -> tuple[PolygonOfInterest, ...]: + """The interior holes of this RoI. + + Holes are regions properly contained within the exterior boundary of + the RoI that are not part of the RoI itself (like the centre of a + doughnut, for example). A region with no holes returns the empty tuple. + """ + return tuple( + PolygonOfInterest( + int_boundary.coords, name=f"Hole {i} of {self.name}" + ) + for i, int_boundary in enumerate(self.region.interiors) + ) + + @property + def interior_boundaries(self) -> tuple[LineOfInterest, ...]: + """The interior boundaries of this RoI. + + Interior boundaries are the boundaries of holes contained within the + polygon. A region with no holes returns the empty tuple. + """ + return tuple( + LineOfInterest( + int_boundary.coords, + loop=True, + name=f"Interior boundary {i} of {self.name}", + ) + for i, int_boundary in enumerate(self.region.interiors) + ) diff --git a/pyproject.toml b/pyproject.toml index 27348c29..911a4e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "attrs", "pooch", "tqdm", + "shapely", "sleap-io", "xarray[accel,viz]", "PyYAML", diff --git a/tests/test_unit/test_roi/conftest.py b/tests/test_unit/test_roi/conftest.py new file mode 100644 index 00000000..8ba15cb9 --- /dev/null +++ b/tests/test_unit/test_roi/conftest.py @@ -0,0 +1,21 @@ +import numpy as np +import pytest + + +@pytest.fixture() +def unit_square_pts() -> np.ndarray: + return np.array( + [ + [0.0, 0.0], + [1.0, 0.0], + [1.0, 1.0], + [0.0, 1.0], + ], + dtype=float, + ) + + +@pytest.fixture() +def unit_square_hole(unit_square_pts: np.ndarray) -> np.ndarray: + """Hole in the shape of a 0.5 side-length square centred on (0.5, 0.5).""" + return 0.25 + (unit_square_pts.copy() * 0.5) diff --git a/tests/test_unit/test_roi/test_instantiate.py b/tests/test_unit/test_roi/test_instantiate.py new file mode 100644 index 00000000..4d8f7cb1 --- /dev/null +++ b/tests/test_unit/test_roi/test_instantiate.py @@ -0,0 +1,108 @@ +import re +from typing import Any + +import numpy as np +import pytest +import shapely + +from movement.roi.base import BaseRegionOfInterest + + +@pytest.mark.parametrize( + ["input_pts", "kwargs_for_creation", "expected_results"], + [ + pytest.param( + "unit_square_pts", + {"dimensions": 2, "closed": False}, + {"is_closed": True, "dimensions": 2, "name": "Un-named region"}, + id="Polygon, closed is ignored", + ), + pytest.param( + "unit_square_pts", + {"dimensions": 1, "closed": False}, + {"is_closed": False, "dimensions": 1}, + id="Line segment(s)", + ), + pytest.param( + "unit_square_pts", + {"dimensions": 1, "closed": True}, + {"is_closed": True, "dimensions": 1}, + id="Looped lines", + ), + pytest.param( + "unit_square_pts", + {"dimensions": 2, "name": "elephant"}, + {"is_closed": True, "dimensions": 2, "name": "elephant"}, + id="Explicit name", + ), + pytest.param( + np.array([[0.0, 0.0], [1.0, 0.0]]), + {"dimensions": 2}, + ValueError("Need at least 3 points to define a 2D region (got 2)"), + id="Too few points (2D)", + ), + pytest.param( + np.array([[0.0, 0.0]]), + {"dimensions": 1}, + ValueError("Need at least 2 points to define a 1D region (got 1)"), + id="Too few points (1D)", + ), + pytest.param( + np.array([[0.0, 0.0], [1.0, 0.0]]), + {"dimensions": 1}, + {"is_closed": False}, + id="Borderline enough points (1D)", + ), + pytest.param( + np.array([[0.0, 0.0], [1.0, 0.0]]), + {"dimensions": 1, "closed": True}, + ValueError("Cannot create a loop from a single line segment."), + id="Cannot close single line segment.", + ), + pytest.param( + "unit_square_pts", + {"dimensions": 3, "closed": False}, + ValueError( + "Only regions of interest of dimension 1 or 2 " + "are supported (requested 3)" + ), + id="Bad dimensionality", + ), + ], +) +def test_creation( + input_pts, + kwargs_for_creation: dict[str, Any], + expected_results: dict[str, Any] | Exception, + request, +) -> None: + if isinstance(input_pts, str): + input_pts = request.getfixturevalue(input_pts) + + if isinstance(expected_results, Exception): + with pytest.raises( + type(expected_results), match=re.escape(str(expected_results)) + ): + BaseRegionOfInterest(input_pts, **kwargs_for_creation) + else: + roi = BaseRegionOfInterest(input_pts, **kwargs_for_creation) + + expected_dim = kwargs_for_creation.pop("dimensions", 2) + expected_closure = kwargs_for_creation.pop("closed", False) + if expected_dim == 2: + assert isinstance(roi.region, shapely.Polygon) + assert len(roi.coords) == len(input_pts) + 1 + string_should_contain = "-gon" + elif expected_closure: + assert isinstance(roi.region, shapely.LinearRing) + assert len(roi.coords) == len(input_pts) + 1 + string_should_contain = "line segment(s)" + else: + assert isinstance(roi.region, shapely.LineString) + assert len(roi.coords) == len(input_pts) + string_should_contain = "line segment(s)" + assert string_should_contain in roi.__str__() + assert string_should_contain in roi.__repr__() + + for attribute_name, expected_value in expected_results.items(): + assert getattr(roi, attribute_name) == expected_value diff --git a/tests/test_unit/test_roi/test_polygon_boundary.py b/tests/test_unit/test_roi/test_polygon_boundary.py new file mode 100644 index 00000000..0ecd1ee0 --- /dev/null +++ b/tests/test_unit/test_roi/test_polygon_boundary.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest +import shapely + +from movement.roi.line import LineOfInterest +from movement.roi.polygon import PolygonOfInterest + + +@pytest.mark.parametrize( + ["exterior_boundary", "interior_boundaries"], + [ + pytest.param("unit_square_pts", tuple(), id="No holes"), + pytest.param( + "unit_square_pts", tuple(["unit_square_hole"]), id="One hole" + ), + pytest.param( + "unit_square_pts", + ( + np.array([[0.0, 0.0], [0.25, 0.0], [0.0, 0.25]]), + np.array([[0.75, 0.0], [1.0, 0.0], [1.0, 0.25]]), + ), + id="Corners shaved off", + ), + ], +) +def test_boundary(exterior_boundary, interior_boundaries, request) -> None: + if isinstance(exterior_boundary, str): + exterior_boundary = request.getfixturevalue(exterior_boundary) + interior_boundaries = tuple( + request.getfixturevalue(ib) if isinstance(ib, str) else ib + for ib in interior_boundaries + ) + tolerance = 1.0e-8 + + polygon = PolygonOfInterest( + exterior_boundary, holes=interior_boundaries, name="Holey" + ) + expected_exterior = shapely.LinearRing(exterior_boundary) + expected_interiors = tuple( + shapely.LinearRing(ib) for ib in interior_boundaries + ) + expected_holes = tuple(shapely.Polygon(ib) for ib in interior_boundaries) + + computed_exterior = polygon.exterior_boundary + computed_interiors = polygon.interior_boundaries + computed_holes = polygon.holes + + assert isinstance(computed_exterior, LineOfInterest) + assert expected_exterior.equals_exact(computed_exterior.region, tolerance) + assert isinstance(computed_interiors, tuple) + assert isinstance(computed_holes, tuple) + assert len(computed_interiors) == len(expected_interiors) + assert len(computed_holes) == len(expected_holes) + assert len(computed_holes) == len(computed_interiors) + for i, interior_line in enumerate(computed_interiors): + assert isinstance(interior_line, LineOfInterest) + + assert expected_interiors[i].equals_exact( + interior_line.region, tolerance + ) + for i, interior_hole in enumerate(computed_holes): + assert isinstance(interior_hole, PolygonOfInterest) + + assert expected_holes[i].equals_exact(interior_hole.region, tolerance) From c8d0d3e27925f5dc059c15c24921095f73c9e265 Mon Sep 17 00:00:00 2001 From: Chang Huan Lo Date: Tue, 4 Feb 2025 17:01:05 +0000 Subject: [PATCH 6/6] Use pypa/gh-action-pypi-publish stable release v1 (#404) --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index cd295632..47ff488e 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -78,7 +78,7 @@ jobs: with: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@v1.12.3 + - uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.TWINE_API_KEY }}