Skip to content

Commit f4b908c

Browse files
Returned svd (#3847)
* WIP * Example of how to use SVD to estimate templates in SC2 * Patching to get a working example * WIP * WIP * WIP * WIP * WIP * WIP * Cosmetic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Patch * WIP * WIP * Fix * WIP * WIP * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make delete_mixtures optional * Better logs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a4a938d commit f4b908c

File tree

6 files changed

+300
-326
lines changed

6 files changed

+300
-326
lines changed

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 165 additions & 148 deletions
Large diffs are not rendered by default.

src/spikeinterface/sortingcomponents/clustering/circus.py

Lines changed: 92 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,10 @@
1919
from .clustering_tools import remove_duplicates_via_matching
2020
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
2121
from spikeinterface.sortingcomponents.peak_selection import select_peaks
22-
from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection
23-
from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter
2422
from spikeinterface.core.template import Templates
2523
from spikeinterface.core.sparsity import compute_sparsity
2624
from spikeinterface.sortingcomponents.tools import remove_empty_templates
27-
import pickle, json
28-
from spikeinterface.core.node_pipeline import (
29-
run_node_pipeline,
30-
ExtractSparseWaveforms,
31-
PeakRetriever,
32-
)
25+
from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd
3326

3427

3528
from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel
@@ -48,20 +41,24 @@ class CircusClustering:
4841
"allow_single_cluster": True,
4942
},
5043
"cleaning_kwargs": {},
44+
"remove_mixtures": False,
5145
"waveforms": {"ms_before": 2, "ms_after": 2},
5246
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
5347
"recursive_kwargs": {
5448
"recursive": True,
5549
"recursive_depth": 3,
5650
"returns_split_count": True,
5751
},
52+
"split_kwargs": {"projection_mode": "tsvd", "n_pca_features": 0.9},
5853
"radius_um": 100,
54+
"neighbors_radius_um": 50,
5955
"n_svd": 5,
6056
"few_waveforms": None,
6157
"ms_before": 0.5,
6258
"ms_after": 0.5,
6359
"noise_threshold": 4,
6460
"rank": 5,
61+
"templates_from_svd": False,
6562
"noise_levels": None,
6663
"tmp_folder": None,
6764
"verbose": True,
@@ -78,6 +75,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
7875
fs = recording.get_sampling_frequency()
7976
ms_before = params["ms_before"]
8077
ms_after = params["ms_after"]
78+
radius_um = params["radius_um"]
79+
neighbors_radius_um = params["neighbors_radius_um"]
8180
nbefore = int(ms_before * fs / 1000.0)
8281
nafter = int(ms_after * fs / 1000.0)
8382
if params["tmp_folder"] is None:
@@ -108,210 +107,139 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
108107
valid = np.argmax(np.abs(wfs), axis=1) == nbefore
109108
wfs = wfs[valid]
110109

111-
# Perform Hanning filtering
112-
hanning_before = np.hanning(2 * nbefore)
113-
hanning_after = np.hanning(2 * nafter)
114-
hanning = np.concatenate((hanning_before[:nbefore], hanning_after[nafter:]))
115-
wfs *= hanning
116-
117110
from sklearn.decomposition import TruncatedSVD
118111

119-
tsvd = TruncatedSVD(params["n_svd"])
120-
tsvd.fit(wfs)
121-
122-
model_folder = tmp_folder / "tsvd_model"
123-
124-
model_folder.mkdir(exist_ok=True)
125-
with open(model_folder / "pca_model.pkl", "wb") as f:
126-
pickle.dump(tsvd, f)
127-
128-
model_params = {
129-
"ms_before": ms_before,
130-
"ms_after": ms_after,
131-
"sampling_frequency": float(fs),
132-
}
133-
134-
with open(model_folder / "params.json", "w") as f:
135-
json.dump(model_params, f)
112+
svd_model = TruncatedSVD(params["n_svd"])
113+
svd_model.fit(wfs)
114+
features_folder = tmp_folder / "tsvd_features"
115+
features_folder.mkdir(exist_ok=True)
136116

