Skip to content

Commit 191a047

Browse files
authored
Merge pull request #3956 from kushaangupta/extend-ks1-params
Expose hidden KiloSort1 parameters and descriptions
2 parents 43fe8bd + 1d8db7d commit 191a047

File tree

1 file changed

+106
-32
lines changed

1 file changed

+106
-32
lines changed

src/spikeinterface/sorters/external/kilosort.py

Lines changed: 106 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
import os
55
from typing import Union
6+
from warnings import warn
67
import numpy as np
78

89
from spikeinterface.sorters.basesorter import BaseSorter
@@ -46,6 +47,31 @@ class KilosortSorter(KilosortBase, BaseSorter):
4647
"wave_length": 61,
4748
"delete_tmp_files": ("matlab_files",),
4849
"delete_recording_dat": False,
50+
"parfor": 0.0,
51+
"nNeighPC": None,
52+
"nNeigh": 16.0,
53+
"whitening": "full",
54+
"nSkipCov": 1.0,
55+
"whiteningRange": 32.0,
56+
"Nrank": 3.0,
57+
"nfullpasses": 6.0,
58+
"maxFR": 20000,
59+
"Th": [4.0, 10.0, 10.0],
60+
"lam": [5.0, 5.0, 5.0],
61+
"nannealpasses": 4.0,
62+
"momentum": [1 / 20, 1 / 400],
63+
"shuffle_clusters": 1.0,
64+
"mergeT": 0.1,
65+
"splitT": 0.1,
66+
"initialize": "fromData",
67+
"loc_range": [3.0, 1.0],
68+
"long_range": [30.0, 6.0],
69+
"maskMaxChannels": 5.0,
70+
"crit": 0.65,
71+
"nFiltMax": 10000.0,
72+
"fracse": 0.1,
73+
"epu": np.inf,
74+
"ForceMaxRAMforDat": 20e9,
4975
}
5076

5177
_params_description = {
@@ -62,6 +88,31 @@ class KilosortSorter(KilosortBase, BaseSorter):
6288
"contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) "
6389
"or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')",
6490
"delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run",
91+
"parfor": "Whether to use parfor to accelerate some parts of the algorithm. (0.0 or 1.0)",
92+
"nNeighPC": "Number of channels to mask the PCs for visualization (Phy). None to skip, default is min(12, Nchan).",
93+
"nNeigh": "Number of neighboring templates to retain projections of for visualization (Phy). (Default 16.0)",
94+
"whitening": "Type of whitening. (Default 'full', or 'noSpikes')",
95+
"nSkipCov": "Compute whitening matrix from every N-th batch. (Default 1.0)",
96+
"whiteningRange": "How many channels to whiten together. (Inf for whole probe whitening, default 32.0)",
97+
"Nrank": "Matrix rank of spike template model. (Default 3.0)",
98+
"nfullpasses": "Number of complete passes through data during optimization. (Default 6.0)",
99+
"maxFR": "Maximum number of spikes to extract per batch. (Default 20000)",
100+
"Th": "Threshold for detecting spikes on template-filtered data. Array of 3 values: [initial, final, final pass]. (Default [4.0, 10.0, 10.0])",
101+
"lam": "Regularization parameter for template amplitudes. Large means amplitudes are forced around the mean. Array of 3 values: [initial, final, final pass]. (Default [5.0, 5.0, 5.0])",
102+
"nannealpasses": "Number of annealing passes. Should be less than nfullpasses. (Default 4.0)",
103+
"momentum": "Momentum for optimization. Array of 2 values: [initial, final]. (Default [1/20, 1/400])",
104+
"shuffle_clusters": "Allow merges and splits during optimization. (Default 1.0 or True)",
105+
"mergeT": "Upper threshold for merging clusters. (Default 0.1)",
106+
"splitT": "Lower threshold for splitting clusters. (Default 0.1)",
107+
"initialize": "How to initialize templates. ('fromData' or 'no') (Default 'fromData')",
108+
"loc_range": "Range (time x channels) to detect peaks. (Default [3.0, 1.0])",
109+
"long_range": "Range (time x channels) to detect isolated peaks. (Default [30.0, 6.0])",
110+
"maskMaxChannels": "How many channels to mask up/down when extracting PCs. (Default 5.0)",
111+
"crit": "Upper criterion for discarding spike repeats. (Default 0.65)",
112+
"nFiltMax": "Maximum number of 'unique' spikes to consider for template initialization. (Default 10000.0)",
113+
"fracse": "Binning step along discriminant axis for posthoc merges (in units of sd). (Default 0.1)",
114+
"epu": "Drift correction parameter ( Inf = no drift correction). (Default np.inf)",
115+
"ForceMaxRAMforDat": "Maximum RAM the algorithm will try to use (in bytes). (Default 20e9)",
65116
}
66117

