70
70
"nearest_chans" : 8 ,
71
71
"nearest_templates" : 35 ,
72
72
"max_channel_distance" : 5 ,
73
- "templates_from_data" : False ,
74
73
"n_templates" : 10 ,
75
74
"n_pcs" : 3 ,
76
75
"Th_single_ch" : 4 ,
109
108
# max_peels is not affecting the results in this short dataset
110
109
PARAMETERS_NOT_AFFECTING_RESULTS .append ("max_peels" )
111
110
111
+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
112
+ PARAMS_TO_TEST_DICT .update ({"cluster_neighbors" : 11 })
113
+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("cluster_neighbors" )
114
+
112
115
113
116
PARAMS_TO_TEST = list (PARAMS_TO_TEST_DICT .keys ())
114
117
@@ -178,11 +181,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
178
181
"""
179
182
paths = {
180
183
"session_scope_tmp_path" : tmp_path ,
181
- "recording_path" : tmp_path / "my_test_recording" ,
184
+ "recording_path" : tmp_path / "my_test_recording" / "traces_cached_seg0.raw" ,
182
185
"probe_path" : tmp_path / "my_test_probe.prb" ,
183
186
}
184
187
185
- recording .save (folder = paths ["recording_path" ], overwrite = True )
188
+ recording .save (folder = paths ["recording_path" ]. parent , overwrite = True )
186
189
187
190
probegroup = recording .get_probegroup ()
188
191
write_prb (paths ["probe_path" ].as_posix (), probegroup )
@@ -214,7 +217,7 @@ def test_default_settings_all_represented(self):
214
217
tested_keys += additional_non_tested_keys
215
218
216
219
for param_key in DEFAULT_SETTINGS :
217
- if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" ]:
220
+ if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" , "templates_from_data" ]:
218
221
assert param_key in tested_keys , f"param: { param_key } in DEFAULT SETTINGS but not tested."
219
222
220
223
def test_spikeinterface_defaults_against_kilsort (self ):
@@ -234,8 +237,11 @@ def test_spikeinterface_defaults_against_kilsort(self):
234
237
235
238
# Testing Arguments ###
236
239
def test_set_files_arguments (self ):
240
+ expected_arguments = ["settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
241
+ if parse (kilosort .__version__ ) >= parse ("4.0.34" ):
242
+ expected_arguments += ["shank_idx" ]
237
243
self ._check_arguments (
238
- set_files , [ "settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
244
+ set_files , expected_arguments
239
245
)
240
246
241
247
def test_initialize_ops_arguments (self ):
@@ -533,33 +539,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
533
539
kilosort_output_dir = tmp_path / "kilosort_output_dir"
534
540
spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
535
541
536
- def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
537
- """
538
- This is a direct copy of the kilosort io.BinaryFiltered.filter
539
- function, with hp_filter and whitening matrix code sections, and
540
- comments removed. This is the easiest way to monkeypatch (tried a few approaches)
541
- """
542
- if self .chan_map is not None :
543
- X = X [self .chan_map ]
544
-
545
- if self .invert_sign :
546
- X = X * - 1
547
-
548
- X = X - X .mean (1 ).unsqueeze (1 )
549
- if self .do_CAR :
550
- X = X - torch .median (X , 0 )[0 ]
551
-
552
- if self .hp_filter is not None :
553
- pass
554
-
555
- if self .artifact_threshold < np .inf :
556
- if torch .any (torch .abs (X ) >= self .artifact_threshold ):
557
- return torch .zeros_like (X )
558
-
559
- if self .whiten_mat is not None :
560
- pass
561
- return X
562
-
542
+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
543
+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None , skip_preproc = False ):
544
+ """
545
+ This is a direct copy of the kilosort io.BinaryFiltered.filter
546
+ function, with hp_filter and whitening matrix code sections, and
547
+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
548
+ """
549
+ if self .chan_map is not None :
550
+ X = X [self .chan_map ]
551
+
552
+ if self .invert_sign :
553
+ X = X * - 1
554
+
555
+ X = X - X .mean (1 ).unsqueeze (1 )
556
+ if self .do_CAR :
557
+ X = X - torch .median (X , 0 )[0 ]
558
+
559
+ if self .hp_filter is not None :
560
+ pass
561
+
562
+ if self .artifact_threshold < np .inf :
563
+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
564
+ return torch .zeros_like (X )
565
+
566
+ if self .whiten_mat is not None :
567
+ pass
568
+ return X
569
+ else :
570
+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
571
+ """
572
+ This is a direct copy of the kilosort io.BinaryFiltered.filter
573
+ function, with hp_filter and whitening matrix code sections, and
574
+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
575
+ """
576
+ if self .chan_map is not None :
577
+ X = X [self .chan_map ]
578
+
579
+ if self .invert_sign :
580
+ X = X * - 1
581
+
582
+ X = X - X .mean (1 ).unsqueeze (1 )
583
+ if self .do_CAR :
584
+ X = X - torch .median (X , 0 )[0 ]
585
+
586
+ if self .hp_filter is not None :
587
+ pass
588
+
589
+ if self .artifact_threshold < np .inf :
590
+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
591
+ return torch .zeros_like (X )
592
+
593
+ if self .whiten_mat is not None :
594
+ pass
595
+ return X
563
596
monkeypatch .setattr ("kilosort.io.BinaryFiltered.filter" , monkeypatch_filter_function )
564
597
565
598
ks_settings , _ , ks_format_probe = self ._get_kilosort_native_settings (recording , paths , param_key , param_value )
@@ -620,7 +653,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
620
653
are through the function, these are split here.
621
654
"""
622
655
settings = {
623
- "data_dir " : paths ["recording_path" ],
656
+ "filename " : paths ["recording_path" ],
624
657
"n_chan_bin" : recording .get_num_channels (),
625
658
"fs" : recording .get_sampling_frequency (),
626
659
}
0 commit comments