Skip to content

Commit 728033f

Browse files
committed
create feature documentation
1 parent 20eea58 commit 728033f

File tree

5 files changed

+143
-56
lines changed

5 files changed

+143
-56
lines changed

docs/extend/features.rst

+62-17
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,81 @@ Extend features
1111
Your new features therefore may not work within different setups. To make your features available to the community and integrated to gaitalytics,
1212
please consider :ref:`contributing <Development>` to the project and our staff will take care of the integration.
1313

14-
Before you start creating new features, you should think about the following questions:
14+
To implement new feature algorithms you need to define a new class inheriting :class:`gaitalytics.features.PointDependentFeature`.
15+
This class internally loops though each cycle and context ('left' and 'right') and calculates the feature for each point in the cycle and calls the :meth:`gaitalytics.features.CycleFeaturesCalculation._calculate` method.
16+
With implementing :meth:`gaitalytics.features.CycleFeaturesCalculation._calculate` in your class you can add new functionalities.
1517

16-
- Will my feature be calculated for each gait cycle?
17-
- Are all the needed markers defined through the :ref:`mappings <Config Mapping>`?
18+
.. code-block:: python
1819
19-
If you have answered all the questions, you can start creating your new feature.
20+
from gaitalytics.features import PointDependentFeature
2021
21-
Class inheritance
22-
-----------------
22+
class NewFeature(PointDependentFeature):
2323
24-
With with your answers in mind, you can now decide which class you want to inherit from. The following classes are available in the library:
24+
def _calculate(self, cycle):
25+
# Your feature calculation here
26+
return feature_value
2527
26-
- :class:`gaitalytics.features.FeatureCalculation`
27-
- :class:`gaitalytics.features.CycleFeaturesCalculation`
28-
- :class:`gaitalytics.features.PointDependentFeature`
28+
Through parameter `cycle` you can access the data of the current cycle and calculate your feature, including the data of the current cycle and the context ('left' or 'right').
29+
As return the framework expects an xarray DataArray object.
2930

30-
The following flowchart will help you to decide which class you should inherit from:
31+
Helping functions
32+
-----------------
3133

32-
.. mermaid::
34+
To help you with the implementation of new features, gaitalytics provides a set of helper functions handle framework specific restrictions.
3335

34-
flowchart LR
35-
A{Per cycle?}---|Yes|B{Mapped markers?}
36-
A---|No|FeatureCalculation
37-
B---|Yes|PointDependentFeature
38-
B---|No|CycleFeaturesCalculation
36+
Markers & Analogs
37+
^^^^^^^^^^^^^^^^^
38+
If your planning to use markers or analogs which are mapped by Gaitalytics (:class:`gaitalytics.mapping.MappedMarkers`) you can use the following helper functions:
39+
:meth:`gaitalytics.features.PointDependentFeature._get_marker_data`. This function returns the data of the mapped marker for the current cycle.
3940

41+
.. hint::
42+
Getting the data for the sacrum marker is handled as a special case. Since Gaitalytics allows either a single sacrum marker or a sacrum marker or two posterior pelvis markers, the method :meth:`gaitalytics.features.PointDependentFeature._get_sacrum_marker` will handle the logic to extract jut a sacrum marker.
4043
..
4144
45+
If you want to use markers or analogs which are not mapped by Gaitalytics, you can find the data in the `cycle` xarray object.
46+
Be aware that approach is not generalized and may not work with different marker models. Therefore, it is recommended to use the mapped markers whenever possible.
47+
Future efforts will be made to generalize this approach.
48+
49+
Event timings
50+
^^^^^^^^^^^^^
51+
Ease the work with event timings in a cycle the :meth:`gaitalytics.features.PointDependentFeature.get_event_times` function can be used to extract the event timings for the current cycle.
52+
53+
54+
Vectors
55+
^^^^^^^
56+
It is often necessary to obtain progression vectors or sagittal plane vectors. To help you with this, gaitalytics provides the following helper functions:
57+
58+
- :meth:`gaitalytics.features.PointDependentFeature._get_progression_vector`
59+
- :meth:`gaitalytics.features.PointDependentFeature._get_sagittal_vector`
60+
61+
Return values
62+
^^^^^^^^^^^^^
63+
The expected return value of the feature calculation is an xarray DataArray object in a specific format.
64+
To help you with the creation of this object, gaitalytics provides the following helper functions:
65+
66+
- :meth:`gaitalytics.features.CycleFeaturesCalculation._create_result_from_dict` to create a DataArray object from a dictionary.
67+
- :meth:`gaitalytics.features.CycleFeaturesCalculation._flatten_features` to flatten an xarray DataArray object.
68+
69+
Including your feature
70+
----------------------
71+
72+
To include your feature in the calculation of the gait metrics, you need to add it to the parameters of your :func:`gaitalytics.api.calculate_features` call.
73+
74+
.. code-block:: python
4275
76+
from gaitalytics import api
77+
from gaitalytics.features import NewFeature
4378
79+
config = api.load_config("./config.yaml")
80+
trial = api.load_c3d_trial("./example.c3d", config)
4481
82+
# Calculate the only the new feature
83+
features = api.calculate_features(trial, config, [NewFeature])
4584
85+
# Calculate all features including the new feature
86+
features = api.calculate_features(trial, config, [gaitalytics.features.TimeSeriesFeatures,
87+
gaitalytics.features.PhaseTimeSeriesFeatures,
88+
gaitalytics.features.TemporalFeatures,
89+
gaitalytics.features.SpatialFeatures,
90+
NewFeature])
4691

