Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a generic load_poses.from_file() function #107 #110

Merged
merged 14 commits into from
Feb 19, 2024
Merged
53 changes: 53 additions & 0 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,59 @@
logger = logging.getLogger(__name__)


def from_file(
file_path: Union[Path, str],
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
fps: Optional[float] = None,
) -> xr.Dataset:
"""Load pose tracking data from a DeepLabCut (DLC), LightningPose (LP) or
SLEAP output file into an xarray Dataset.

Parameters
----------
file_path : pathlib.Path or str
Path to the file containing predicted poses. The file format must
be among those supported by the from_dlc_file(), from_slp_file()
or from_lp_file() functions.
source_software : "DeepLabCut", "SLEAP" or "LightningPose"
The source software of the file.
fps : float, optional
The number of frames per second in the video. If None (default),
the `time` coordinates will be in frame numbers.

Returns
-------
xarray.Dataset
Dataset containing the pose tracks, confidence scores, and metadata.

Notes
-----
Identical to calling any of the functions from_dlc_file(),
from_sleap_file() or from_lp_file().

niksirbi marked this conversation as resolved.
Show resolved Hide resolved
See Also
--------
movement.io.load_poses.from_dlc_file : Load pose tracks directly
from DeepLabCut files.
movement.io.load_poses.from_sleap_file : Load pose tracks directly
from SLEAP files.
movement.io.load_poses.from_lp_file : Load pose tracks directly
from LightningPose files.

"""

if source_software == "DeepLabCut":
return from_dlc_file(file_path, fps)
elif source_software == "SLEAP":
return from_sleap_file(file_path, fps)
elif source_software == "LightningPose":
return from_lp_file(file_path, fps)
niksirbi marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(
"Unsupported source software: {}".format(source_software)
)
niksirbi marked this conversation as resolved.
Show resolved Hide resolved


def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset:
"""Create an xarray.Dataset from a DeepLabCut-style pandas DataFrame.

Expand Down
24 changes: 24 additions & 0 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import h5py
import numpy as np
import pytest
Expand Down Expand Up @@ -239,3 +241,25 @@ def test_load_multi_animal_from_lp_file_raises(self):
file_path = POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv")
with pytest.raises(ValueError):
load_poses.from_lp_file(file_path)

@pytest.mark.parametrize(
"source_software", ["SLEAP", "DeepLabCut", "LightningPose", "Unknown"]
)
@pytest.mark.parametrize("fps", [None, 30, 60.0])
def test_from_file_delegates_correctly(self, source_software, fps):
"""Test that the from_file() function delegates to the correct
loader function according to the source_software."""

software_to_loader = {
"SLEAP": "movement.io.load_poses.from_sleap_file",
"DeepLabCut": "movement.io.load_poses.from_dlc_file",
"LightningPose": "movement.io.load_poses.from_lp_file",
}

if source_software == "Unknown":
with pytest.raises(ValueError):
load_poses.from_file("some_file", source_software)
else:
with patch(software_to_loader[source_software]) as mock_loader:
load_poses.from_file("some_file", source_software, fps)
mock_loader.assert_called_with("some_file", fps)
Loading