Skip to content

Commit 08e7f72

Browse files
author
André Böni
committed
+ clean code
+ fix tests
1 parent 1f41e7a commit 08e7f72

File tree

6 files changed

+177
-134
lines changed

6 files changed

+177
-134
lines changed

gaitalytics/features.py

Lines changed: 134 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _get_progression_vector(self, trial: model.Trial) -> xr.DataArray:
247247
An xarray DataArray containing the calculated progression vector.
248248
"""
249249
return mocap.get_progression_vector(trial, self._config)
250-
250+
251251
def _get_sagittal_vector(self, trial: model.Trial) -> xr.DataArray:
252252
"""Calculate the sagittal vector for a trial.
253253
@@ -259,9 +259,12 @@ def _get_sagittal_vector(self, trial: model.Trial) -> xr.DataArray:
259259
An xarray DataArray containing the calculated sagittal vector.
260260
"""
261261
progression_vector = self._get_progression_vector(trial)
262-
vertical_vector = xr.DataArray([0,0,1], dims=['axis'], coords={'axis': ['x', 'y', 'z']})
262+
vertical_vector = xr.DataArray(
263+
[0, 0, 1], dims=["axis"], coords={"axis": ["x", "y", "z"]}
264+
)
263265
return linalg.get_normal_vector(progression_vector, vertical_vector)
264266

267+
265268
class TimeSeriesFeatures(_CycleFeaturesCalculation):
266269
"""Calculate time series features for a trial.
267270
@@ -497,7 +500,6 @@ class SpatialFeatures(_PointDependentFeature):
497500
"""
498501

499502
def _calculate(self, trial: model.Trial) -> xr.DataArray:
500-
501503
"""Calculate the spatial features for a trial.
502504
503505
Definitions of the spatial features:
@@ -519,32 +521,54 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
519521
raise ValueError("Trial does not have events.")
520522