gaitalytics/features.py

+46-34
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def _create_result_from_dict(result_dict: dict) -> xr.DataArray:
172172
173173
Returns:
174174
An xarray DataArray containing the data from the dictionary.
175+
:meta public:
175176
"""
176177
xr_dict = {
177178
"coords": {
@@ -193,6 +194,8 @@ def _flatten_features(features: xr.DataArray) -> xr.DataArray:
193194
194195
Returns:
195196
The reshaped features.
197+
198+
:meta public:
196199
"""
197200

198201
np_data = features.to_numpy()
@@ -256,6 +259,8 @@ def _get_progression_vector(self, trial: model.Trial) -> xr.DataArray:
256259
257260
Returns:
258261
An xarray DataArray containing the calculated progression vector.
262+
263+
:meta public:
259264
"""
260265
return mocap.get_progression_vector(trial, self._config)
261266

@@ -268,6 +273,7 @@ def _get_sagittal_vector(self, trial: model.Trial) -> xr.DataArray:
268273
trial: The trial for which to calculate the sagittal vector.
269274
Returns:
270275
An xarray DataArray containing the calculated sagittal vector.
276+
:meta public:
271277
"""
272278
progression_vector = self._get_progression_vector(trial)
273279
vertical_vector = xr.DataArray(
@@ -276,8 +282,6 @@ def _get_sagittal_vector(self, trial: model.Trial) -> xr.DataArray:
276282
return linalg.get_normal_vector(progression_vector, vertical_vector)
277283

278284

279-
280-
281285
class TimeSeriesFeatures(CycleFeaturesCalculation):
282286
"""Calculate time series features for a trial.
283287
@@ -531,7 +535,7 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
531535
Minimal toe clearance: Schulz 2017 (doi: 10.1016/j.jbiomech.2017.02.024)
532536
533537
Args:
534-
trial: The trial for which to calculate the features.
538+
trial: The trial for which to calculate the features.
535539
536540
Returns:
537541
An xarray DataArray containing the calculated features.
@@ -581,7 +585,7 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
581585
trial,
582586
marker_dict["ipsi_heel"], # type: ignore
583587
marker_dict["contra_toe_2"], # type: ignore
584-
marker_dict["xcom"]
588+
marker_dict["xcom"],
585589
)
586590
)
587591

@@ -590,7 +594,7 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
590594
trial,
591595
marker_dict["ipsi_ankle"],
592596
marker_dict["contra_ankle"],
593-
marker_dict["xcom"]
597+
marker_dict["xcom"],
594598
)
595599
)
596600
except KeyError:
@@ -850,12 +854,13 @@ def _find_mtc_index(
850854

851855
return None if not mtc_i else min(mtc_i, key=lambda i: toe_z[i]) # type: ignore
852856

853-
def _calculate_ap_margin_of_stability(self,
854-
trial: model.Trial,
855-
ipsi_heel_marker: mapping.MappedMarkers,
856-
contra_toe_marker: mapping.MappedMarkers,
857-
xcom_marker: mapping.MappedMarkers,
858-
) -> dict[str, np.ndarray]:
857+
def _calculate_ap_margin_of_stability(
858+
self,
859+
trial: model.Trial,
860+
ipsi_heel_marker: mapping.MappedMarkers,
861+
contra_toe_marker: mapping.MappedMarkers,
862+
xcom_marker: mapping.MappedMarkers,
863+
) -> dict[str, np.ndarray]:
859864
"""Calculate the anterio-posterior margin of stability at heel strike. Result should be interpreted according to Curtze et al. (2024)
860865
Args:
861866
trial: The trial for which to calculate the AP margin of stability
@@ -880,30 +885,33 @@ def _calculate_ap_margin_of_stability(self,
880885
xcom = self._get_marker_data(trial, xcom_marker).sel(
881886
time=event_times[0], method="nearest"
882887
)
883-
888+
884889
progress_axis = self._get_progression_vector(trial)
885890
progress_axis = linalg.normalize_vector(progress_axis)
886-
891+
887892
front_marker = linalg.get_point_in_front(ipsi_heel, contra_toe, progress_axis)
888893
back_marker = linalg.get_point_behind(ipsi_heel, contra_toe, progress_axis)
889-
894+
890895
bos_vect = front_marker - back_marker
891896
xcom_vect = xcom - back_marker
892-
897+
893898
bos_proj = abs(linalg.signed_projection_norm(bos_vect, progress_axis))
894899
xcom_proj = linalg.signed_projection_norm(xcom_vect, progress_axis)
895900
mos = bos_proj - xcom_proj
896901

897-
return {"AP_margin_of_stability": mos,
898-
"AP_base_of_support": bos_proj,
899-
"AP_xcom": xcom_proj}
902+
return {
903+
"AP_margin_of_stability": mos,
904+
"AP_base_of_support": bos_proj,
905+
"AP_xcom": xcom_proj,
906+
}
900907

901-
def _calculate_ml_margin_of_stability(self,
902-
trial: model.Trial,
903-
ipsi_ankle_marker: mapping.MappedMarkers,
904-
contra_ankle_marker: mapping.MappedMarkers,
905-
xcom_marker: mapping.MappedMarkers
906-
) -> dict[str, np.ndarray]:
908+
def _calculate_ml_margin_of_stability(
909+
self,
910+
trial: model.Trial,
911+
ipsi_ankle_marker: mapping.MappedMarkers,
912+
contra_ankle_marker: mapping.MappedMarkers,
913+
xcom_marker: mapping.MappedMarkers,
914+
) -> dict[str, np.ndarray]:
907915
"""Calculate the medio-lateral margin of stability at heel strike. Result should be interpreted according to Curtze et al. (2024)
908916
Args:
909917
trial: The trial for which to calculate the ml margin of stability
@@ -927,27 +935,31 @@ def _calculate_ml_margin_of_stability(self,
927935
)
928936
xcom = self._get_marker_data(trial, xcom_marker).sel(
929937
time=event_times[0], method="nearest"
930-
)
938+
)
931939

932940
sagittal_axis = self._get_sagittal_vector(trial)
933941
sagittal_axis = linalg.normalize_vector(sagittal_axis)
934-
942+
935943
if trial.events.attrs["context"] == "Left":
936-
#Rotate sagittal axis so it points towards the left side of the body
944+
# Rotate sagittal axis so it points towards the left side of the body
937945
sagittal_axis = -sagittal_axis
938-
946+
939947
# Lateral is the furthest point in the direction of the sagittal axis
940-
lateral_point = linalg.get_point_in_front(ipsi_ankle, contra_ankle, sagittal_axis)
948+
lateral_point = linalg.get_point_in_front(
949+
ipsi_ankle, contra_ankle, sagittal_axis
950+
)
941951
# Medial is the closest point in the direction of the sagittal axis
942952
medial_point = linalg.get_point_behind(ipsi_ankle, contra_ankle, sagittal_axis)
943953

944954
bos_vect = lateral_point - medial_point
945955
xcom_vect = xcom - medial_point
946-
956+
947957
bos_proj = abs(linalg.signed_projection_norm(bos_vect, sagittal_axis))
948958
xcom_proj = linalg.signed_projection_norm(xcom_vect, sagittal_axis)
949959
mos = bos_proj - xcom_proj
950-
951-
return {"ML_margin_of_stability": mos,
952-
"ML_base_of_support": bos_proj,
953-
"ML_xcom": xcom_proj}
960+
961+
return {
962+
"ML_margin_of_stability": mos,
963+
"ML_base_of_support": bos_proj,
964+
"ML_xcom": xcom_proj,
965+
}

gaitalytics/io.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module provides classes for reading biomechanical file-types."""
2+
23
import math
34
from abc import abstractmethod, ABC
45
from pathlib import Path

gaitalytics/mapping.py

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,26 @@
55

66

77
class MappedMarkers(Enum):
8+
"""This defines the markers that are mapped in the configuration file.
9+
10+
Attributes:
11+
L_HEEL (str): Left heel marker
12+
R_HEEL (str): Right heel marker
13+
L_TOE (str): Left toe marker
14+
R_TOE (str): Right toe marker
15+
L_ANKLE (str): Left ankle marker
16+
R_ANKLE (str): Right ankle marker
17+
L_TOE_2 (str): Left toe 2 marker
18+
R_TOE_2 (str): Right toe 2 marker
19+
L_ANT_HIP (str): Left anterior hip marker
20+
R_ANT_HIP (str): Right anterior hip marker
21+
L_POST_HIP (str): Left posterior hip marker
22+
R_POST_HIP (str): Right posterior hip marker
23+
SACRUM (str): Sacrum marker
24+
XCOM (str): Extrapolated center of mass marker
25+
26+
"""
27+
828
# Foot
929
L_HEEL = "l_heel"
1030
R_HEEL = "r_heel"

gaitalytics/utils/linalg.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def project_point_on_vector(point: xr.DataArray, vector: xr.DataArray) -> xr.Dat
3030
"""
3131
return vector * point.dot(vector, dim="axis")
3232

33+
3334
def signed_projection_norm(vector: xr.DataArray, onto: xr.DataArray) -> xr.DataArray:
3435
"""Compute the signed norm of the projection of a vector onto another vector. <br>
3536
If the projection is in the same direction as the onto vector, the norm is positive. <br>
@@ -98,9 +99,14 @@ def calculate_speed_norm(position: xr.DataArray, dt: float = 0.01) -> np.ndarray
9899
speed_values = np.sqrt(velocity_squared_sum)
99100
speed_values = np.append(speed_values, speed_values[-1])
100101

101-
return xr.DataArray(speed_values, dims=["time"], coords={"time": position.coords["time"]})
102+
return xr.DataArray(
103+
speed_values, dims=["time"], coords={"time": position.coords["time"]}
104+
)
105+
102106

103-
def get_point_in_front(point_a: xr.DataArray, point_b: xr.DataArray, direction_vector: xr.DataArray) -> xr.DataArray:
107+
def get_point_in_front(
108+
point_a: xr.DataArray, point_b: xr.DataArray, direction_vector: xr.DataArray
109+
) -> xr.DataArray:
104110
"""Determine which point is in front of the other according to the direction vector.
105111
106112
Args:
@@ -114,10 +120,13 @@ def get_point_in_front(point_a: xr.DataArray, point_b: xr.DataArray, direction_v
114120
direction_vector = direction_vector / direction_vector.meca.norm(dim="axis")
115121
vector_b_to_a = point_a - point_b
116122
signed_distance = vector_b_to_a.dot(direction_vector, dim="axis")
117-
123+
118124
return point_a if signed_distance > 0 else point_b
119125

120-
def get_point_behind(point_a: xr.DataArray, point_b: xr.DataArray, direction_vector: xr.DataArray) -> xr.DataArray:
126+
127+
def get_point_behind(
128+
point_a: xr.DataArray, point_b: xr.DataArray, direction_vector: xr.DataArray
129+
) -> xr.DataArray:
121130
"""Determine which point is behind the other according to the direction vector.
122131
123132
Args:
@@ -131,5 +140,5 @@ def get_point_behind(point_a: xr.DataArray, point_b: xr.DataArray, direction_vec
131140
direction_vector = direction_vector / direction_vector.meca.norm(dim="axis")
132141
vector_b_to_a = point_a - point_b
133142
signed_distance = vector_b_to_a.dot(direction_vector, dim="axis")
134-
143+
135144
return point_b if signed_distance > 0 else point_a

0 commit comments

Comments
 (0)