Skip to content

Commit 3072aa4

Browse files
committed
Added filtering.py module, w/ draft interp_pose() & filter_confidence() fxns
1 parent d5b1335 commit 3072aa4

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

movement/analysis/filtering.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from datetime import datetime
2+
from typing import Union
3+
4+
import numpy as np
5+
import xarray as xr
6+
7+
8+
def interp_pose(
9+
ds: xr.Dataset,
10+
method: str = "linear",
11+
limit: Union[int, None] = None,
12+
max_gap: Union[int, None] = None,
13+
inplace: bool = False,
14+
) -> Union[xr.Dataset, None]:
15+
"""
16+
Fills in NaN values by interpolating over the time dimension.
17+
18+
Parameters
19+
----------
20+
ds : xarray.Dataset
21+
Dataset containing pose tracks, confidence scores, and metadata.
22+
method : str
23+
String indicating which method to use for interpolation.
24+
Default is `linear`. See documentation for
25+
`xarray.DataSet.interpolate_na` for complete list of options.
26+
limit : int | None
27+
Maximum number of consecutive NaNs to interpolate over.
28+
`None` indicates no limit, and is the default value.
29+
max_gap : TODO: Clarify the difference between `limit` & `max_gap`
30+
The largest gap of consecutive NaNs that will be
31+
interpolated over. The default value is `None`.
32+
inplace: bool
33+
If true, updates the provided DataSet in place and returns
34+
`None`.
35+
36+
Returns
37+
-------
38+
ds_thresholded : xr.DataArray
39+
The provided dataset (ds), where NaN values have been
40+
interpolated over using the parameters provided.
41+
"""
42+
# TODO: This method interpolates over confidence values as well.
43+
# -> Figure out whether this is the desired default behavior.
44+
ds_interpolated = ds.interpolate_na(
45+
dim="time", method=method, limit=limit, max_gap=max_gap
46+
)
47+
48+
# Logging
49+
log_entry = {
50+
"operation": "interp_pose",
51+
"method": method,
52+
"limit": limit,
53+
"max_gap": max_gap,
54+
"inplace": inplace,
55+
"datetime": str(datetime.now()),
56+
}
57+
ds_interpolated.attrs["log"].append(log_entry)
58+
59+
if inplace:
60+
ds["pose_tracks"] = ds_interpolated["pose_tracks"]
61+
ds["confidence"] = ds_interpolated["confidence"]
62+
return None
63+
else:
64+
return ds_interpolated
65+
66+
67+
def filter_confidence(
68+
ds: xr.Dataset,
69+
threshold: float = 0.6,
70+
inplace: bool = False,
71+
interp: bool = False,
72+
) -> Union[xr.Dataset, None]:
73+
"""
74+
Drops all datapoints where the associated confidence value
75+
falls below a user-defined threshold.
76+
77+
Parameters
78+
----------
79+
ds : xarray.Dataset
80+
Dataset containing pose tracks, confidence scores, and metadata.
81+
threshold : float
82+
The confidence threshold below which datapoints are filtered.
83+
A default value of `0.6` is used.
84+
inplace : bool
85+
If true, updates the provided DataSet in place and returns
86+
`None`.
87+
interp : bool
88+
If true, NaNs are interpolated over using `interp_pose` with
89+
default parameters.
90+
91+
Returns
92+
-------
93+
ds_thresholded : xarray.Dataset
94+
The provided dataset (ds), where datapoints with a confidence
95+
value below the user-defined threshold have been converted
96+
to NaNs
97+
"""
98+
99+
ds_thresholded = ds.where(ds.confidence >= threshold)
100+
101+
# Diagnostics
102+
print("\nDatapoints Filtered:\n")
103+
for kp in ds.keypoints.values:
104+
n_nans = np.count_nonzero(
105+
np.isnan(ds_thresholded.confidence.sel(keypoints=f"{kp}").values)
106+
)
107+
n_points = ds.time.values.shape[0]
108+
prop_nans = round((n_nans / n_points) * 100, 2)
109+
print(f"{kp}: {n_nans}/{n_points} ({prop_nans}%)")
110+
111+
# TODO: Is this enough diagnostics? Should I write logic to allow
112+
# users to optionally plot out confidence distributions + imposed
113+
# threshold?
114+
115+
# Logging
116+
if "log" not in ds_thresholded.attrs.keys():
117+
ds_thresholded.attrs["log"] = []
118+
119+
log_entry = {
120+
"operation": "filter_confidence",
121+
"threshold": threshold,
122+
"inplace": inplace,
123+
"datetime": str(datetime.now()),
124+
}
125+
ds_thresholded.attrs["log"].append(log_entry)
126+
127+
# Interpolation
128+
if interp:
129+
interp_pose(ds_thresholded, inplace=True)
130+
131+
if inplace:
132+
ds["pose_tracks"] = ds_thresholded["pose_tracks"]
133+
ds["confidence"] = ds_thresholded["confidence"]
134+
return None
135+
if not inplace:
136+
return ds_thresholded

0 commit comments

Comments
 (0)