14
14
import h5py
15
15
import numpy as np
16
16
import numpy .ma as ma
17
- from scipy .interpolate import UnivariateSpline
17
+ from scipy .interpolate import interp1d
18
18
from scipy .signal import savgol_filter
19
19
20
20
from wormpose .commands import _log_parameters
@@ -75,61 +75,32 @@ def _get_valid_segments(is_valid_series: np.ndarray, max_gap_size: int, min_segm
75
75
76
76
77
77
class _SplineInterpolation (object ):
78
- def __init__ (self , frames_around : int = 3 , spline_degree : int = 3 ):
79
-
80
- self .spline_degree = spline_degree
81
- self .frames_around = frames_around
82
- self .slope = np .linspace (0 , 0.5 , frames_around + 1 )
83
-
84
- def interpolate_tseries (
85
- self , tseries : np .ndarray , segments_boundaries : Sequence , std_fraction : float
86
- ) -> np .ndarray :
87
-
88
- weight = 1 / (std_fraction * np .nanstd (tseries ))
78
+ def interpolate_tseries (self , tseries : np .ndarray , segments_boundaries : Sequence ) -> np .ndarray :
89
79
90
80
tseries [~ np .isnan (tseries )] = np .unwrap (tseries [~ np .isnan (tseries )])
91
81
new_tseries = np .full_like (tseries , np .nan )
92
82
93
83
for t0 , tf in segments_boundaries :
94
- new_tseries [t0 :tf ] = self ._interpolate_segment (tseries [t0 :tf ], weight )
84
+ new_tseries [t0 :tf ] = self ._interpolate_segment (tseries [t0 :tf ])
95
85
96
86
return new_tseries
97
87
98
- def _interpolate_segment (self , tseries : np .ndarray , weight : float ) -> np .ndarray :
88
+ def _interpolate_segment (self , tseries : np .ndarray ) -> np .ndarray :
99
89
new_tseries = np .copy (tseries )
100
90
101
- nan_y = np .isnan (new_tseries )
102
- indices_nan = np .any (nan_y , axis = 1 )
103
- series_len = len (new_tseries )
104
- x = np .arange (series_len )
105
- new_tseries [nan_y ] = 0.0
106
-
107
- w = self .build_weights (indices_nan , series_len , weight )
108
-
109
91
# perform spline interpolation separately for each dimension
110
92
for dim in range (new_tseries .shape [1 ]):
111
- y = new_tseries [:, dim ]
112
- spl = UnivariateSpline (x , y , w = w , k = self .spline_degree , s = len (x ))
113
- new_x = x
114
- new_weighted_y = spl (new_x )
115
- new_tseries [:, dim ] = new_weighted_y
93
+ y0 = new_tseries [:, dim ]
94
+ xn = np .arange (len (new_tseries ))
95
+ sel = ~ np .isnan (y0 )
96
+ x = xn [sel ]
97
+ y = y0 [sel ]
98
+ f = interp1d (x , y , kind = "cubic" )
99
+ yn = f (xn )
100
+ new_tseries [:, dim ] = yn
116
101
117
102
return new_tseries
118
103
119
- def build_weights (self , indices_nan : np .ndarray , series_len : int , weight : float ) -> np .ndarray :
120
- # setup weights: lower the weights closer to the edges
121
- w = np .full (series_len , weight )
122
- where_nan = np .where (indices_nan )[0 ]
123
- if len (where_nan ) == 0 :
124
- return w
125
-
126
- for idx in range (series_len ):
127
- closest_nan_distance = np .min (np .abs (where_nan - idx ))
128
- if closest_nan_distance <= self .frames_around :
129
- w [idx ] = self .slope [closest_nan_distance ] * weight
130
-
131
- return w
132
-
133
104
134
105
def _smooth_tseries (
135
106
tseries : np .ndarray ,
@@ -211,15 +182,13 @@ def _parse_arguments(dataset_path: str, kwargs: dict):
211
182
if kwargs .get ("work_dir" ) is None :
212
183
kwargs ["work_dir" ] = default_paths .WORK_DIR
213
184
if kwargs .get ("max_gap_size" ) is None :
214
- kwargs ["max_gap_size" ] = 3
185
+ kwargs ["max_gap_size" ] = 4
215
186
if kwargs .get ("min_segment_size" ) is None :
216
- kwargs ["min_segment_size" ] = 11
187
+ kwargs ["min_segment_size" ] = 8
217
188
if kwargs .get ("smoothing_window" ) is None :
218
- kwargs ["smoothing_window" ] = 7
189
+ kwargs ["smoothing_window" ] = 8
219
190
if kwargs .get ("poly_order" ) is None :
220
191
kwargs ["poly_order" ] = 3
221
- if kwargs .get ("std_fraction" ) is None :
222
- kwargs ["std_fraction" ] = 0.001
223
192
if kwargs .get ("eigenworms_matrix_path" ) is None :
224
193
kwargs ["eigenworms_matrix_path" ] = None
225
194
if kwargs .get ("num_process" ) is None :
@@ -294,9 +263,7 @@ def post_process(dataset_path: str, **kwargs):
294
263
min_segment_size = args .min_segment_size ,
295
264
)
296
265
# interpolate and smooth in angles space
297
- thetas_interp = spline_interpolation .interpolate_tseries (
298
- results_raw .theta , segments_boundaries , args .std_fraction
299
- )
266
+ thetas_interp = spline_interpolation .interpolate_tseries (results_raw .theta , segments_boundaries )
300
267
results_interp = _calculate_skeleton (thetas_interp , args , dataset , video_name )
301
268
302
269
thetas_smooth = _smooth_tseries (
@@ -322,6 +289,8 @@ def post_process(dataset_path: str, **kwargs):
322
289
setattr (results_interp , "modes" , _thetas_to_modes (results_interp .theta , eigenworms_matrix ))
323
290
setattr (results_smooth , "modes" , _thetas_to_modes (results_smooth .theta , eigenworms_matrix ))
324
291
292
+ frame_rate = features .frame_rate
293
+
325
294
# save results
326
295
results_saver = ResultsSaver (
327
296
temp_dir = args .temp_dir , results_root_dir = results_root_dir , results_filename = POSTPROCESSED_RESULTS_FILENAME
@@ -332,8 +301,8 @@ def post_process(dataset_path: str, **kwargs):
332
301
"min_segment_size" : args .min_segment_size ,
333
302
"smoothing_window" : args .smoothing_window ,
334
303
"poly_order" : args .poly_order ,
335
- "std_fraction" : args .std_fraction ,
336
304
"dorsal_ventral_flip" : flipped ,
305
+ "frame_rate" : frame_rate ,
337
306
}
338
307
339
308
results_saver .save (
@@ -369,11 +338,6 @@ def main():
369
338
help = "Only segments of valid values of length greater than min_segment_size (frames)"
370
339
"will be interpolated and smoothed" ,
371
340
)
372
- parser .add_argument (
373
- "--std_fraction" ,
374
- type = float ,
375
- help = "The higher the guessed noise to signal ratio is, the smoother the interpolation will be" ,
376
- )
377
341
parser .add_argument ("--smoothing_window" , type = int , help = "smoothing window in frames" )
378
342
parser .add_argument ("--poly_order" , type = int , help = "polynomial order in smoothing" )
379
343
parser .add_argument ("--temp_dir" , type = str , help = "Where to store temporary intermediate results" )
0 commit comments