137-
# features
138-
node0 = PeakRetriever(recording, peaks)
139-
140-
radius_um = params["radius_um"]
141-
node1 = ExtractSparseWaveforms(
117+
peaks_svd, sparse_mask, svd_model = extract_peaks_svd(
142118
recording,
143-
parents=[node0],
144-
return_output=False,
119+
peaks,
145120
ms_before=ms_before,
146121
ms_after=ms_after,
122+
svd_model=svd_model,
147123
radius_um=radius_um,
124+
folder=features_folder,
125+
**job_kwargs,
148126
)
149127

150-
node2 = HanningFilter(recording, parents=[node0, node1], return_output=False)
128+
neighbours_mask = get_channel_distances(recording) <= neighbors_radius_um
151129

152-
node3 = TemporalPCAProjection(
153-
recording, parents=[node0, node2], return_output=True, model_folder_path=model_folder
154-
)
130+
if params["debug"]:
131+
np.save(features_folder / "sparse_mask.npy", sparse_mask)
132+
np.save(features_folder / "peaks.npy", peaks)
155133

156-
pipeline_nodes = [node0, node1, node2, node3]
134+
original_labels = peaks["channel_index"]
135+
from spikeinterface.sortingcomponents.clustering.split import split_clusters
157136

158-
if len(params["recursive_kwargs"]) == 0:
159-
from sklearn.decomposition import PCA
137+
split_kwargs = params["split_kwargs"].copy()
138+
split_kwargs["neighbours_mask"] = neighbours_mask
139+
split_kwargs["waveforms_sparse_mask"] = sparse_mask
140+
split_kwargs["min_size_split"] = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 50)
141+
split_kwargs["clusterer_kwargs"] = params["hdbscan_kwargs"]
160142

161-
all_pc_data = run_node_pipeline(
162-
recording,
163-
pipeline_nodes,
164-
job_kwargs,
165-
job_name="extracting features",
166-
)
167-
168-
peak_labels = -1 * np.ones(len(peaks), dtype=int)
169-
nb_clusters = 0
170-
for c in np.unique(peaks["channel_index"]):
171-
mask = peaks["channel_index"] == c
172-
sub_data = all_pc_data[mask]
173-
sub_data = sub_data.reshape(len(sub_data), -1)
174-
175-
if all_pc_data.shape[1] > params["n_svd"]:
176-
tsvd = PCA(params["n_svd"], whiten=True)
177-
else:
178-
tsvd = PCA(all_pc_data.shape[1], whiten=True)
179-
180-
hdbscan_data = tsvd.fit_transform(sub_data)
181-
try:
182-
clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"])
183-
local_labels = clustering[0]
184-
except Exception:
185-
local_labels = np.zeros(len(hdbscan_data))
186-
valid_clusters = local_labels > -1
187-
if np.sum(valid_clusters) > 0:
188-
local_labels[valid_clusters] += nb_clusters
189-
peak_labels[mask] = local_labels
190-
nb_clusters += len(np.unique(local_labels[valid_clusters]))
143+
if params["debug"]:
144+
debug_folder = tmp_folder / "split"
191145
else:
146+
debug_folder = None
192147

193-
features_folder = tmp_folder / "tsvd_features"
194-
features_folder.mkdir(exist_ok=True)
195-
196-
_ = run_node_pipeline(
197-
recording,
198-
pipeline_nodes,
199-
job_kwargs,
200-
job_name="extracting features",
201-
gather_mode="npy",
202-
gather_kwargs=dict(exist_ok=True),
203-
folder=features_folder,
204-
names=["sparse_tsvd"],
205-
)
206-
207-
sparse_mask = node1.neighbours_mask
208-
neighbours_mask = get_channel_distances(recording) <= radius_um
209-
210-
# np.save(features_folder / "sparse_mask.npy", sparse_mask)
211-
np.save(features_folder / "peaks.npy", peaks)
212-
213-
original_labels = peaks["channel_index"]
214-
from spikeinterface.sortingcomponents.clustering.split import split_clusters
148+
peak_labels, _ = split_clusters(
149+
original_labels,
150+
recording,
151+
{"peaks": peaks, "sparse_tsvd": peaks_svd},
152+
method="local_feature_clustering",
153+
method_kwargs=split_kwargs,
154+
debug_folder=debug_folder,
155+
**params["recursive_kwargs"],
156+
**job_kwargs,
157+
)
215158

