Skip to content

Commit 81fb2a5

Browse files
Added clear_cache, do_CAR, invert_sign to GUI
1 parent 1d11e34 commit 81fb2a5

File tree

3 files changed

+48
-14
lines changed

3 files changed

+48
-14
lines changed

kilosort/gui/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ def set_parameters(self):
340340

341341
params = settings.copy()
342342
params['save_preprocessed_copy'] = self.run_box.save_preproc_check.isChecked()
343+
params['clear_cache'] = self.run_box.clear_cache_check.isChecked()
344+
params['do_CAR'] = self.run_box.do_CAR_check.isChecked()
345+
params['invert_sign'] = self.run_box.invert_sign_check.isChecked()
343346

344347
assert params
345348

kilosort/gui/run_box.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def __init__(self, parent):
2424
self.run_all_button = QtWidgets.QPushButton("Run")
2525
self.spike_sort_button = QtWidgets.QPushButton("Spikesort")
2626
self.save_preproc_check = QtWidgets.QCheckBox("Save Preprocessed Copy")
27+
self.clear_cache_check = QtWidgets.QCheckBox("Clear PyTorch Cache")
28+
self.do_CAR_check = QtWidgets.QCheckBox("CAR")
29+
self.invert_sign_check = QtWidgets.QCheckBox("Invert Sign")
2730

2831
self.buttons = [
2932
self.run_all_button
@@ -44,7 +47,7 @@ def __init__(self, parent):
4447
self.remote_widgets = None
4548

4649
self.progress_bar = QtWidgets.QProgressBar()
47-
self.layout.addWidget(self.progress_bar, 3, 0, 2, 2)
50+
self.layout.addWidget(self.progress_bar, 5, 0, 3, 4)
4851

4952
self.setup()
5053

@@ -64,8 +67,36 @@ def setup(self):
6467
"""
6568
self.save_preproc_check.setToolTip(preproc_text)
6669

67-
self.layout.addWidget(self.run_all_button, 0, 0, 2, 2)
68-
self.layout.addWidget(self.save_preproc_check, 2, 0, 1, 2)
70+
self.clear_cache_check.setCheckState(QtCore.Qt.CheckState.Unchecked)
71+
cache_text = """
72+
If enabled, force pytorch to free up memory reserved for its cache in
73+
between memory-intensive operations.
74+
Note that setting `clear_cache=True` is NOT recommended unless you
75+
encounter GPU out-of-memory errors, since this can result in slower
76+
sorting.
77+
"""
78+
self.clear_cache_check.setToolTip(cache_text)
79+
80+
self.do_CAR_check.setCheckState(QtCore.Qt.CheckState.Checked)
81+
car_text = """
82+
If enabled, apply common average reference during preprocessing
83+
(recommended).
84+
"""
85+
self.do_CAR_check.setToolTip(car_text)
86+
87+
self.invert_sign_check.setCheckState(QtCore.Qt.CheckState.Unchecked)
88+
invert_sign_text = """
89+
If enabled, flip positive/negative values in data to conform to
90+
standard expected by Kilosort4. This is NOT recommended unless you
91+
know your data is using the opposite sign.
92+
"""
93+
self.invert_sign_check.setToolTip(invert_sign_text)
94+
95+
self.layout.addWidget(self.run_all_button, 0, 0, 3, 4)
96+
self.layout.addWidget(self.save_preproc_check, 3, 0, 1, 2)
97+
self.layout.addWidget(self.clear_cache_check, 3, 2, 1, 2)
98+
self.layout.addWidget(self.do_CAR_check, 4, 0, 1, 2)
99+
self.layout.addWidget(self.invert_sign_check, 4, 2, 1, 2)
69100

70101
self.setLayout(self.layout)
71102

kilosort/gui/sorter.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,13 @@ def run(self):
5252
try:
5353
logger.info(f"Kilosort version {kilosort.__version__}")
5454
logger.info(f"Sorting {self.data_path}")
55+
clear_cache = settings['clear_cache']
56+
if clear_cache:
57+
logger.info('clear_cache=True')
5558
logger.info('-'*40)
5659

5760
tic0 = time.time()
5861

59-
# TODO: make these options in GUI
60-
do_CAR=True
61-
invert_sign=False
62-
63-
if not do_CAR:
64-
logger.info("Skipping common average reference.")
65-
6662
if probe['chanMap'].max() >= settings['n_chan_bin']:
6763
raise ValueError(
6864
f'Largest value of chanMap exceeds channel count of data, '
@@ -74,9 +70,13 @@ def run(self):
7470
data_dtype = settings['data_dtype']
7571
device = self.device
7672
save_preprocessed_copy = settings['save_preprocessed_copy']
73+
do_CAR = settings['do_CAR']
74+
invert_sign = settings['invert_sign']
75+
if not do_CAR:
76+
logger.info("Skipping common average reference.")
7777

7878
ops = initialize_ops(settings, probe, data_dtype, do_CAR,
79-
invert_sign, device, save_preprocessed_copy)
79+
invert_sign, device, save_preprocessed_copy)
8080
# Remove some stuff that doesn't need to be printed twice,
8181
# then pretty-print format for log file.
8282
ops_copy = ops.copy()
@@ -94,7 +94,7 @@ def run(self):
9494
torch.random.manual_seed(1)
9595
ops, bfile, st0 = compute_drift_correction(
9696
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
97-
file_object=self.file_object
97+
file_object=self.file_object, clear_cache=clear_cache
9898
)
9999

100100
# Check scale of data for log file
@@ -113,7 +113,7 @@ def run(self):
113113
# Sort spikes and save results
114114
st, tF, Wall0, clu0 = detect_spikes(
115115
ops, self.device, bfile, tic0=tic0,
116-
progress_bar=self.progress_bar
116+
progress_bar=self.progress_bar, clear_cache=clear_cache
117117
)
118118

119119
self.Wall0 = Wall0
@@ -123,7 +123,7 @@ def run(self):
123123

124124
clu, Wall = cluster_spikes(
125125
st, tF, ops, self.device, bfile, tic0=tic0,
126-
progress_bar=self.progress_bar
126+
progress_bar=self.progress_bar, clear_cache=clear_cache
127127
)
128128
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
129129
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)

0 commit comments

Comments
 (0)