Skip to content

Commit e0a1337

Browse files
authored
Merge branch 'main' into import-cleanup-again
2 parents 9488bd2 + 2edc731 commit e0a1337

File tree

7 files changed

+247
-37
lines changed

7 files changed

+247
-37
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,4 @@ If you find SpikeInterface useful in your research, please cite:
134134
```
135135

136136
Please also cite other relevant papers for the specific components you use.
137-
For a ful list of references, please check the [references](https://spikeinterface.readthedocs.io/en/latest/references.html) page.
137+
For a full list of references, please check the [references](https://spikeinterface.readthedocs.io/en/latest/references.html) page.

src/spikeinterface/extractors/neoextractors/openephys.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,6 @@ def __init__(
160160
):
161161

162162
if load_sync_channel:
163-
import warnings
164-
165163
warning_message = (
166164
"OpenEphysBinaryRecordingExtractor: load_sync_channel is deprecated and will"
167165
"be removed in version 0.104, use the stream_name or stream_id to load the sync stream if needed"

src/spikeinterface/sorters/external/kilosort.py

Lines changed: 106 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
import os
55
from typing import Union
6+
from warnings import warn
67
import numpy as np
78

89
from spikeinterface.sorters.basesorter import BaseSorter
@@ -46,6 +47,31 @@ class KilosortSorter(KilosortBase, BaseSorter):
4647
"wave_length": 61,
4748
"delete_tmp_files": ("matlab_files",),
4849
"delete_recording_dat": False,
50+
"parfor": 0.0,
51+
"nNeighPC": None,
52+
"nNeigh": 16.0,
53+
"whitening": "full",
54+
"nSkipCov": 1.0,
55+
"whiteningRange": 32.0,
56+
"Nrank": 3.0,
57+
"nfullpasses": 6.0,
58+
"maxFR": 20000,
59+
"Th": [4.0, 10.0, 10.0],
60+
"lam": [5.0, 5.0, 5.0],
61+
"nannealpasses": 4.0,
62+
"momentum": [1 / 20, 1 / 400],
63+
"shuffle_clusters": 1.0,
64+
"mergeT": 0.1,
65+
"splitT": 0.1,
66+
"initialize": "fromData",
67+
"loc_range": [3.0, 1.0],
68+
"long_range": [30.0, 6.0],
69+
"maskMaxChannels": 5.0,
70+
"crit": 0.65,
71+
"nFiltMax": 10000.0,
72+
"fracse": 0.1,
73+
"epu": np.inf,
74+
"ForceMaxRAMforDat": 20e9,
4975
}
5076

5177
_params_description = {
@@ -62,6 +88,31 @@ class KilosortSorter(KilosortBase, BaseSorter):
6288
"contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) "
6389
"or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')",
6490
"delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run",
91+
"parfor": "Whether to use parfor to accelerate some parts of the algorithm. (0.0 or 1.0)",
92+
"nNeighPC": "Number of channels to mask the PCs for visualization (Phy). None to skip, default is min(12, Nchan).",
93+
"nNeigh": "Number of neighboring templates to retain projections of for visualization (Phy). (Default 16.0)",
94+
"whitening": "Type of whitening. (Default 'full', or 'noSpikes')",
95+
"nSkipCov": "Compute whitening matrix from every N-th batch. (Default 1.0)",
96+
"whiteningRange": "How many channels to whiten together. (Inf for whole probe whitening, default 32.0)",
97+
"Nrank": "Matrix rank of spike template model. (Default 3.0)",
98+
"nfullpasses": "Number of complete passes through data during optimization. (Default 6.0)",
99+
"maxFR": "Maximum number of spikes to extract per batch. (Default 20000)",
100+
"Th": "Threshold for detecting spikes on template-filtered data. Array of 3 values: [initial, final, final pass]. (Default [4.0, 10.0, 10.0])",
101+
"lam": "Regularization parameter for template amplitudes. Large means amplitudes are forced around the mean. Array of 3 values: [initial, final, final pass]. (Default [5.0, 5.0, 5.0])",
102+
"nannealpasses": "Number of annealing passes. Should be less than nfullpasses. (Default 4.0)",
103+
"momentum": "Momentum for optimization. Array of 2 values: [initial, final]. (Default [1/20, 1/400])",
104+
"shuffle_clusters": "Allow merges and splits during optimization. (Default 1.0 or True)",
105+
"mergeT": "Upper threshold for merging clusters. (Default 0.1)",
106+
"splitT": "Lower threshold for splitting clusters. (Default 0.1)",
107+
"initialize": "How to initialize templates. ('fromData' or 'no') (Default 'fromData')",
108+
"loc_range": "Range (time x channels) to detect peaks. (Default [3.0, 1.0])",
109+
"long_range": "Range (time x channels) to detect isolated peaks. (Default [30.0, 6.0])",
110+
"maskMaxChannels": "How many channels to mask up/down when extracting PCs. (Default 5.0)",
111+
"crit": "Upper criterion for discarding spike repeats. (Default 0.65)",
112+
"nFiltMax": "Maximum number of 'unique' spikes to consider for template initialization. (Default 10000.0)",
113+
"fracse": "Binning step along discriminant axis for posthoc merges (in units of sd). (Default 0.1)",
114+
"epu": "Drift correction parameter ( Inf = no drift correction). (Default np.inf)",
115+
"ForceMaxRAMforDat": "Maximum RAM the algorithm will try to use (in bytes). (Default 20e9)",
65116
}
66117

67118
sorter_description = """Kilosort is a GPU-accelerated and efficient template-matching spike sorter.
@@ -150,33 +201,50 @@ def _get_specific_options(cls, ops, params) -> dict:
150201

151202
# TODO: Check GPU option!
152203
ops["GPU"] = params["useGPU"] # whether to run this code on an Nvidia GPU (much faster, mexGPUall first)
153-
ops["parfor"] = 0.0 # whether to use parfor to accelerate some parts of the algorithm
204+
ops["parfor"] = params["parfor"] # whether to use parfor to accelerate some parts of the algorithm
154205
ops["verbose"] = 1.0 # whether to print command line progress
155206
ops["showfigures"] = 0.0 # whether to plot figures during optimization
156207

157208
ops["Nfilt"] = params[
158209
"Nfilt"
159210
] # number of clusters to use (2-4 times more than Nchan, should be a multiple of 32)
160-
ops["nNeighPC"] = min(
161-
12.0, ops["Nchan"]
162-
) # visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12)
163-
ops["nNeigh"] = 16.0 # visualization only (Phy): number of neighboring templates to retain projections of (16)
211+
212+
# ops["nNeighPC"] = min(12.0, ops["Nchan"]) # Original Kilosort default logic
213+
if params["nNeighPC"] is not None:
214+
ops["nNeighPC"] = params["nNeighPC"]
215+
else:
216+
# Kilosort's default behavior if nNeighPC is None in params (from _default_params)
217+
ops["nNeighPC"] = min(
218+
12.0, ops["Nchan"]
219+
) # visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12)
220+
221+
ops["nNeigh"] = params[
222+
"nNeigh"
223+
] # visualization only (Phy): number of neighboring templates to retain projections of (16)
164224

165225
# options for channel whitening
166-
ops["whitening"] = (
167-
"full" # type of whitening (default 'full', for 'noSpikes' set options for spike detection below)
168-
)
169-
ops["nSkipCov"] = 1.0 # compute whitening matrix from every N-th batch (1)
170-
ops["whiteningRange"] = (
171-
32.0 # how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32)
172-
)
226+
ops["whitening"] = params[
227+
"whitening"
228+
] # type of whitening (default 'full', for 'noSpikes' set options for spike detection below)
229+
ops["nSkipCov"] = params["nSkipCov"] # compute whitening matrix from every N-th batch (1)
230+
231+
if params["whiteningRange"] > 32:
232+
n_channels_whitening = params["whiteningRange"] if np.isfinite(params["whiteningRange"]) else "all"
233+
warn(
234+
"Kilosort recommends whitening with 32 or fewer channels. "
235+
f"However, you are whitening with {n_channels_whitening} channels."
236+
)
237+
238+
ops["whiteningRange"] = params[
239+
"whiteningRange"
240+
] # how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32)
173241

174242
# ops['criterionNoiseChannels'] = 0.2 # fraction of "noise" templates allowed to span all channel groups (see createChannelMapFile for more info).
175243

176244
# other options for controlling the model and optimization
177-
ops["Nrank"] = 3.0 # matrix rank of spike template model (3)
178-
ops["nfullpasses"] = 6.0 # number of complete passes through data during optimization (6)
179-
ops["maxFR"] = 20000 # maximum number of spikes to extract per batch (20000)
245+
ops["Nrank"] = params["Nrank"] # matrix rank of spike template model (3)
246+
ops["nfullpasses"] = params["nfullpasses"] # number of complete passes through data during optimization (6)
247+
ops["maxFR"] = params["maxFR"] # maximum number of spikes to extract per batch (20000)
180248
ops["fshigh"] = params["freq_min"] # frequency for high pass filtering
181249
ops["fslow"] = params["freq_max"] # frequency for low pass filtering (optional)
182250
ops["ntbuff"] = params["ntbuff"] # samples of symmetrical buffer for whitening and spike detection
@@ -188,27 +256,33 @@ def _get_specific_options(cls, ops, params) -> dict:
188256
# the following options can improve/deteriorate results.
189257
# when multiple values are provided for an option, the first two are beginning and ending anneal values,
190258
# the third is the value used in the final pass.
191-
ops["Th"] = [4.0, 10.0, 10.0] # threshold for detecting spikes on template-filtered data ([6 12 12])
192-
ops["lam"] = [5.0, 5.0, 5.0] # large means amplitudes are forced around the mean ([10 30 30])
193-
ops["nannealpasses"] = 4.0 # should be less than nfullpasses (4)
194-
ops["momentum"] = [1 / 20, 1 / 400] # start with high momentum and anneal (1./[20 1000])
195-
ops["shuffle_clusters"] = 1.0 # allow merges and splits during optimization (1)
196-
ops["mergeT"] = 0.1 # upper threshold for merging (.1)
197-
ops["splitT"] = 0.1 # lower threshold for splitting (.1)
198-
199-
ops["initialize"] = "fromData" # 'fromData' or 'no'
259+
ops["Th"] = params["Th"] # threshold for detecting spikes on template-filtered data ([6 12 12])
260+
ops["lam"] = params["lam"] # large means amplitudes are forced around the mean ([10 30 30])
261+
ops["nannealpasses"] = params["nannealpasses"] # should be less than nfullpasses (4)
262+
assert (
263+
ops["nannealpasses"] < ops["nfullpasses"]
264+
), f"{ops['nannealpasses']=} should be less than {ops['nfullpasses']=}"
265+
ops["momentum"] = params["momentum"] # start with high momentum and anneal (1./[20 1000])
266+
ops["shuffle_clusters"] = params["shuffle_clusters"] # allow merges and splits during optimization (1)
267+
ops["mergeT"] = params["mergeT"] # upper threshold for merging (.1)
268+
ops["splitT"] = params["splitT"] # lower threshold for splitting (.1)
269+
270+
ops["initialize"] = params["initialize"] # 'fromData' or 'no'
271+
200272
ops["spkTh"] = -params["detect_threshold"] # spike threshold in standard deviations (-6)
201-
ops["loc_range"] = [3.0, 1.0] # ranges to detect peaks; plus/minus in time and channel ([3 1])
202-
ops["long_range"] = [30.0, 6.0] # ranges to detect isolated peaks ([30 6])
203-
ops["maskMaxChannels"] = 5.0 # how many channels to mask up/down ([5])
204-
ops["crit"] = 0.65 # upper criterion for discarding spike repeates (0.65)
205-
ops["nFiltMax"] = 10000.0 # maximum "unique" spikes to consider (10000)
273+
ops["loc_range"] = params["loc_range"] # ranges to detect peaks; plus/minus in time and channel ([3 1])
274+
ops["long_range"] = params["long_range"] # ranges to detect isolated peaks ([30 6])
275+
ops["maskMaxChannels"] = params["maskMaxChannels"] # how many channels to mask up/down ([5])
276+
ops["crit"] = params["crit"] # upper criterion for discarding spike repeates (0.65)
277+
ops["nFiltMax"] = params["nFiltMax"] # maximum "unique" spikes to consider (10000)
206278

207279
# options for posthoc merges (under construction)
208-
ops["fracse"] = 0.1 # binning step along discriminant axis for posthoc merges (in units of sd)
209-
ops["epu"] = np.inf
280+
ops["fracse"] = params["fracse"] # binning step along discriminant axis for posthoc merges (in units of sd)
281+
ops["epu"] = params["epu"]
210282

211-
ops["ForceMaxRAMforDat"] = 20e9 # maximum RAM the algorithm will try to use; on Windows it will autodetect.
283+
ops["ForceMaxRAMforDat"] = params[
284+
"ForceMaxRAMforDat"
285+
] # maximum RAM the algorithm will try to use; on Windows it will autodetect.
212286

213287
## option for wavelength
214288
ops["nt0"] = params[

src/spikeinterface/sortingcomponents/matching/tdc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def __init__(
152152
self.sparse_templates_array_static = None
153153

154154
# interpolation bins edges
155-
156155
self.interpolation_time_bins_s = []
157156
self.interpolation_time_bin_edges_s = []
158157
for segment_index, parent_segment in enumerate(recording._recording_segments):
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from __future__ import annotations
2+
3+
from probeinterface import ProbeGroup
4+
5+
from .base import BaseWidget, to_attr
6+
from .utils import get_unit_colors
7+
from spikeinterface.core.sortinganalyzer import SortingAnalyzer
8+
9+
from .unit_templates import UnitTemplatesWidget
10+
from ..core import Templates
11+
12+
13+
class DriftingTemplatesWidget(BaseWidget):
14+
"""
15+
Plot a drifting templates object to explore motion
16+
17+
Parameters
18+
----------
19+
drifting_templates :
20+
A drifting templates object
21+
scale : float, default: 1
22+
Scale factor for the waveforms/templates (matplotlib backend)
23+
"""
24+
25+
def __init__(
26+
self,
27+
drifting_templates: SortingAnalyzer,
28+
scale=1,
29+
backend=None,
30+
**backend_kwargs,
31+
):
32+
self.drifting_templates = drifting_templates
33+
34+
data_plot = dict(
35+
drifting_templates=drifting_templates,
36+
)
37+
38+
BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs)
39+
40+
def plot_ipywidgets(self, data_plot, **backend_kwargs):
41+
import matplotlib.pyplot as plt
42+
import ipywidgets.widgets as widgets
43+
from IPython.display import display
44+
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector
45+
46+
check_ipywidget_backend()
47+
48+
# self.next_data_plot = data_plot.copy()
49+
self.drifting_templates = data_plot["drifting_templates"]
50+
51+
cm = 1 / 2.54
52+
53+
width_cm = backend_kwargs["width_cm"]
54+
height_cm = backend_kwargs["height_cm"]
55+
56+
ratios = [0.15, 0.85]
57+
58+
with plt.ioff():
59+
output = widgets.Output()
60+
with output:
61+
fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
62+
plt.show()
63+
64+
unit_ids = self.drifting_templates.unit_ids
65+
self.unit_selector = UnitSelector(unit_ids)
66+
self.unit_selector.value = list(unit_ids)[:1]
67+
68+
arr = self.drifting_templates.templates_array_moved
69+
70+
self.slider = widgets.IntSlider(
71+
orientation="horizontal",
72+
value=arr.shape[0] // 2,
73+
min=0,
74+
max=arr.shape[0] - 1,
75+
readout=False,
76+
continuous_update=True,
77+
layout=widgets.Layout(width=f"100%"),
78+
)
79+
80+
self.widget = widgets.AppLayout(
81+
center=fig.canvas,
82+
left_sidebar=self.unit_selector,
83+
pane_widths=ratios + [0],
84+
footer=self.slider,
85+
)
86+
87+
self._update_ipywidget()
88+
89+
self.unit_selector.observe(self._change_unit, names="value", type="change")
90+
self.slider.observe(self._change_displacement, names="value", type="change")
91+
92+
if backend_kwargs["display"]:
93+
display(self.widget)
94+
95+
def _change_unit(self, change=None):
96+
self._update_ipywidget(keep_lims=False)
97+
98+
def _change_displacement(self, change=None):
99+
self._update_ipywidget(keep_lims=True)
100+
101+
def _update_ipywidget(self, keep_lims=False):
102+
if keep_lims:
103+
xlim = self.ax.get_xlim()
104+
ylim = self.ax.get_ylim()
105+
106+
self.ax.clear()
107+
unit_ids = self.unit_selector.value
108+
109+
displacement_index = self.slider.value
110+
111+
templates_array = self.drifting_templates.templates_array_moved[displacement_index, :, :, :]
112+
templates = Templates(
113+
templates_array,
114+
self.drifting_templates.sampling_frequency,
115+
self.drifting_templates.nbefore,
116+
is_scaled=self.drifting_templates.is_scaled,
117+
sparsity_mask=None,
118+
channel_ids=self.drifting_templates.channel_ids,
119+
unit_ids=self.drifting_templates.unit_ids,
120+
probe=self.drifting_templates.probe,
121+
)
122+
123+
UnitTemplatesWidget(
124+
templates, unit_ids=unit_ids, scale=5, plot_legend=False, backend="matplotlib", ax=self.ax, same_axis=True
125+
)
126+
127+
displacement = self.drifting_templates.displacements[displacement_index]
128+
self.ax.set_title(f"{displacement_index}:{displacement} - untis:{unit_ids}")
129+
130+
if keep_lims:
131+
self.ax.set_xlim(xlim)
132+
self.ax.set_ylim(ylim)
133+
134+
fig = self.ax.get_figure()
135+
fig.canvas.draw()
136+
fig.canvas.flush_events()

src/spikeinterface/widgets/unit_waveforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167

168168
# get templates
169169
if isinstance(sorting_analyzer_or_templates, Templates):
170-
templates = sorting_analyzer_or_templates.templates_array
170+
templates = sorting_analyzer_or_templates.select_units(unit_ids).templates_array
171171
nbefore = sorting_analyzer_or_templates.nbefore
172172
self.templates_ext = None
173173
templates_shading = None

src/spikeinterface/widgets/widget_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .amplitudes import AmplitudesWidget
1010
from .autocorrelograms import AutoCorrelogramsWidget
1111
from .crosscorrelograms import CrossCorrelogramsWidget
12+
from .drift_templates import DriftingTemplatesWidget
1213
from .isi_distribution import ISIDistributionWidget
1314
from .motion import DriftRasterMapWidget, MotionWidget, MotionInfoWidget
1415
from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget
@@ -45,6 +46,7 @@
4546
ConfusionMatrixWidget,
4647
ComparisonCollisionBySimilarityWidget,
4748
CrossCorrelogramsWidget,
49+
DriftingTemplatesWidget,
4850
DriftRasterMapWidget,
4951
ISIDistributionWidget,
5052
LocationsWidget,
@@ -124,6 +126,7 @@
124126
plot_confusion_matrix = ConfusionMatrixWidget
125127
plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget
126128
plot_crosscorrelograms = CrossCorrelogramsWidget
129+
plot_drifting_templates = DriftingTemplatesWidget
127130
plot_drift_raster_map = DriftRasterMapWidget
128131
plot_isi_distribution = ISIDistributionWidget
129132
plot_locations = LocationsWidget

0 commit comments

Comments
 (0)