216-
min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 20)
159+
if params["noise_levels"] is None:
160+
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)
217161

218-
if params["debug"]:
219-
debug_folder = tmp_folder / "split"
220-
else:
221-
debug_folder = None
162+
if not params["templates_from_svd"]:
163+
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording
222164

223-
peak_labels, _ = split_clusters(
224-
original_labels,
165+
templates = get_templates_from_peaks_and_recording(
225166
recording,
226-
features_folder,
227-
method="local_feature_clustering",
228-
method_kwargs=dict(
229-
clusterer="hdbscan",
230-
feature_name="sparse_tsvd",
231-
neighbours_mask=neighbours_mask,
232-
waveforms_sparse_mask=sparse_mask,
233-
min_size_split=min_size,
234-
clusterer_kwargs=d["hdbscan_kwargs"],
235-
n_pca_features=5,
236-
),
237-
debug_folder=debug_folder,
238-
**params["recursive_kwargs"],
167+
peaks,
168+
peak_labels,
169+
ms_before,
170+
ms_after,
239171
**job_kwargs,
240172
)
173+
else:
174+
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd
241175

242-
non_noise = peak_labels > -1
243-
labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True)
244-
peak_labels[non_noise] = inverse
245-
labels = np.unique(inverse)
246-
247-
spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype)
248-
spikes["sample_index"] = peaks[non_noise]["sample_index"]
249-
spikes["segment_index"] = peaks[non_noise]["segment_index"]
250-
spikes["unit_index"] = peak_labels[non_noise]
251-
252-
unit_ids = labels
253-
254-
nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0)
255-
nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0)
256-
257-
if params["noise_levels"] is None:
258-
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)
259-
260-
templates_array = estimate_templates(
261-
recording,
262-
spikes,
263-
unit_ids,
264-
nbefore,
265-
nafter,
266-
return_scaled=False,
267-
job_name=None,
268-
**job_kwargs,
269-
)
176+
templates = get_templates_from_peaks_and_svd(
177+
recording,
178+
peaks,
179+
peak_labels,
180+
ms_before,
181+
ms_after,
182+
svd_model,
183+
peaks_svd,
184+
sparse_mask,
185+
operator="median",
186+
)
270187

188+
templates_array = templates.templates_array
271189
best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1)
272190
peak_snrs = np.abs(templates_array[:, nbefore, :])
273191
best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels]
192+
old_unit_ids = templates.unit_ids.copy()
274193
valid_templates = best_snrs_ratio > params["noise_threshold"]
275194

276-
if d["rank"] is not None:
277-
from spikeinterface.sortingcomponents.matching.circus import compress_templates
195+
mask = np.isin(peak_labels, old_unit_ids[~valid_templates])
196+
peak_labels[mask] = -1
278197

279-
_, _, _, templates_array = compress_templates(templates_array, d["rank"])
198+
from spikeinterface.core.template import Templates
280199

281200
templates = Templates(
282201
templates_array=templates_array[valid_templates],
283202
sampling_frequency=fs,
284-
nbefore=nbefore,
203+
nbefore=templates.nbefore,
285204
sparsity_mask=None,
286205
channel_ids=recording.channel_ids,
287-
unit_ids=unit_ids[valid_templates],
206+
unit_ids=templates.unit_ids[valid_templates],
288207
probe=recording.get_probe(),
289208
is_scaled=False,
290209
)
291210