521523
marker_dict = self.select_markers_for_spatial_features(trial)
522-
523-
results_dict = self._calculate_step_length(trial, marker_dict["ipsi_heel"], marker_dict["contra_heel"])
524+
525+
results_dict = self._calculate_step_length(
526+
trial, marker_dict["ipsi_heel"], marker_dict["contra_heel"]
527+
)
524528
results_dict.update(
525-
self._calculate_step_width(trial, marker_dict["ipsi_heel"], marker_dict["contra_heel"])
529+
self._calculate_step_width(
530+
trial, marker_dict["ipsi_heel"], marker_dict["contra_heel"]
531+
)
526532
)
527533
results_dict.update(
528-
self._calculate_stride_length(trial, marker_dict["ipsi_heel"], marker_dict["contra_heel"])
534+
self._calculate_stride_length(
535+
trial, marker_dict["ipsi_heel"], marker_dict["contra_heel"]
536+
)
529537
)
530538
results_dict.update(
531-
self._calculate_minimal_toe_clearance(trial, marker_dict["ipsi_toe_2"], marker_dict["ipsi_heel"], marker_dict["ipsi_toe_5"])
539+
self._calculate_minimal_toe_clearance(
540+
trial,
541+
marker_dict["ipsi_toe_2"],
542+
marker_dict["ipsi_heel"],
543+
marker_dict["ipsi_toe_5"],
544+
)
532545
)
533546
if marker_dict["xcom"] is not None:
534547
results_dict.update(
535-
self._calculate_AP_margin_of_stability(trial, marker_dict["ipsi_toe_2"], marker_dict["contra_toe_2"], marker_dict["xcom"])
548+
self._calculate_AP_margin_of_stability(
549+
trial,
550+
marker_dict["ipsi_toe_2"],
551+
marker_dict["contra_toe_2"],
552+
marker_dict["xcom"],
553+
)
536554
)
537-
if (marker_dict['ipsi_ankle'] is not None) and (marker_dict['contra_ankle'] is not None):
555+
if (marker_dict["ipsi_ankle"] is not None) and (
556+
marker_dict["contra_ankle"] is not None
557+
):
538558
results_dict.update(
539-
self._calculate_ML_margin_of_stability(trial, marker_dict['ipsi_ankle'], marker_dict['contra_ankle'], marker_dict["xcom"])
559+
self._calculate_ML_margin_of_stability(
560+
trial,
561+
marker_dict["ipsi_ankle"],
562+
marker_dict["contra_ankle"],
563+
marker_dict["xcom"],
564+
)
540565
)
541-
566+
542567
return self._create_result_from_dict(results_dict)
543-
544-
545-
def select_markers_for_spatial_features(self,
546-
trial: model.Trial
547-
) -> dict[str, mapping.MappedMarkers]:
568+
569+
def select_markers_for_spatial_features(
570+
self, trial: model.Trial
571+
) -> dict[str, mapping.MappedMarkers]:
548572
"""Select markers based on the trial's context (Right or Left). If some markers are missing, return them as None
549573
550574
Args:
@@ -556,25 +580,25 @@ def select_markers_for_spatial_features(self,
556580
if trial.events.attrs["context"] == "Right":
557581
ipsi_heel_marker = mapping.MappedMarkers.R_HEEL
558582
ipsi_toe_2_marker = mapping.MappedMarkers.R_TOE
559-
ipsi_toe_5_marker = self.get_optional_marker('R_TOE_5')
560-
ipsi_ankle_marker = self.get_optional_marker('R_ANKLE')
583+
ipsi_toe_5_marker = self.get_optional_marker("R_TOE_5")
584+
ipsi_ankle_marker = self.get_optional_marker("R_ANKLE")
561585

562586
contra_toe_2_marker = mapping.MappedMarkers.L_TOE
563587
contra_heel_marker = mapping.MappedMarkers.L_HEEL
564-
contra_ankle_marker = self.get_optional_marker('L_ANKLE')
565-
588+
contra_ankle_marker = self.get_optional_marker("L_ANKLE")
589+
566590
else:
567591
ipsi_heel_marker = mapping.MappedMarkers.L_HEEL
568592
ipsi_toe_2_marker = mapping.MappedMarkers.L_TOE
569-
ipsi_toe_5_marker = self.get_optional_marker('L_TOE_5')
570-
ipsi_ankle_marker = self.get_optional_marker('L_ANKLE')
593+
ipsi_toe_5_marker = self.get_optional_marker("L_TOE_5")
594+
ipsi_ankle_marker = self.get_optional_marker("L_ANKLE")
571595

572596
contra_toe_2_marker = mapping.MappedMarkers.R_TOE
573597
contra_heel_marker = mapping.MappedMarkers.R_HEEL
574-
contra_ankle_marker = self.get_optional_marker('R_ANKLE')
598+
contra_ankle_marker = self.get_optional_marker("R_ANKLE")
599+
600+
xcom_marker = self.get_optional_marker("XCOM")
575601

576-
xcom_marker = self.get_optional_marker('XCOM')
577-
578602
return {
579603
"ipsi_toe_2": ipsi_toe_2_marker,
580604
"ipsi_toe_5": ipsi_toe_5_marker,
@@ -583,18 +607,16 @@ def select_markers_for_spatial_features(self,
583607
"contra_toe_2": contra_toe_2_marker,
584608
"contra_heel": contra_heel_marker,
585609
"contra_ankle": contra_ankle_marker,
586-
"xcom": xcom_marker
610+
"xcom": xcom_marker,
587611
}
588-
589-
612+
590613
def get_optional_marker(self, marker_name: str) -> mapping.MappedMarkers | None:
591-
""" Returns the marker if exists, else returns None
614+
"""Returns the marker if exists, else returns None
592615
593616
Args:
594617
marker_name (str): The marker name
595618
"""
596619
return getattr(mapping.MappedMarkers, marker_name, None)
597-
598620

