Skip to content

Commit c4a2de4

Browse files
author
André Böni
committed
+ clean stride length
+ clean documentation
1 parent 08e7f72 commit c4a2de4

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

gaitalytics/features.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import gaitalytics.mapping as mapping
1010
import gaitalytics.model as model
1111
import gaitalytics.utils.linalg as linalg
12-
import gaitalytics.utils.mocap as mocap
1312
import gaitalytics.utils.math as math
13+
import gaitalytics.utils.mocap as mocap
1414

1515

1616
class FeatureCalculation(ABC):
@@ -545,7 +545,7 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
545545
)
546546
if marker_dict["xcom"] is not None:
547547
results_dict.update(
548-
self._calculate_AP_margin_of_stability(
548+
self._calculate_ap_margin_of_stability(
549549
trial,
550550
marker_dict["ipsi_toe_2"],
551551
marker_dict["contra_toe_2"],
@@ -577,7 +577,7 @@ def select_markers_for_spatial_features(
577577
Returns:
578578
A dictionary of markers based on context
579579
"""
580-
if trial.events.attrs["context"] == "Right":
580+
if trial.events is not None and trial.events.attrs["context"] == "Right":
581581
ipsi_heel_marker = mapping.MappedMarkers.R_HEEL
582582
ipsi_toe_2_marker = mapping.MappedMarkers.R_TOE
583583
ipsi_toe_5_marker = self.get_optional_marker("R_TOE_5")
@@ -704,28 +704,32 @@ def _calculate_stride_length(
704704
progress_axis = self._get_progression_vector(trial)
705705
progress_axis = linalg.normalize_vector(progress_axis)
706706

707-
total_distance = 0
707+
total_distance = None
708708

709+
# Add the distance of the ipsi and contra heel step length
709710
for event_time in [event_times[2], event_times[-1]]:
710711
ipsi_heel = self._get_marker_data(trial, ipsi_marker).sel(
711712
time=event_time, method="nearest"
712713
)
713714
contra_heel = self._get_marker_data(trial, contra_marker).sel(
714715
time=event_time, method="nearest"
715716
)
716-
print(f"ipsi heel: {ipsi_heel}")
717717

718718
projected_ipsi = linalg.project_point_on_vector(ipsi_heel, progress_axis)
719719
projected_contra = linalg.project_point_on_vector(
720720
contra_heel, progress_axis
721721
)
722-
print(f"projected ipsi: {projected_ipsi}")
723722

724723
distance = linalg.calculate_distance(
725724
projected_ipsi, projected_contra
726725
).values
727-
print(f"distance: {distance}")
728-
total_distance += distance
726+
if total_distance is None:
727+
total_distance = distance
728+
else:
729+
total_distance += distance
730+
731+
if total_distance is None:
732+
total_distance = np.empty(1)
729733

730734
return {"stride_length": total_distance}
731735

@@ -782,23 +786,21 @@ def _calculate_minimal_toe_clearance(
782786
mtc_values.append(meta_data.sel(axis="z")[i])
783787

784788
if not mtc_values:
785-
return {"minimal_toe_clearance": np.NaN}
789+
return {"minimal_toe_clearance": np.empty(1)}
786790

787791
return {"minimal_toe_clearance": min(mtc_values)}
788792

793+
@staticmethod
789794
def _find_mtc_index(
790-
self,
791-
toe_position: xr.DataArray,
792-
heel_position: xr.DataArray,
793-
toes_vel: xr.DataArray,
795+
toe_position: xr.DataArray, heel_position: xr.DataArray, toes_vel: xr.DataArray
794796
):
795797
"""Find the time corresponding to minimal toe clearance of a specific toe.
796-
Valid minimal toe clearance point must validates conditions
798+
Valid minimal toe clearance point must pass conditions
797799
defined in Schulz 2017 (doi: 10.1016/j.jbiomech.2017.02.024)
798800
Args:
799801
toe_position: A DataArray containing positions of the toe
800802
heel_position: A DataArray containing positions of the heel
801-
toes_vel: A DataArray containing the mean toes velocity at each timepoint
803+
toes_vel: A DataArray containing the mean toes velocity at each point in time
802804
Returns:
803805
The time corresponding to minimal toe clearance for the input toe.
804806
"""
@@ -811,24 +813,24 @@ def _find_mtc_index(
811813
mtc_i = [i for i in mtc_i if toes_vel[i] >= toes_vel_up_quant]
812814
mtc_i = [i for i in mtc_i if toe_z[i] <= heel_z[i]]
813815

814-
return np.NaN if not mtc_i else min(mtc_i, key=lambda i: toe_z[i])
816+
return None if not mtc_i else min(mtc_i, key=lambda i: toe_z[i])
815817

816-
def _calculate_AP_margin_of_stability(
818+
def _calculate_ap_margin_of_stability(
817819
self,
818820
trial: model.Trial,
819821
ipsi_toe_marker: mapping.MappedMarkers,
820822
contra_toe_marker: mapping.MappedMarkers,
821823
xcom_marker: mapping.MappedMarkers,
822824
) -> dict[str, np.ndarray]:
823-
"""Calculate the anterio-posterior margin of stability at heel strike
825+
"""Calculate the anterior-posterior margin of stability at heel strike
824826
Args:
825827
trial: The trial for which to calculate the AP margin of stability
826828
ipsi_toe_marker: The ipsi-lateral toe marker
827-
contra_marker: The contra-lateral toe marker
829+
contra_toe_marker: The contra-lateral toe marker
828830
xcom_marker: The extrapolated center of mass marker
829831
830832
Returns:
831-
The calculated anterio-posterior margin of stability in a dict
833+
The calculated anterior-posterior margin of stability in a dict
832834
"""
833835
event_times = self.get_event_times(trial.events)
834836

@@ -866,8 +868,8 @@ def _calculate_ML_margin_of_stability(
866868
"""Calculate the medio-lateral margin of stability at heel strike
867869
Args:
868870
trial: The trial for which to calculate the AP margin of stability
869-
ipsi_toe_marker: The ipsi-lateral lateral ankle marker
870-
contra_marker: The contra-lateral lateral ankle marker
871+
ipsi_ankle_marker: The ipsi-lateral lateral ankle marker
872+
contra_ankle_marker: The contra-lateral lateral ankle marker
871873
xcom_marker: The extrapolated center of mass marker
872874
873875
Returns:

gaitalytics/utils/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def calculate_speed_norm(position: xr.DataArray, dt: float = 0.01) -> xr.DataArr
6666
Compute the speed from a 3xN position data array obtained with constant sampling rate
6767
6868
Args:
69-
position_data: A 3xN xarray.DataArray where each row corresponds to an axis (x, y, z),
69+
position: A 3xN xarray.DataArray where each row corresponds to an axis (x, y, z),
7070
and each column represents a time point (positions in space).
7171
dt: Time interval between samples.
7272

0 commit comments

Comments
 (0)