1414import h5py
1515import numpy as np
1616import numpy .ma as ma
17- from scipy .interpolate import UnivariateSpline
17+ from scipy .interpolate import interp1d
1818from scipy .signal import savgol_filter
1919
2020from wormpose .commands import _log_parameters
@@ -75,61 +75,32 @@ def _get_valid_segments(is_valid_series: np.ndarray, max_gap_size: int, min_segm
7575
7676
7777class _SplineInterpolation (object ):
78- def __init__ (self , frames_around : int = 3 , spline_degree : int = 3 ):
79-
80- self .spline_degree = spline_degree
81- self .frames_around = frames_around
82- self .slope = np .linspace (0 , 0.5 , frames_around + 1 )
83-
84- def interpolate_tseries (
85- self , tseries : np .ndarray , segments_boundaries : Sequence , std_fraction : float
86- ) -> np .ndarray :
87-
88- weight = 1 / (std_fraction * np .nanstd (tseries ))
78+ def interpolate_tseries (self , tseries : np .ndarray , segments_boundaries : Sequence ) -> np .ndarray :
8979
9080 tseries [~ np .isnan (tseries )] = np .unwrap (tseries [~ np .isnan (tseries )])
9181 new_tseries = np .full_like (tseries , np .nan )
9282
9383 for t0 , tf in segments_boundaries :
94- new_tseries [t0 :tf ] = self ._interpolate_segment (tseries [t0 :tf ], weight )
84+ new_tseries [t0 :tf ] = self ._interpolate_segment (tseries [t0 :tf ])
9585
9686 return new_tseries
9787
98- def _interpolate_segment (self , tseries : np .ndarray , weight : float ) -> np .ndarray :
88+ def _interpolate_segment (self , tseries : np .ndarray ) -> np .ndarray :
9989 new_tseries = np .copy (tseries )
10090
101- nan_y = np .isnan (new_tseries )
102- indices_nan = np .any (nan_y , axis = 1 )
103- series_len = len (new_tseries )
104- x = np .arange (series_len )
105- new_tseries [nan_y ] = 0.0
106-
107- w = self .build_weights (indices_nan , series_len , weight )
108-
10991 # perform spline interpolation separately for each dimension
11092 for dim in range (new_tseries .shape [1 ]):
111- y = new_tseries [:, dim ]
112- spl = UnivariateSpline (x , y , w = w , k = self .spline_degree , s = len (x ))
113- new_x = x
114- new_weighted_y = spl (new_x )
115- new_tseries [:, dim ] = new_weighted_y
93+ y0 = new_tseries [:, dim ]
94+ xn = np .arange (len (new_tseries ))
95+ sel = ~ np .isnan (y0 )
96+ x = xn [sel ]
97+ y = y0 [sel ]
98+ f = interp1d (x , y , kind = "cubic" )
99+ yn = f (xn )
100+ new_tseries [:, dim ] = yn
116101
117102 return new_tseries
118103
119- def build_weights (self , indices_nan : np .ndarray , series_len : int , weight : float ) -> np .ndarray :
120- # setup weights: lower the weights closer to the edges
121- w = np .full (series_len , weight )
122- where_nan = np .where (indices_nan )[0 ]
123- if len (where_nan ) == 0 :
124- return w
125-
126- for idx in range (series_len ):
127- closest_nan_distance = np .min (np .abs (where_nan - idx ))
128- if closest_nan_distance <= self .frames_around :
129- w [idx ] = self .slope [closest_nan_distance ] * weight
130-
131- return w
132-
133104
134105def _smooth_tseries (
135106 tseries : np .ndarray ,
@@ -211,15 +182,13 @@ def _parse_arguments(dataset_path: str, kwargs: dict):
211182 if kwargs .get ("work_dir" ) is None :
212183 kwargs ["work_dir" ] = default_paths .WORK_DIR
213184 if kwargs .get ("max_gap_size" ) is None :
214- kwargs ["max_gap_size" ] = 3
185+ kwargs ["max_gap_size" ] = 4
215186 if kwargs .get ("min_segment_size" ) is None :
216- kwargs ["min_segment_size" ] = 11
187+ kwargs ["min_segment_size" ] = 8
217188 if kwargs .get ("smoothing_window" ) is None :
218- kwargs ["smoothing_window" ] = 7
189+ kwargs ["smoothing_window" ] = 8
219190 if kwargs .get ("poly_order" ) is None :
220191 kwargs ["poly_order" ] = 3
221- if kwargs .get ("std_fraction" ) is None :
222- kwargs ["std_fraction" ] = 0.001
223192 if kwargs .get ("eigenworms_matrix_path" ) is None :
224193 kwargs ["eigenworms_matrix_path" ] = None
225194 if kwargs .get ("num_process" ) is None :
@@ -294,9 +263,7 @@ def post_process(dataset_path: str, **kwargs):
294263 min_segment_size = args .min_segment_size ,
295264 )
296265 # interpolate and smooth in angles space
297- thetas_interp = spline_interpolation .interpolate_tseries (
298- results_raw .theta , segments_boundaries , args .std_fraction
299- )
266+ thetas_interp = spline_interpolation .interpolate_tseries (results_raw .theta , segments_boundaries )
300267 results_interp = _calculate_skeleton (thetas_interp , args , dataset , video_name )
301268
302269 thetas_smooth = _smooth_tseries (
@@ -322,6 +289,8 @@ def post_process(dataset_path: str, **kwargs):
322289 setattr (results_interp , "modes" , _thetas_to_modes (results_interp .theta , eigenworms_matrix ))
323290 setattr (results_smooth , "modes" , _thetas_to_modes (results_smooth .theta , eigenworms_matrix ))
324291
292+ frame_rate = features .frame_rate
293+
325294 # save results
326295 results_saver = ResultsSaver (
327296 temp_dir = args .temp_dir , results_root_dir = results_root_dir , results_filename = POSTPROCESSED_RESULTS_FILENAME
@@ -332,8 +301,8 @@ def post_process(dataset_path: str, **kwargs):
332301 "min_segment_size" : args .min_segment_size ,
333302 "smoothing_window" : args .smoothing_window ,
334303 "poly_order" : args .poly_order ,
335- "std_fraction" : args .std_fraction ,
336304 "dorsal_ventral_flip" : flipped ,
305+ "frame_rate" : frame_rate ,
337306 }
338307
339308 results_saver .save (
@@ -369,11 +338,6 @@ def main():
369338 help = "Only segments of valid values of length greater than min_segment_size (frames)"
370339 "will be interpolated and smoothed" ,
371340 )
372- parser .add_argument (
373- "--std_fraction" ,
374- type = float ,
375- help = "The higher the guessed noise to signal ratio is, the smoother the interpolation will be" ,
376- )
377341 parser .add_argument ("--smoothing_window" , type = int , help = "smoothing window in frames" )
378342 parser .add_argument ("--poly_order" , type = int , help = "polynomial order in smoothing" )
379343 parser .add_argument ("--temp_dir" , type = str , help = "Where to store temporary intermediate results" )
0 commit comments