599621
def _calculate_step_length(
600622
self,
@@ -628,7 +650,6 @@ def _calculate_step_length(
628650
distance = linalg.calculate_distance(projected_ipsi, projected_contra).values
629651
return {"step_length": distance}
630652

631-
632653
def _calculate_step_width(
633654
self,
634655
trial: model.Trial,
@@ -661,8 +682,7 @@ def _calculate_step_width(
661682
distance = linalg.calculate_distance(ipsi_heel, projected_ipsi).values
662683

663684
return {"step_width": distance}
664-
665-
685+
666686
def _calculate_stride_length(
667687
self,
668688
trial: model.Trial,
@@ -687,29 +707,35 @@ def _calculate_stride_length(
687707
total_distance = 0
688708

689709
for event_time in [event_times[2], event_times[-1]]:
690-
ipsi_heel = self._get_marker_data(trial, ipsi_marker).sel(time=event_time, method="nearest")
691-
contra_heel = self._get_marker_data(trial, contra_marker).sel(time=event_time, method="nearest")
692-
print(f'ipsi heel: {ipsi_heel}')
693-
710+
ipsi_heel = self._get_marker_data(trial, ipsi_marker).sel(
711+
time=event_time, method="nearest"
712+
)
713+
contra_heel = self._get_marker_data(trial, contra_marker).sel(
714+
time=event_time, method="nearest"
715+
)
716+
print(f"ipsi heel: {ipsi_heel}")
717+
694718
projected_ipsi = linalg.project_point_on_vector(ipsi_heel, progress_axis)
695-
projected_contra = linalg.project_point_on_vector(contra_heel, progress_axis)
696-
print(f'projected ipsi: {projected_ipsi}')
719+
projected_contra = linalg.project_point_on_vector(
720+
contra_heel, progress_axis
721+
)
722+
print(f"projected ipsi: {projected_ipsi}")
697723

698-
distance = linalg.calculate_distance(projected_ipsi, projected_contra).values
699-
print(f'distance: {distance}')
724+
distance = linalg.calculate_distance(
725+
projected_ipsi, projected_contra
726+
).values
727+
print(f"distance: {distance}")
700728
total_distance += distance
701-
702-
729+
703730
return {"stride_length": total_distance}
704-
705-
731+
706732
def _calculate_minimal_toe_clearance(
707-
self,
708-
trial: model.Trial,
709-
ipsi_toe_marker: mapping.MappedMarkers,
710-
ipsi_heel_marker: mapping.MappedMarkers,
711-
*opt_ipsi_toe_markers: mapping.MappedMarkers
712-
) -> dict[str, np.ndarray]:
733+
self,
734+
trial: model.Trial,
735+
ipsi_toe_marker: mapping.MappedMarkers,
736+
ipsi_heel_marker: mapping.MappedMarkers,
737+
*opt_ipsi_toe_markers: mapping.MappedMarkers,
738+
) -> dict[str, np.ndarray]:
713739
"""Calculate the minimal toe clearance for a trial. Toe clearance is computed for all toe markers passed, only the minimal is returned
714740
715741
Args:
@@ -722,20 +748,26 @@ def _calculate_minimal_toe_clearance(
722748
"""
723749
event_times = self.get_event_times(trial.events)
724750

725-
ipsi_heel = self._get_marker_data(trial, ipsi_heel_marker).sel(time=slice(event_times[3], event_times[4]))
726-
ipsi_toe = self._get_marker_data(trial, ipsi_toe_marker).sel(time=slice(event_times[3], event_times[4]))
751+
ipsi_heel = self._get_marker_data(trial, ipsi_heel_marker).sel(
752+
time=slice(event_times[3], event_times[4])
753+
)
754+
ipsi_toe = self._get_marker_data(trial, ipsi_toe_marker).sel(
755+
time=slice(event_times[3], event_times[4])
756+
)
727757

728758
toes_vel = linalg.calculate_speed_norm(ipsi_toe)
729759

730760
additional_meta_data = []
731-
761+
732762
for meta_marker in opt_ipsi_toe_markers:
733763
if meta_marker is not None:
734-
meta_data = self._get_marker_data(trial, meta_marker).sel(time=slice(event_times[3], event_times[4]))
764+
meta_data = self._get_marker_data(trial, meta_marker).sel(
765+
time=slice(event_times[3], event_times[4])
766+
)
735767
toes_vel += linalg.calculate_speed_norm(meta_data)
736768
additional_meta_data.append(meta_data)
737-
738-
toes_vel /= (1 + len(additional_meta_data))
769+
770+
toes_vel /= 1 + len(additional_meta_data)
739771

740772
mtc_i = self._find_mtc_index(ipsi_toe, ipsi_heel, toes_vel)
741773
mtc_additional_indices = [
@@ -744,49 +776,50 @@ def _calculate_minimal_toe_clearance(
744776
]
745777

746778
# Handle NaN cases and find minimal clearance
747-
mtc_values = [] if np.isnan(mtc_i) else [ipsi_toe.sel(axis='z')[mtc_i]]
779+
mtc_values = [] if np.isnan(mtc_i) else [ipsi_toe.sel(axis="z")[mtc_i]]
748780
for i, meta_data in zip(mtc_additional_indices, additional_meta_data):
749781
if not np.isnan(i):
750-
mtc_values.append(meta_data.sel(axis='z')[i])
782+
mtc_values.append(meta_data.sel(axis="z")[i])
751783

752784
if not mtc_values:
753785
return {"minimal_toe_clearance": np.NaN}
754786

755787
return {"minimal_toe_clearance": min(mtc_values)}
756-
757-
758-
def _find_mtc_index(self,
759-
toe_position: xr.DataArray,
760-
heel_position: xr.DataArray,
761-
toes_vel: xr.DataArray):
788+
789+
def _find_mtc_index(
790+
self,
791+
toe_position: xr.DataArray,
792+
heel_position: xr.DataArray,
793+
toes_vel: xr.DataArray,
794+
):
762795
"""Find the time corresponding to minimal toe clearance of a specific toe.
763-
Valid minimal toe clearance point must validates conditions
764-
defined in Schulz 2017 (doi: 10.1016/j.jbiomech.2017.02.024)
765-
Args:
766-
toe_position: A DataArray containing positions of the toe
767-
heel_position: A DataArray containing positions of the heel
768-
toes_vel: A DataArray containing the mean toes velocity at each timepoint
769-
Returns:
770-
The time corresponding to minimal toe clearance for the input toe.
771-
"""
772-
toes_vel_up_quant = np.quantile(toes_vel, .5)
773-
toe_z = toe_position.sel(axis='z')
774-
heel_z = heel_position.sel(axis='z')
796+
Valid minimal toe clearance point must validates conditions
797+
defined in Schulz 2017 (doi: 10.1016/j.jbiomech.2017.02.024)
798+
Args:
799+
toe_position: A DataArray containing positions of the toe
800+
heel_position: A DataArray containing positions of the heel
801+
toes_vel: A DataArray containing the mean toes velocity at each timepoint
802+
Returns:
803+
The time corresponding to minimal toe clearance for the input toe.
804+
"""
805+
toes_vel_up_quant = np.quantile(toes_vel, 0.5)
806+
toe_z = toe_position.sel(axis="z")
807+
heel_z = heel_position.sel(axis="z")
775808

776809
# Check conditions according to Schulz 2017
777810
mtc_i = math.find_local_minimas(toe_z)
778811
mtc_i = [i for i in mtc_i if toes_vel[i] >= toes_vel_up_quant]
779812
mtc_i = [i for i in mtc_i if toe_z[i] <= heel_z[i]]
780813

781814
return np.NaN if not mtc_i else min(mtc_i, key=lambda i: toe_z[i])
782-
783-
784-
def _calculate_AP_margin_of_stability(self,
785-
trial: model.Trial,
786-
ipsi_toe_marker: mapping.MappedMarkers,
787-
contra_toe_marker: mapping.MappedMarkers,
788-
xcom_marker: mapping.MappedMarkers,
789-
) -> dict[str, np.ndarray]:
815+
816+
def _calculate_AP_margin_of_stability(
817+
self,
818+
trial: model.Trial,
819+
ipsi_toe_marker: mapping.MappedMarkers,
820+
contra_toe_marker: mapping.MappedMarkers,
821+
xcom_marker: mapping.MappedMarkers,
822+
) -> dict[str, np.ndarray]:
790823
"""Calculate the anterio-posterior margin of stability at heel strike
791824
Args:
792825
trial: The trial for which to calculate the AP margin of stability
@@ -808,28 +841,28 @@ def _calculate_AP_margin_of_stability(self,
808841
xcom = self._get_marker_data(trial, xcom_marker).sel(
809842
time=event_times[0], method="nearest"
810843
)
811-
844+
812845
progress_axis = self._get_progression_vector(trial)
813846
progress_axis = linalg.normalize_vector(progress_axis)
814-
847+
815848
projected_ipsi = linalg.project_point_on_vector(ipsi_toe, progress_axis)
816849
projected_contra = linalg.project_point_on_vector(contra_toe, progress_axis)
817850
projected_xcom = linalg.project_point_on_vector(xcom, progress_axis)
818-
851+
819852
bos_len = linalg.calculate_distance(projected_ipsi, projected_contra).values
820853
xcom_len = linalg.calculate_distance(projected_contra, projected_xcom).values
821854

822855
mos = bos_len - xcom_len
823-
856+
824857
return {"AP_margin_of_stability": mos}
825-
826-
827-
def _calculate_ML_margin_of_stability(self,
828-
trial: model.Trial,
829-
ipsi_ankle_marker: mapping.MappedMarkers,
830-
contra_ankle_marker: mapping.MappedMarkers,
831-
xcom_marker: mapping.MappedMarkers
832-
) -> dict[str, np.ndarray]:
858+
859+
def _calculate_ML_margin_of_stability(
860+
self,
861+
trial: model.Trial,
862+
ipsi_ankle_marker: mapping.MappedMarkers,
863+
contra_ankle_marker: mapping.MappedMarkers,
864+
xcom_marker: mapping.MappedMarkers,
865+
) -> dict[str, np.ndarray]:
833866
"""Calculate the medio-lateral margin of stability at heel strike
834867
Args:
835868
trial: The trial for which to calculate the AP margin of stability
@@ -858,10 +891,10 @@ def _calculate_ML_margin_of_stability(self,
858891
projected_ipsi = linalg.project_point_on_vector(ipsi_ankle, sagittal_axis)
859892
projected_contra = linalg.project_point_on_vector(contra_ankle, sagittal_axis)
860893
projected_xcom = linalg.project_point_on_vector(xcom, sagittal_axis)
861-
894+
862895
bos_len = linalg.calculate_distance(projected_contra, projected_ipsi).values
863896
xcom_len = linalg.calculate_distance(projected_contra, projected_xcom).values
864-
897+
865898
mos = bos_len - xcom_len
866899

867-
return {"ML_margin_of_stability": mos}
900+
return {"ML_margin_of_stability": mos}

0 commit comments

Comments
 (0)