Skip to content

Commit e99dc6c

Browse files
author
André Böni
committed
+ fix typing bugs
1 parent ed54b61 commit e99dc6c

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

gaitalytics/features.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ def _calculate_minimal_toe_clearance(
831831

832832
@staticmethod
833833
def _find_mtc_index(
834-
toe_position: xr.DataArray, heel_position: xr.DataArray, toes_vel: np.ndarray
834+
toe_position: xr.DataArray, heel_position: xr.DataArray, toes_vel: xr.DataArray
835835
):
836836
"""Find the time corresponding to minimal toe clearance of a specific toe.
837837
Valid minimal toe clearance point must pass conditions
@@ -900,9 +900,9 @@ def _calculate_ap_margin_of_stability(
900900
mos = bos_proj - xcom_proj
901901

902902
return {
903-
"AP_margin_of_stability": mos,
904-
"AP_base_of_support": bos_proj,
905-
"AP_xcom": xcom_proj,
903+
"AP_margin_of_stability": mos.to_numpy(),
904+
"AP_base_of_support": bos_proj.to_numpy(),
905+
"AP_xcom": xcom_proj.to_numpy(),
906906
}
907907

908908
def _calculate_ml_margin_of_stability(
@@ -940,7 +940,7 @@ def _calculate_ml_margin_of_stability(
940940
sagittal_axis = self._get_sagittal_vector(trial)
941941
sagittal_axis = linalg.normalize_vector(sagittal_axis)
942942

943-
if trial.events.attrs["context"] == "Left":
943+
if trial.events is not None and trial.events.attrs["context"] == "Left":
944944
# Rotate sagittal axis so it points towards the left side of the body
945945
sagittal_axis = -sagittal_axis
946946

@@ -959,7 +959,7 @@ def _calculate_ml_margin_of_stability(
959959
mos = bos_proj - xcom_proj
960960

961961
return {
962-
"ML_margin_of_stability": mos,
963-
"ML_base_of_support": bos_proj,
964-
"ML_xcom": xcom_proj,
962+
"ML_margin_of_stability": mos.to_numpy(),
963+
"ML_base_of_support": bos_proj.to_numpy(),
964+
"ML_xcom": xcom_proj.to_numpy(),
965965
}

gaitalytics/utils/linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def normalize_vector(vector: xr.DataArray) -> xr.DataArray:
8080
return vector / vector.meca.norm(dim="axis")
8181

8282

83-
def calculate_speed_norm(position: xr.DataArray, dt: float = 0.01) -> np.ndarray:
83+
def calculate_speed_norm(position: xr.DataArray, dt: float = 0.01) -> xr.DataArray:
8484
"""
8585
Compute the speed from a 3xN position data array obtained with constant sampling rate
8686

0 commit comments

Comments
 (0)