19
19
from .clustering_tools import remove_duplicates_via_matching
20
20
from spikeinterface .core .recording_tools import get_noise_levels , get_channel_distances
21
21
from spikeinterface .sortingcomponents .peak_selection import select_peaks
22
- from spikeinterface .sortingcomponents .waveforms .temporal_pca import TemporalPCAProjection
23
- from spikeinterface .sortingcomponents .waveforms .hanning_filter import HanningFilter
24
22
from spikeinterface .core .template import Templates
25
23
from spikeinterface .core .sparsity import compute_sparsity
26
24
from spikeinterface .sortingcomponents .tools import remove_empty_templates
27
- import pickle , json
28
- from spikeinterface .core .node_pipeline import (
29
- run_node_pipeline ,
30
- ExtractSparseWaveforms ,
31
- PeakRetriever ,
32
- )
25
+ from spikeinterface .sortingcomponents .clustering .peak_svd import extract_peaks_svd
33
26
34
27
35
28
from spikeinterface .sortingcomponents .tools import extract_waveform_at_max_channel
@@ -48,20 +41,24 @@ class CircusClustering:
48
41
"allow_single_cluster" : True ,
49
42
},
50
43
"cleaning_kwargs" : {},
44
+ "remove_mixtures" : False ,
51
45
"waveforms" : {"ms_before" : 2 , "ms_after" : 2 },
52
46
"sparsity" : {"method" : "snr" , "amplitude_mode" : "peak_to_peak" , "threshold" : 0.25 },
53
47
"recursive_kwargs" : {
54
48
"recursive" : True ,
55
49
"recursive_depth" : 3 ,
56
50
"returns_split_count" : True ,
57
51
},
52
+ "split_kwargs" : {"projection_mode" : "tsvd" , "n_pca_features" : 0.9 },
58
53
"radius_um" : 100 ,
54
+ "neighbors_radius_um" : 50 ,
59
55
"n_svd" : 5 ,
60
56
"few_waveforms" : None ,
61
57
"ms_before" : 0.5 ,
62
58
"ms_after" : 0.5 ,
63
59
"noise_threshold" : 4 ,
64
60
"rank" : 5 ,
61
+ "templates_from_svd" : False ,
65
62
"noise_levels" : None ,
66
63
"tmp_folder" : None ,
67
64
"verbose" : True ,
@@ -78,6 +75,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
78
75
fs = recording .get_sampling_frequency ()
79
76
ms_before = params ["ms_before" ]
80
77
ms_after = params ["ms_after" ]
78
+ radius_um = params ["radius_um" ]
79
+ neighbors_radius_um = params ["neighbors_radius_um" ]
81
80
nbefore = int (ms_before * fs / 1000.0 )
82
81
nafter = int (ms_after * fs / 1000.0 )
83
82
if params ["tmp_folder" ] is None :
@@ -108,210 +107,139 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
108
107
valid = np .argmax (np .abs (wfs ), axis = 1 ) == nbefore
109
108
wfs = wfs [valid ]
110
109
111
- # Perform Hanning filtering
112
- hanning_before = np .hanning (2 * nbefore )
113
- hanning_after = np .hanning (2 * nafter )
114
- hanning = np .concatenate ((hanning_before [:nbefore ], hanning_after [nafter :]))
115
- wfs *= hanning
116
-
117
110
from sklearn .decomposition import TruncatedSVD
118
111
119
- tsvd = TruncatedSVD (params ["n_svd" ])
120
- tsvd .fit (wfs )
121
-
122
- model_folder = tmp_folder / "tsvd_model"
123
-
124
- model_folder .mkdir (exist_ok = True )
125
- with open (model_folder / "pca_model.pkl" , "wb" ) as f :
126
- pickle .dump (tsvd , f )
127
-
128
- model_params = {
129
- "ms_before" : ms_before ,
130
- "ms_after" : ms_after ,
131
- "sampling_frequency" : float (fs ),
132
- }
133
-
134
- with open (model_folder / "params.json" , "w" ) as f :
135
- json .dump (model_params , f )
112
+ svd_model = TruncatedSVD (params ["n_svd" ])
113
+ svd_model .fit (wfs )
114
+ features_folder = tmp_folder / "tsvd_features"
115
+ features_folder .mkdir (exist_ok = True )
136
116
137
- # features
138
- node0 = PeakRetriever (recording , peaks )
139
-
140
- radius_um = params ["radius_um" ]
141
- node1 = ExtractSparseWaveforms (
117
+ peaks_svd , sparse_mask , svd_model = extract_peaks_svd (
142
118
recording ,
143
- parents = [node0 ],
144
- return_output = False ,
119
+ peaks ,
145
120
ms_before = ms_before ,
146
121
ms_after = ms_after ,
122
+ svd_model = svd_model ,
147
123
radius_um = radius_um ,
124
+ folder = features_folder ,
125
+ ** job_kwargs ,
148
126
)
149
127
150
- node2 = HanningFilter (recording , parents = [ node0 , node1 ], return_output = False )
128
+ neighbours_mask = get_channel_distances (recording ) <= neighbors_radius_um
151
129
152
- node3 = TemporalPCAProjection (
153
- recording , parents = [ node0 , node2 ], return_output = True , model_folder_path = model_folder
154
- )
130
+ if params [ "debug" ]:
131
+ np . save ( features_folder / "sparse_mask.npy" , sparse_mask )
132
+ np . save ( features_folder / "peaks.npy" , peaks )
155
133
156
- pipeline_nodes = [node0 , node1 , node2 , node3 ]
134
+ original_labels = peaks ["channel_index" ]
135
+ from spikeinterface .sortingcomponents .clustering .split import split_clusters
157
136
158
- if len (params ["recursive_kwargs" ]) == 0 :
159
- from sklearn .decomposition import PCA
137
+ split_kwargs = params ["split_kwargs" ].copy ()
138
+ split_kwargs ["neighbours_mask" ] = neighbours_mask
139
+ split_kwargs ["waveforms_sparse_mask" ] = sparse_mask
140
+ split_kwargs ["min_size_split" ] = 2 * params ["hdbscan_kwargs" ].get ("min_cluster_size" , 50 )
141
+ split_kwargs ["clusterer_kwargs" ] = params ["hdbscan_kwargs" ]
160
142
161
- all_pc_data = run_node_pipeline (
162
- recording ,
163
- pipeline_nodes ,
164
- job_kwargs ,
165
- job_name = "extracting features" ,
166
- )
167
-
168
- peak_labels = - 1 * np .ones (len (peaks ), dtype = int )
169
- nb_clusters = 0
170
- for c in np .unique (peaks ["channel_index" ]):
171
- mask = peaks ["channel_index" ] == c
172
- sub_data = all_pc_data [mask ]
173
- sub_data = sub_data .reshape (len (sub_data ), - 1 )
174
-
175
- if all_pc_data .shape [1 ] > params ["n_svd" ]:
176
- tsvd = PCA (params ["n_svd" ], whiten = True )
177
- else :
178
- tsvd = PCA (all_pc_data .shape [1 ], whiten = True )
179
-
180
- hdbscan_data = tsvd .fit_transform (sub_data )
181
- try :
182
- clustering = hdbscan .hdbscan (hdbscan_data , ** d ["hdbscan_kwargs" ])
183
- local_labels = clustering [0 ]
184
- except Exception :
185
- local_labels = np .zeros (len (hdbscan_data ))
186
- valid_clusters = local_labels > - 1
187
- if np .sum (valid_clusters ) > 0 :
188
- local_labels [valid_clusters ] += nb_clusters
189
- peak_labels [mask ] = local_labels
190
- nb_clusters += len (np .unique (local_labels [valid_clusters ]))
143
+ if params ["debug" ]:
144
+ debug_folder = tmp_folder / "split"
191
145
else :
146
+ debug_folder = None
192
147
193
- features_folder = tmp_folder / "tsvd_features"
194
- features_folder .mkdir (exist_ok = True )
195
-
196
- _ = run_node_pipeline (
197
- recording ,
198
- pipeline_nodes ,
199
- job_kwargs ,
200
- job_name = "extracting features" ,
201
- gather_mode = "npy" ,
202
- gather_kwargs = dict (exist_ok = True ),
203
- folder = features_folder ,
204
- names = ["sparse_tsvd" ],
205
- )
206
-
207
- sparse_mask = node1 .neighbours_mask
208
- neighbours_mask = get_channel_distances (recording ) <= radius_um
209
-
210
- # np.save(features_folder / "sparse_mask.npy", sparse_mask)
211
- np .save (features_folder / "peaks.npy" , peaks )
212
-
213
- original_labels = peaks ["channel_index" ]
214
- from spikeinterface .sortingcomponents .clustering .split import split_clusters
148
+ peak_labels , _ = split_clusters (
149
+ original_labels ,
150
+ recording ,
151
+ {"peaks" : peaks , "sparse_tsvd" : peaks_svd },
152
+ method = "local_feature_clustering" ,
153
+ method_kwargs = split_kwargs ,
154
+ debug_folder = debug_folder ,
155
+ ** params ["recursive_kwargs" ],
156
+ ** job_kwargs ,
157
+ )
215
158
216
- min_size = 2 * params ["hdbscan_kwargs" ].get ("min_cluster_size" , 20 )
159
+ if params ["noise_levels" ] is None :
160
+ params ["noise_levels" ] = get_noise_levels (recording , return_scaled = False , ** job_kwargs )
217
161
218
- if params ["debug" ]:
219
- debug_folder = tmp_folder / "split"
220
- else :
221
- debug_folder = None
162
+ if not params ["templates_from_svd" ]:
163
+ from spikeinterface .sortingcomponents .clustering .tools import get_templates_from_peaks_and_recording
222
164
223
- peak_labels , _ = split_clusters (
224
- original_labels ,
165
+ templates = get_templates_from_peaks_and_recording (
225
166
recording ,
226
- features_folder ,
227
- method = "local_feature_clustering" ,
228
- method_kwargs = dict (
229
- clusterer = "hdbscan" ,
230
- feature_name = "sparse_tsvd" ,
231
- neighbours_mask = neighbours_mask ,
232
- waveforms_sparse_mask = sparse_mask ,
233
- min_size_split = min_size ,
234
- clusterer_kwargs = d ["hdbscan_kwargs" ],
235
- n_pca_features = 5 ,
236
- ),
237
- debug_folder = debug_folder ,
238
- ** params ["recursive_kwargs" ],
167
+ peaks ,
168
+ peak_labels ,
169
+ ms_before ,
170
+ ms_after ,
239
171
** job_kwargs ,
240
172
)
173
+ else :
174
+ from spikeinterface .sortingcomponents .clustering .tools import get_templates_from_peaks_and_svd
241
175
242
- non_noise = peak_labels > - 1
243
- labels , inverse = np .unique (peak_labels [non_noise ], return_inverse = True )
244
- peak_labels [non_noise ] = inverse
245
- labels = np .unique (inverse )
246
-
247
- spikes = np .zeros (non_noise .sum (), dtype = minimum_spike_dtype )
248
- spikes ["sample_index" ] = peaks [non_noise ]["sample_index" ]
249
- spikes ["segment_index" ] = peaks [non_noise ]["segment_index" ]
250
- spikes ["unit_index" ] = peak_labels [non_noise ]
251
-
252
- unit_ids = labels
253
-
254
- nbefore = int (params ["waveforms" ]["ms_before" ] * fs / 1000.0 )
255
- nafter = int (params ["waveforms" ]["ms_after" ] * fs / 1000.0 )
256
-
257
- if params ["noise_levels" ] is None :
258
- params ["noise_levels" ] = get_noise_levels (recording , return_scaled = False , ** job_kwargs )
259
-
260
- templates_array = estimate_templates (
261
- recording ,
262
- spikes ,
263
- unit_ids ,
264
- nbefore ,
265
- nafter ,
266
- return_scaled = False ,
267
- job_name = None ,
268
- ** job_kwargs ,
269
- )
176
+ templates = get_templates_from_peaks_and_svd (
177
+ recording ,
178
+ peaks ,
179
+ peak_labels ,
180
+ ms_before ,
181
+ ms_after ,
182
+ svd_model ,
183
+ peaks_svd ,
184
+ sparse_mask ,
185
+ operator = "median" ,
186
+ )
270
187
188
+ templates_array = templates .templates_array
271
189
best_channels = np .argmax (np .abs (templates_array [:, nbefore , :]), axis = 1 )
272
190
peak_snrs = np .abs (templates_array [:, nbefore , :])
273
191
best_snrs_ratio = (peak_snrs / params ["noise_levels" ])[np .arange (len (peak_snrs )), best_channels ]
192
+ old_unit_ids = templates .unit_ids .copy ()
274
193
valid_templates = best_snrs_ratio > params ["noise_threshold" ]
275
194
276
- if d [ "rank" ] is not None :
277
- from spikeinterface . sortingcomponents . matching . circus import compress_templates
195
+ mask = np . isin ( peak_labels , old_unit_ids [ ~ valid_templates ])
196
+ peak_labels [ mask ] = - 1
278
197
279
- _ , _ , _ , templates_array = compress_templates ( templates_array , d [ "rank" ])
198
+ from spikeinterface . core . template import Templates
280
199
281
200
templates = Templates (
282
201
templates_array = templates_array [valid_templates ],
283
202
sampling_frequency = fs ,
284
- nbefore = nbefore ,
203
+ nbefore = templates . nbefore ,
285
204
sparsity_mask = None ,
286
205
channel_ids = recording .channel_ids ,
287
- unit_ids = unit_ids [valid_templates ],
206
+ unit_ids = templates . unit_ids [valid_templates ],
288
207
probe = recording .get_probe (),
289
208
is_scaled = False ,
290
209
)
291
210
211
+ if params ["debug" ]:
212
+ templates_folder = tmp_folder / "dense_templates"
213
+ templates .to_zarr (folder_path = templates_folder )
214
+
292
215
sparsity = compute_sparsity (templates , noise_levels = params ["noise_levels" ], ** params ["sparsity" ])
293
216
templates = templates .to_sparse (sparsity )
294
217
empty_templates = templates .sparsity_mask .sum (axis = 1 ) == 0
218
+ old_unit_ids = templates .unit_ids .copy ()
295
219
templates = remove_empty_templates (templates )
296
220
297
- mask = np .isin (peak_labels , np . where ( empty_templates )[ 0 ])
221
+ mask = np .isin (peak_labels , old_unit_ids [ empty_templates ])
298
222
peak_labels [mask ] = - 1
299
223
300
- mask = np .isin (peak_labels , np . where ( ~ valid_templates )[ 0 ] )
301
- peak_labels [ mask ] = - 1
224
+ labels = np .unique (peak_labels )
225
+ labels = labels [ labels >= 0 ]
302
226
303
- if verbose :
304
- print ("Found %d raw clusters, starting to clean with matching" % (len (templates .unit_ids )))
227
+ if params ["remove_mixtures" ]:
228
+ if verbose :
229
+ print ("Found %d raw clusters, starting to clean with matching" % (len (templates .unit_ids )))
305
230
306
- cleaning_job_kwargs = job_kwargs .copy ()
307
- cleaning_job_kwargs ["progress_bar" ] = False
308
- cleaning_params = params ["cleaning_kwargs" ].copy ()
231
+ cleaning_job_kwargs = job_kwargs .copy ()
232
+ cleaning_job_kwargs ["progress_bar" ] = False
233
+ cleaning_params = params ["cleaning_kwargs" ].copy ()
309
234
310
- labels , peak_labels = remove_duplicates_via_matching (
311
- templates , peak_labels , job_kwargs = cleaning_job_kwargs , ** cleaning_params
312
- )
235
+ labels , peak_labels = remove_duplicates_via_matching (
236
+ templates , peak_labels , job_kwargs = cleaning_job_kwargs , ** cleaning_params
237
+ )
313
238
314
- if verbose :
315
- print ("Kept %d non-duplicated clusters" % len (labels ))
239
+ if verbose :
240
+ print ("Kept %d non-duplicated clusters" % len (labels ))
241
+ else :
242
+ if verbose :
243
+ print ("Kept %d raw clusters" % len (labels ))
316
244
317
- return labels , peak_labels
245
+ return labels , peak_labels , svd_model , peaks_svd , sparse_mask
0 commit comments