211+
if params["debug"]:
212+
templates_folder = tmp_folder / "dense_templates"
213+
templates.to_zarr(folder_path=templates_folder)
214+
292215
sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
293216
templates = templates.to_sparse(sparsity)
294217
empty_templates = templates.sparsity_mask.sum(axis=1) == 0
218+
old_unit_ids = templates.unit_ids.copy()
295219
templates = remove_empty_templates(templates)
296220

297-
mask = np.isin(peak_labels, np.where(empty_templates)[0])
221+
mask = np.isin(peak_labels, old_unit_ids[empty_templates])
298222
peak_labels[mask] = -1
299223

300-
mask = np.isin(peak_labels, np.where(~valid_templates)[0])
301-
peak_labels[mask] = -1
224+
labels = np.unique(peak_labels)
225+
labels = labels[labels >= 0]
302226

303-
if verbose:
304-
print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids)))
227+
if params["remove_mixtures"]:
228+
if verbose:
229+
print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids)))
305230

306-
cleaning_job_kwargs = job_kwargs.copy()
307-
cleaning_job_kwargs["progress_bar"] = False
308-
cleaning_params = params["cleaning_kwargs"].copy()
231+
cleaning_job_kwargs = job_kwargs.copy()
232+
cleaning_job_kwargs["progress_bar"] = False
233+
cleaning_params = params["cleaning_kwargs"].copy()
309234

310-
labels, peak_labels = remove_duplicates_via_matching(
311-
templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params
312-
)
235+
labels, peak_labels = remove_duplicates_via_matching(
236+
templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params
237+
)
313238

314-
if verbose:
315-
print("Kept %d non-duplicated clusters" % len(labels))
239+
if verbose:
240+
print("Kept %d non-duplicated clusters" % len(labels))
241+
else:
242+
if verbose:
243+
print("Kept %d raw clusters" % len(labels))
316244

317-
return labels, peak_labels
245+
return labels, peak_labels, svd_model, peaks_svd, sparse_mask

src/spikeinterface/sortingcomponents/clustering/graph_clustering.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
5959
radius_um = params["radius_um"]
6060
motion = params["motion"]
6161
seed = params["seed"]
62+
ms_before = params["ms_before"]
63+
ms_after = params["ms_after"]
6264
clustering_method = params["clustering_method"]
6365
clustering_kwargs = params["clustering_kwargs"]
6466
graph_kwargs = params["graph_kwargs"]
@@ -70,9 +72,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
7072
elif graph_kwargs["bin_mode"] == "vertical_bins":
7173
assert radius_um >= graph_kwargs["bin_um"] * 3
7274

73-
peaks_svd, sparse_mask, _ = extract_peaks_svd(
75+
peaks_svd, sparse_mask, svd_model = extract_peaks_svd(
7476
recording,
7577
peaks,
78+
ms_before=ms_before,
79+
ms_after=ms_after,
7680
radius_um=radius_um,
7781
motion_aware=motion_aware,
7882
motion=None,
@@ -98,7 +102,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
98102
# print(distances.shape)
99103
# print("sparsity: ", distances.indices.size / (distances.shape[0]**2))
100104

101-
print("clustering_method", clustering_method)
105+
# print("clustering_method", clustering_method)
102106

103107
if clustering_method == "networkx-louvain":
104108
# using networkx : very slow (possible backend with cude backend="cugraph",)
@@ -191,7 +195,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
191195
labels_set = np.unique(peak_labels)
192196
labels_set = labels_set[labels_set >= 0]
193197

194-
return labels_set, peak_labels
198+
return labels_set, peak_labels, svd_model, peaks_svd, sparse_mask
195199

196200

197201
def _remove_small_cluster(peak_labels, min_size=1):

src/spikeinterface/sortingcomponents/clustering/graph_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def create_graph_from_peak_features(
118118
raise ValueError("create_graph_from_peak_features : wrong bin_mode")
119119

120120
if progress_bar:
121-
loop = tqdm(loop, desc=f"Construct distance graph looping over {bin_mode}")
121+
loop = tqdm(loop, desc=f"Build distance graph over {bin_mode}")
122122

123123
local_graphs = []
124124
row_indices = []

0 commit comments

Comments
 (0)