Skip to content

Commit ed54b61

Browse files
author
André Böni
committed
+ open events api for external methods
1 parent 3bb8da9 commit ed54b61

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

gaitalytics/api.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,17 @@ def load_c3d_trial(
9393

9494

9595
def detect_events(
96-
trial: model.Trial, config: mapping.MappingConfigs, method: str = "Marker", **kwargs
96+
trial: model.Trial,
97+
config: mapping.MappingConfigs,
98+
method: type[events.BaseEventDetection] = events.MarkerEventDetection,
99+
**kwargs,
97100
) -> pd.DataFrame:
98101
"""Detects the events in the trial.
99102
100103
Args:
101104
trial: The trial to detect the events for.
102105
config: The mapping configurations
103-
method: The method to use for detecting the events.
106+
method: The class to use for detecting the events.
104107
Currently, only "Marker" is supported, which implements
105108
the method from Zenis et al. 2008.
106109
Default is "Marker".
@@ -113,15 +116,9 @@ def detect_events(
113116
Returns:
114117
A DataFrame containing the detected events.
115118
116-
Raises:
117-
ValueError: If the method is not supported.
118119
"""
119120

120-
match method:
121-
case "Marker":
122-
method_obj = events.MarkerEventDetection(config, **kwargs)
123-
case _:
124-
raise ValueError(f"Unsupported method: {method}")
121+
method_obj = method(config, **kwargs)
125122

126123
event_table = method_obj.detect_events(trial)
127124
return event_table

gaitalytics/events.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _check_contexts(self, events: pd.DataFrame) -> list[tuple]:
143143
return incorrect_times
144144

145145

146-
class _BaseEventDetection(ABC):
146+
class BaseEventDetection(ABC):
147147
"""Abstract class for event detectors.
148148
149149
This class provides a common interface for detecting events in a trial,
@@ -171,7 +171,7 @@ def detect_events(self, trial: model.Trial) -> pd.DataFrame:
171171
raise NotImplementedError
172172

173173

174-
class MarkerEventDetection(_BaseEventDetection):
174+
class MarkerEventDetection(BaseEventDetection):
175175
"""A class for detecting events using marker data.
176176
177177
This class provides a method to detect events using marker data in a trial.

tests/test_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import pytest
66

7+
import gaitalytics.events as events
78
import gaitalytics.api as api
89
import gaitalytics.mapping as mapping
910
import gaitalytics.model as model
@@ -61,8 +62,8 @@ def test_detect_events():
6162
def test_detect_events_methode():
6263
config = api.load_config("./tests/pig_config.yaml")
6364
trial = model.Trial()
64-
with pytest.raises(ValueError):
65-
api.detect_events(trial, config, method="ForcePlate")
65+
with pytest.raises(TypeError):
66+
api.detect_events(trial, config, method= events.BaseEventDetection)
6667

6768

6869
def test_check_events():

0 commit comments

Comments
 (0)