Skip to content

Commit b2b45d8

Browse files
author
André Böni
committed
2 parents 7cad4ed + 7f40ef9 commit b2b45d8

File tree

4 files changed

+204
-59
lines changed

4 files changed

+204
-59
lines changed

gaitalytics/features.py

Lines changed: 76 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _get_progression_vector(self, trial: model.Trial) -> xr.DataArray:
256256
def _get_sagittal_vector(self, trial: model.Trial) -> xr.DataArray:
257257
"""Calculate the sagittal vector for a trial.
258258
259-
The sagittal vector is the vector normal to the sagittal plane.
259+
The sagittal vector is the vector normal to the sagittal plane. Note that this vector will always be pointing towards the right side of the body.
260260
261261
Args:
262262
trial: The trial for which to calculate the sagittal vector.
@@ -507,19 +507,23 @@ class SpatialFeatures(_PointDependentFeature):
507507
- step_width
508508
- minimal_toe_clearance
509509
- AP_margin_of_stability
510+
- AP_base_of_support
511+
- AP_xcom
510512
- ML_margin_of_stability
513+
- ML_base_of_support
514+
- ML_xcom
511515
"""
512516

513517
def _calculate(self, trial: model.Trial) -> xr.DataArray:
514518
"""Calculate the spatial features for a trial.
515519
516520
Definitions of the spatial features:
517521
Step length & Step width: Hollmann et al. 2011 (doi: 10.1016/j.gaitpost.2011.03.024)
518-
Margin of stability: Jinfeng et al. 2021 (doi: 10.1152/jn.00091.2021)
522+
Margin of stability: Jinfeng et al. 2021 (doi: 10.1152/jn.00091.2021), Curtze et al. 2024 (doi: 10.1016/j.jbiomech.2024.112045)
519523
Minimal toe clearance: Schulz 2017 (doi: 10.1016/j.jbiomech.2017.02.024)
520524
521525
Args:
522-
trial: The trial for which to calculate the features.
526+
trial: The trial for which to calculate the features.
523527
524528
Returns:
525529
An xarray DataArray containing the calculated features.
@@ -567,9 +571,9 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
567571
results_dict.update(
568572
self._calculate_ap_margin_of_stability(
569573
trial,
570-
marker_dict["ipsi_toe_2"], # type: ignore
574+
marker_dict["ipsi_heel"], # type: ignore
571575
marker_dict["contra_toe_2"], # type: ignore
572-
marker_dict["xcom"],
576+
marker_dict["xcom"]
573577
)
574578
)
575579

@@ -578,7 +582,7 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
578582
trial,
579583
marker_dict["ipsi_ankle"],
580584
marker_dict["contra_ankle"],
581-
marker_dict["xcom"],
585+
marker_dict["xcom"]
582586
)
583587
)
584588
except KeyError:
@@ -704,7 +708,7 @@ def _calculate_stride_length(
704708
ipsi_marker: mapping.MappedMarkers,
705709
contra_marker: mapping.MappedMarkers,
706710
) -> dict[str, np.ndarray]:
707-
"""Calculate the stride length for a trial.
711+
"""Calculate the stride length for a trial. It is computed as the two consecutive step lengths constituting the gait cycle.
708712
709713
Args:
710714
trial: The trial for which to calculate the stride length.
@@ -838,26 +842,28 @@ def _find_mtc_index(
838842

839843
return None if not mtc_i else min(mtc_i, key=lambda i: toe_z[i]) # type: ignore
840844

841-
def _calculate_ap_margin_of_stability(
842-
self,
843-
trial: model.Trial,
844-
ipsi_toe_marker: mapping.MappedMarkers,
845-
contra_toe_marker: mapping.MappedMarkers,
846-
xcom_marker: mapping.MappedMarkers,
847-
) -> dict[str, np.ndarray]:
848-
"""Calculate the anterior-posterior margin of stability at heel strike
845+
def _calculate_ap_margin_of_stability(self,
846+
trial: model.Trial,
847+
ipsi_heel_marker: mapping.MappedMarkers,
848+
contra_toe_marker: mapping.MappedMarkers,
849+
xcom_marker: mapping.MappedMarkers,
850+
) -> dict[str, np.ndarray]:
851+
"""Calculate the anterio-posterior margin of stability at heel strike. Result should be interpreted according to Curtze et al. (2024)
849852
Args:
850853
trial: The trial for which to calculate the AP margin of stability
851-
ipsi_toe_marker: The ipsi-lateral toe marker
852-
contra_toe_marker: The contra-lateral toe marker
854+
ipsi_heel_marker: The ipsi-lateral heel marker
855+
contra_marker: The contra-lateral toe marker
853856
xcom_marker: The extrapolated center of mass marker
854857
855858
Returns:
856-
The calculated anterior-posterior margin of stability in a dict
859+
dict: A dictionary containing:
860+
- "AP_margin_of_stability": The calculated anterio-posterior margin of stability.
861+
- "AP_base_of_support": The calculated anterio-posterior base of support.
862+
- "AP_XCOM": The calculated anterio-posterior position of the extrapolated center of mass relative to the back foot.
857863
"""
858864
event_times = self.get_event_times(trial.events)
859865

