@@ -52,17 +52,13 @@ def run(self):
52
52
try :
53
53
logger .info (f"Kilosort version { kilosort .__version__ } " )
54
54
logger .info (f"Sorting { self .data_path } " )
55
+ clear_cache = settings ['clear_cache' ]
56
+ if clear_cache :
57
+ logger .info ('clear_cache=True' )
55
58
logger .info ('-' * 40 )
56
59
57
60
tic0 = time .time ()
58
61
59
- # TODO: make these options in GUI
60
- do_CAR = True
61
- invert_sign = False
62
-
63
- if not do_CAR :
64
- logger .info ("Skipping common average reference." )
65
-
66
62
if probe ['chanMap' ].max () >= settings ['n_chan_bin' ]:
67
63
raise ValueError (
68
64
f'Largest value of chanMap exceeds channel count of data, '
@@ -74,9 +70,13 @@ def run(self):
74
70
data_dtype = settings ['data_dtype' ]
75
71
device = self .device
76
72
save_preprocessed_copy = settings ['save_preprocessed_copy' ]
73
+ do_CAR = settings ['do_CAR' ]
74
+ invert_sign = settings ['invert_sign' ]
75
+ if not do_CAR :
76
+ logger .info ("Skipping common average reference." )
77
77
78
78
ops = initialize_ops (settings , probe , data_dtype , do_CAR ,
79
- invert_sign , device , save_preprocessed_copy )
79
+ invert_sign , device , save_preprocessed_copy )
80
80
# Remove some stuff that doesn't need to be printed twice,
81
81
# then pretty-print format for log file.
82
82
ops_copy = ops .copy ()
@@ -94,7 +94,7 @@ def run(self):
94
94
torch .random .manual_seed (1 )
95
95
ops , bfile , st0 = compute_drift_correction (
96
96
ops , self .device , tic0 = tic0 , progress_bar = self .progress_bar ,
97
- file_object = self .file_object
97
+ file_object = self .file_object , clear_cache = clear_cache
98
98
)
99
99
100
100
# Check scale of data for log file
@@ -113,7 +113,7 @@ def run(self):
113
113
# Sort spikes and save results
114
114
st , tF , Wall0 , clu0 = detect_spikes (
115
115
ops , self .device , bfile , tic0 = tic0 ,
116
- progress_bar = self .progress_bar
116
+ progress_bar = self .progress_bar , clear_cache = clear_cache
117
117
)
118
118
119
119
self .Wall0 = Wall0
@@ -123,7 +123,7 @@ def run(self):
123
123
124
124
clu , Wall = cluster_spikes (
125
125
st , tF , ops , self .device , bfile , tic0 = tic0 ,
126
- progress_bar = self .progress_bar
126
+ progress_bar = self .progress_bar , clear_cache = clear_cache
127
127
)
128
128
ops , similar_templates , is_ref , est_contam_rate , kept_spikes = \
129
129
save_sorting (ops , results_dir , st , clu , tF , Wall , bfile .imin , tic0 )
0 commit comments