67118
sorter_description = """Kilosort is a GPU-accelerated and efficient template-matching spike sorter.
@@ -150,33 +201,50 @@ def _get_specific_options(cls, ops, params) -> dict:
150201

151202
# TODO: Check GPU option!
152203
ops["GPU"] = params["useGPU"] # whether to run this code on an Nvidia GPU (much faster, mexGPUall first)
153-
ops["parfor"] = 0.0 # whether to use parfor to accelerate some parts of the algorithm
204+
ops["parfor"] = params["parfor"] # whether to use parfor to accelerate some parts of the algorithm
154205
ops["verbose"] = 1.0 # whether to print command line progress
155206
ops["showfigures"] = 0.0 # whether to plot figures during optimization
156207

157208
ops["Nfilt"] = params[
158209
"Nfilt"
159210
] # number of clusters to use (2-4 times more than Nchan, should be a multiple of 32)
160-
ops["nNeighPC"] = min(
161-
12.0, ops["Nchan"]
162-
) # visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12)
163-
ops["nNeigh"] = 16.0 # visualization only (Phy): number of neighboring templates to retain projections of (16)
211+
212+
# ops["nNeighPC"] = min(12.0, ops["Nchan"]) # Original Kilosort default logic
213+
if params["nNeighPC"] is not None:
214+
ops["nNeighPC"] = params["nNeighPC"]
215+
else:
216+
# Kilosort's default behavior if nNeighPC is None in params (from _default_params)
217+
ops["nNeighPC"] = min(
218+
12.0, ops["Nchan"]
219+
) # visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12)
220+
221+
ops["nNeigh"] = params[
222+
"nNeigh"
223+
] # visualization only (Phy): number of neighboring templates to retain projections of (16)
164224

165225
# options for channel whitening
166-
ops["whitening"] = (
167-
"full" # type of whitening (default 'full', for 'noSpikes' set options for spike detection below)
168-
)
169-
ops["nSkipCov"] = 1.0 # compute whitening matrix from every N-th batch (1)
170-
ops["whiteningRange"] = (
171-
32.0 # how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32)
172-
)
226+
ops["whitening"] = params[
227+
"whitening"
228+
] # type of whitening (default 'full', for 'noSpikes' set options for spike detection below)
229+
ops["nSkipCov"] = params["nSkipCov"] # compute whitening matrix from every N-th batch (1)
230+
231+
if params["whiteningRange"] > 32:
232+
n_channels_whitening = params["whiteningRange"] if np.isfinite(params["whiteningRange"]) else "all"
233+
warn(
234+
"Kilosort recommends whitening with 32 or fewer channels. "
235+
f"However, you are whitening with {n_channels_whitening} channels."
236+
)
237+
238+
ops["whiteningRange"] = params[
239+
"whiteningRange"
240+
] # how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32)
173241

174242
# ops['criterionNoiseChannels'] = 0.2 # fraction of "noise" templates allowed to span all channel groups (see createChannelMapFile for more info).
175243

176244
# other options for controlling the model and optimization
177-
ops["Nrank"] = 3.0 # matrix rank of spike template model (3)
178-
ops["nfullpasses"] = 6.0 # number of complete passes through data during optimization (6)
179-
ops["maxFR"] = 20000 # maximum number of spikes to extract per batch (20000)
245+
ops["Nrank"] = params["Nrank"] # matrix rank of spike template model (3)
246+
ops["nfullpasses"] = params["nfullpasses"] # number of complete passes through data during optimization (6)
247+
ops["maxFR"] = params["maxFR"] # maximum number of spikes to extract per batch (20000)
180248
ops["fshigh"] = params["freq_min"] # frequency for high pass filtering
181249
ops["fslow"] = params["freq_max"] # frequency for low pass filtering (optional)
182250
ops["ntbuff"] = params["ntbuff"] # samples of symmetrical buffer for whitening and spike detection
@@ -188,27 +256,33 @@ def _get_specific_options(cls, ops, params) -> dict:
188256
# the following options can improve/deteriorate results.
189257
# when multiple values are provided for an option, the first two are beginning and ending anneal values,
190258
# the third is the value used in the final pass.
191-
ops["Th"] = [4.0, 10.0, 10.0] # threshold for detecting spikes on template-filtered data ([6 12 12])
192-
ops["lam"] = [5.0, 5.0, 5.0] # large means amplitudes are forced around the mean ([10 30 30])
193-
ops["nannealpasses"] = 4.0 # should be less than nfullpasses (4)
194-
ops["momentum"] = [1 / 20, 1 / 400] # start with high momentum and anneal (1./[20 1000])
195-
ops["shuffle_clusters"] = 1.0 # allow merges and splits during optimization (1)
196-
ops["mergeT"] = 0.1 # upper threshold for merging (.1)
197-
ops["splitT"] = 0.1 # lower threshold for splitting (.1)
198-
199-
ops["initialize"] = "fromData" # 'fromData' or 'no'
259+
ops["Th"] = params["Th"] # threshold for detecting spikes on template-filtered data ([6 12 12])
260+
ops["lam"] = params["lam"] # large means amplitudes are forced around the mean ([10 30 30])
261+
ops["nannealpasses"] = params["nannealpasses"] # should be less than nfullpasses (4)
262+
assert (
263+
ops["nannealpasses"] < ops["nfullpasses"]
264+
), f"{ops['nannealpasses']=} should be less than {ops['nfullpasses']=}"
265+
ops["momentum"] = params["momentum"] # start with high momentum and anneal (1./[20 1000])
266+
ops["shuffle_clusters"] = params["shuffle_clusters"] # allow merges and splits during optimization (1)
267+
ops["mergeT"] = params["mergeT"] # upper threshold for merging (.1)
268+
ops["splitT"] = params["splitT"] # lower threshold for splitting (.1)
269+
270+
ops["initialize"] = params["initialize"] # 'fromData' or 'no'
271+
200272
ops["spkTh"] = -params["detect_threshold"] # spike threshold in standard deviations (-6)
201-
ops["loc_range"] = [3.0, 1.0] # ranges to detect peaks; plus/minus in time and channel ([3 1])
202-
ops["long_range"] = [30.0, 6.0] # ranges to detect isolated peaks ([30 6])
203-
ops["maskMaxChannels"] = 5.0 # how many channels to mask up/down ([5])
204-
ops["crit"] = 0.65 # upper criterion for discarding spike repeates (0.65)
205-
ops["nFiltMax"] = 10000.0 # maximum "unique" spikes to consider (10000)
273+
ops["loc_range"] = params["loc_range"] # ranges to detect peaks; plus/minus in time and channel ([3 1])
274+
ops["long_range"] = params["long_range"] # ranges to detect isolated peaks ([30 6])
275+
ops["maskMaxChannels"] = params["maskMaxChannels"] # how many channels to mask up/down ([5])
276+
ops["crit"] = params["crit"] # upper criterion for discarding spike repeates (0.65)
277+
ops["nFiltMax"] = params["nFiltMax"] # maximum "unique" spikes to consider (10000)
206278

207279
# options for posthoc merges (under construction)
208-
ops["fracse"] = 0.1 # binning step along discriminant axis for posthoc merges (in units of sd)
209-
ops["epu"] = np.inf
280+
ops["fracse"] = params["fracse"] # binning step along discriminant axis for posthoc merges (in units of sd)
281+
ops["epu"] = params["epu"]
210282

211-
ops["ForceMaxRAMforDat"] = 20e9 # maximum RAM the algorithm will try to use; on Windows it will autodetect.
283+
ops["ForceMaxRAMforDat"] = params[
284+
"ForceMaxRAMforDat"
285+
] # maximum RAM the algorithm will try to use; on Windows it will autodetect.
212286

213287
## option for wavelength
214288
ops["nt0"] = params[

0 commit comments

Comments
 (0)