860-
ipsi_toe = self._get_marker_data(trial, ipsi_toe_marker).sel(
866+
ipsi_heel = self._get_marker_data(trial, ipsi_heel_marker).sel(
861867
time=event_times[0], method="nearest"
862868
)
863869
contra_toe = self._get_marker_data(trial, contra_toe_marker).sel(
@@ -866,37 +872,42 @@ def _calculate_ap_margin_of_stability(
866872
xcom = self._get_marker_data(trial, xcom_marker).sel(
867873
time=event_times[0], method="nearest"
868874
)
869-
875+
870876
progress_axis = self._get_progression_vector(trial)
871877
progress_axis = linalg.normalize_vector(progress_axis)
872-
873-
projected_ipsi = linalg.project_point_on_vector(ipsi_toe, progress_axis)
874-
projected_contra = linalg.project_point_on_vector(contra_toe, progress_axis)
875-
projected_xcom = linalg.project_point_on_vector(xcom, progress_axis)
876-
877-
bos_len = linalg.calculate_distance(projected_ipsi, projected_contra).values
878-
xcom_len = linalg.calculate_distance(projected_contra, projected_xcom).values
879-
880-
mos = bos_len - xcom_len
881-
882-
return {"AP_margin_of_stability": mos}
883-
884-
def _calculate_ml_margin_of_stability(
885-
self,
886-
trial: model.Trial,
887-
ipsi_ankle_marker: mapping.MappedMarkers,
888-
contra_ankle_marker: mapping.MappedMarkers,
889-
xcom_marker: mapping.MappedMarkers,
890-
) -> dict[str, np.ndarray]:
891-
"""Calculate the medio-lateral margin of stability at heel strike
878+
879+
front_marker = linalg.get_point_in_front(ipsi_heel, contra_toe, progress_axis)
880+
back_marker = linalg.get_point_behind(ipsi_heel, contra_toe, progress_axis)
881+
882+
bos_vect = front_marker - back_marker
883+
xcom_vect = xcom - back_marker
884+
885+
bos_proj = abs(linalg.signed_projection_norm(bos_vect, progress_axis))
886+
xcom_proj = linalg.signed_projection_norm(xcom_vect, progress_axis)
887+
mos = bos_proj - xcom_proj
888+
889+
return {"AP_margin_of_stability": mos,
890+
"AP_base_of_support": bos_proj,
891+
"AP_xcom": xcom_proj}
892+
893+
def _calculate_ml_margin_of_stability(self,
894+
trial: model.Trial,
895+
ipsi_ankle_marker: mapping.MappedMarkers,
896+
contra_ankle_marker: mapping.MappedMarkers,
897+
xcom_marker: mapping.MappedMarkers
898+
) -> dict[str, np.ndarray]:
899+
"""Calculate the medio-lateral margin of stability at heel strike. Result should be interpreted according to Curtze et al. (2024)
892900
Args:
893-
trial: The trial for which to calculate the AP margin of stability
894-
ipsi_ankle_marker: The ipsi-lateral lateral ankle marker
895-
contra_ankle_marker: The contra-lateral lateral ankle marker
901+
trial: The trial for which to calculate the ml margin of stability
902+
ipsi_toe_marker: The ipsi-lateral lateral ankle marker
903+
contra_marker: The contra-lateral lateral ankle marker
896904
xcom_marker: The extrapolated center of mass marker
897905
898906
Returns:
899-
The calculated anterio-posterior margin of stability in a dict
907+
dict: A dictionary containing:
908+
- "ML_margin_of_stability": The calculated medio-lateral margin of stability.
909+
- "ML_base_of_support": The calculated medio-lateral base of support.
910+
- "ML_xcom": The calculated medio-lateral position of the extrapolated center of mass relative to the back foot.
900911
"""
901912
event_times = self.get_event_times(trial.events)
902913

@@ -908,18 +919,27 @@ def _calculate_ml_margin_of_stability(
908919
)
909920
xcom = self._get_marker_data(trial, xcom_marker).sel(
910921
time=event_times[0], method="nearest"
911-
)
922+
)
912923

913924
sagittal_axis = self._get_sagittal_vector(trial)
914925
sagittal_axis = linalg.normalize_vector(sagittal_axis)
915-
916-
projected_ipsi = linalg.project_point_on_vector(ipsi_ankle, sagittal_axis)
917-
projected_contra = linalg.project_point_on_vector(contra_ankle, sagittal_axis)
918-
projected_xcom = linalg.project_point_on_vector(xcom, sagittal_axis)
919-
920-
bos_len = linalg.calculate_distance(projected_contra, projected_ipsi).values
921-
xcom_len = linalg.calculate_distance(projected_contra, projected_xcom).values
922-
923-
mos = bos_len - xcom_len
924-
925-
return {"ML_margin_of_stability": mos}
926+
927+
if trial.events.attrs["context"] == "Left":
928+
#Rotate sagittal axis so it points towards the left side of the body
929+
sagittal_axis = -sagittal_axis
930+
931+
# Lateral is the furthest point in the direction of the sagittal axis
932+
lateral_point = linalg.get_point_in_front(ipsi_ankle, contra_ankle, sagittal_axis)
933+
# Medial is the closest point in the direction of the sagittal axis
934+
medial_point = linalg.get_point_behind(ipsi_ankle, contra_ankle, sagittal_axis)
935+
936+
bos_vect = lateral_point - medial_point
937+
xcom_vect = xcom - medial_point
938+
939+
bos_proj = abs(linalg.signed_projection_norm(bos_vect, sagittal_axis))
940+
xcom_proj = linalg.signed_projection_norm(xcom_vect, sagittal_axis)
941+
mos = bos_proj - xcom_proj
942+
943+
return {"ML_margin_of_stability": mos,
944+
"ML_base_of_support": bos_proj,
945+
"ML_xcom": xcom_proj}

gaitalytics/io.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""This module provides classes for reading biomechanical file-types."""
2-
32
import math
43
from abc import abstractmethod, ABC
54
from pathlib import Path

gaitalytics/utils/linalg.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,26 @@ def project_point_on_vector(point: xr.DataArray, vector: xr.DataArray) -> xr.Dat
3030
"""
3131
return vector * point.dot(vector, dim="axis")
3232

33+
def signed_projection_norm(vector: xr.DataArray, onto: xr.DataArray) -> xr.DataArray:
34+
"""Compute the signed norm of the projection of a vector onto another vector. <br>
35+
If the projection is in the same direction as the onto vector, the norm is positive. <br>
36+
If the projection is in the opposite direction, the norm is negative. <br>
3337
34-
def get_normal_vector(vector1: xr.DataArray, vector2: xr.DataArray):
38+
Args:
39+
vector: The vector to be projected.
40+
onto: The vector to project onto.
41+
42+
Returns:
43+
An xarray DataArray containing the signed norm of the projected vector.
44+
"""
45+
projection = onto * vector.dot(onto, dim="axis") / onto.dot(onto, dim="axis")
46+
projection_norm = projection.meca.norm(dim="axis")
47+
sign = xr.where(vector.dot(onto, dim="axis") > 0, 1, -1)
48+
sign = xr.where(vector.dot(onto, dim="axis") == 0, 0, sign)
49+
return projection_norm * sign
50+
51+
52+
def get_normal_vector(vector1: xr.DataArray, vector2: xr.DataArray) -> xr.DataArray:
3553
"""Create a vector with norm = 1 normal to two other vectors.
3654
3755
Args:
@@ -80,4 +98,38 @@ def calculate_speed_norm(position: xr.DataArray, dt: float = 0.01) -> np.ndarray
8098
speed_values = np.sqrt(velocity_squared_sum)
8199
speed_values = np.append(speed_values, speed_values[-1])
82100

83-
return speed_values
101+
return xr.DataArray(speed_values, dims=["time"], coords={"time": position.coords["time"]})
102+
103+
def get_point_in_front(point_a: xr.DataArray, point_b: xr.DataArray, direction_vector: xr.DataArray) -> xr.DataArray:
104+
"""Determine which point is in front of the other according to the direction vector.
105+
106+
Args:
107+
point_a: The first point.
108+
point_b: The second point.
109+
direction_vector: The direction vector.
110+
111+
Returns:
112+
The point that is in front according to the direction vector.
113+
"""
114+
direction_vector = direction_vector / direction_vector.meca.norm(dim="axis")
115+
vector_b_to_a = point_a - point_b
116+
signed_distance = vector_b_to_a.dot(direction_vector, dim="axis")
117+
118+
return point_a if signed_distance > 0 else point_b
119+
120+
def get_point_behind(point_a: xr.DataArray, point_b: xr.DataArray, direction_vector: xr.DataArray) -> xr.DataArray:
121+
"""Determine which point is behind the other according to the direction vector.
122+
123+
Args:
124+
point_a: The first point.
125+
point_b: The second point.
126+
direction_vector: The direction vector.
127+
128+
Returns:
129+
The point that is behind according to the direction vector.
130+
"""
131+
direction_vector = direction_vector / direction_vector.meca.norm(dim="axis")
132+
vector_b_to_a = point_a - point_b
133+
signed_distance = vector_b_to_a.dot(direction_vector, dim="axis")
134+
135+
return point_b if signed_distance > 0 else point_a

tests/full/test_linalg.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
import xarray as xr
3+
import numpy as np
4+
5+
from gaitalytics.utils.linalg import (
6+
calculate_distance,
7+
project_point_on_vector,
8+
signed_projection_norm,
9+
get_normal_vector,
10+
normalize_vector,
11+
calculate_speed_norm,
12+
get_point_in_front,
13+
get_point_behind
14+
)
15+
16+
@pytest.fixture
17+
def sample_data():
18+
point_a = xr.DataArray([1, 2, 3], dims=["axis"], coords={"axis": ["x", "y", "z"]})
19+
point_b = xr.DataArray([4, 5, 6], dims=["axis"], coords={"axis": ["x", "y", "z"]})
20+
vector_a = xr.DataArray([1, 0, 0], dims=["axis"], coords={"axis": ["x", "y", "z"]})
21+
vector_b = xr.DataArray([0, 1, 0], dims=["axis"], coords={"axis": ["x", "y", "z"]})
22+
return point_a, point_b, vector_a, vector_b
23+
24+
def test_calculate_distance(sample_data):
25+
point_a, point_b, _, _ = sample_data
26+
distance = calculate_distance(point_a, point_b)
27+
expected_distance = np.sqrt(27)
28+
assert distance == pytest.approx(expected_distance)
29+
30+
def test_project_point_on_vector(sample_data):
31+
point_a, _, vector_a, _ = sample_data
32+
projected_point = project_point_on_vector(point_a, vector_a)
33+
expected_projection = xr.DataArray([1, 0, 0], dims=["axis"], coords={"axis": ["x", "y", "z"]})
34+
xr.testing.assert_allclose(projected_point, expected_projection)
35+
36+
def test_signed_projection_norm(sample_data):
37+
_, _, vector_a, vector_b = sample_data
38+
signed_norm = signed_projection_norm(vector_a, vector_b)
39+
expected_signed_norm = 0.0
40+
assert signed_norm == pytest.approx(expected_signed_norm)
41+
42+
def test_get_normal_vector(sample_data):
43+
_, _, vector_a, vector_b = sample_data
44+
normal_vector = get_normal_vector(vector_a, vector_b)
45+
expected_normal_vector = xr.DataArray([0, 0, 1], dims=["axis"], coords={"axis": ["x", "y", "z"]})
46+
xr.testing.assert_allclose(normal_vector, expected_normal_vector)
47+
48+
def test_normalize_vector(sample_data):
49+
_, _, vector_a, _ = sample_data
50+
normalized_vector = normalize_vector(vector_a)
51+
expected_normalized_vector = xr.DataArray([1, 0, 0], dims=["axis"], coords={"axis": ["x", "y", "z"]})
52+
xr.testing.assert_allclose(normalized_vector, expected_normalized_vector)
53+
54+
def test_calculate_speed_norm():
55+
position = xr.DataArray(
56+
np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]),
57+
dims=["axis", "time"],
58+
coords={"axis": ["x", "y", "z"], "time": [0, 1, 2]}
59+
)
60+
speed = calculate_speed_norm(position, dt=1.0)
61+
expected_speed = xr.DataArray([np.sqrt(3), np.sqrt(3), np.sqrt(3)], dims=["time"], coords={"time": [0, 1, 2]})
62+
xr.testing.assert_allclose(speed, expected_speed)
63+
64+
def test_get_point_in_front(sample_data):
65+
point_a, point_b, vector_a, _ = sample_data
66+
point_in_front = get_point_in_front(point_a, point_b, vector_a)
67+
expected_point_in_front = point_b
68+
xr.testing.assert_allclose(point_in_front, expected_point_in_front)
69+
70+
def test_get_point_behind(sample_data):
71+
point_a, point_b, vector_a, _ = sample_data
72+
point_behind = get_point_behind(point_a, point_b, vector_a)
73+
expected_point_behind = point_a
74+
xr.testing.assert_allclose(point_behind, expected_point_behind)

0 commit comments

Comments
 (0)