@@ -17,10 +17,10 @@ class BaseSorting(BaseExtractor):
17
17
Abstract class representing several segment several units and relative spiketrains.
18
18
"""
19
19
20
- def __init__ (self , sampling_frequency : float , unit_ids : List ):
20
+ def __init__ (self , sampling_frequency : float , unit_ids : list ):
21
21
BaseExtractor .__init__ (self , unit_ids )
22
22
self ._sampling_frequency = float (sampling_frequency )
23
- self ._sorting_segments : List [BaseSortingSegment ] = []
23
+ self ._sorting_segments : list [BaseSortingSegment ] = []
24
24
# this weak link is to handle times from a recording object
25
25
self ._recording = None
26
26
self ._sorting_info = None
@@ -212,7 +212,7 @@ def set_sorting_info(self, recording_dict, params_dict, log_dict):
212
212
sorting_info = dict (recording = recording_dict , params = params_dict , log = log_dict )
213
213
self .annotate (__sorting_info__ = sorting_info )
214
214
215
- def has_recording (self ):
215
+ def has_recording (self ) -> bool :
216
216
return self ._recording is not None
217
217
218
218
def has_time_vector (self , segment_index = None ) -> bool :
@@ -302,14 +302,6 @@ def get_unit_property(self, unit_id, key):
302
302
v = values [self .id_to_index (unit_id )]
303
303
return v
304
304
305
- def get_total_num_spikes (self ):
306
- warnings .warn (
307
- "Sorting.get_total_num_spikes() is deprecated and will be removed in spikeinterface 0.102, use sorting.count_num_spikes_per_unit()" ,
308
- DeprecationWarning ,
309
- stacklevel = 2 ,
310
- )
311
- return self .count_num_spikes_per_unit (outputs = "dict" )
312
-
313
305
def count_num_spikes_per_unit (self , outputs = "dict" ):
314
306
"""
315
307
For each unit : get number of spikes across segments.
@@ -451,12 +443,34 @@ def remove_empty_units(self):
451
443
non_empty_units = self .get_non_empty_unit_ids ()
452
444
return self .select_units (non_empty_units )
453
445
454
- def get_non_empty_unit_ids (self ):
446
+ def get_non_empty_unit_ids (self ) -> np .ndarray :
447
+ """
448
+ Return the unit IDs that have at least one spike across all segments.
449
+
450
+ This method computes the number of spikes for each unit using
451
+ `count_num_spikes_per_unit` and filters out units with zero spikes.
452
+
453
+ Returns
454
+ -------
455
+ np.ndarray
456
+ Array of unit IDs (same dtype as self.unit_ids) for which at least one spike exists.
457
+ """
455
458
num_spikes_per_unit = self .count_num_spikes_per_unit ()
456
459
457
460
return np .array ([unit_id for unit_id in self .unit_ids if num_spikes_per_unit [unit_id ] != 0 ])
458
461
459
- def get_empty_unit_ids (self ):
462
+ def get_empty_unit_ids (self ) -> np .ndarray :
463
+ """
464
+ Return the unit IDs that have zero spikes across all segments.
465
+
466
+ This method returns the complement of `get_non_empty_unit_ids` with respect
467
+ to all unit IDs in the sorting.
468
+
469
+ Returns
470
+ -------
471
+ np.ndarray
472
+ Array of unit IDs (same dtype as self.unit_ids) for which no spikes exist.
473
+ """
460
474
unit_ids = self .unit_ids
461
475
empty_units = unit_ids [~ np .isin (unit_ids , self .get_non_empty_unit_ids ())]
462
476
return empty_units
@@ -506,44 +520,6 @@ def time_to_sample_index(self, time, segment_index=0):
506
520
507
521
return sample_index
508
522
509
- def get_all_spike_trains (self , outputs = "unit_id" ):
510
- """
511
- Return all spike trains concatenated.
512
- This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead
513
- """
514
-
515
- warnings .warn (
516
- "Sorting.get_all_spike_trains() will be deprecated. Sorting.to_spike_vector() instead" ,
517
- DeprecationWarning ,
518
- stacklevel = 2 ,
519
- )
520
-
521
- assert outputs in ("unit_id" , "unit_index" )
522
- spikes = []
523
- for segment_index in range (self .get_num_segments ()):
524
- spike_times = []
525
- spike_labels = []
526
- for i , unit_id in enumerate (self .unit_ids ):
527
- st = self .get_unit_spike_train (unit_id = unit_id , segment_index = segment_index )
528
- spike_times .append (st )
529
- if outputs == "unit_id" :
530
- spike_labels .append (np .array ([unit_id ] * st .size ))
531
- elif outputs == "unit_index" :
532
- spike_labels .append (np .zeros (st .size , dtype = "int64" ) + i )
533
-
534
- if len (spike_times ) > 0 :
535
- spike_times = np .concatenate (spike_times )
536
- spike_labels = np .concatenate (spike_labels )
537
- order = np .argsort (spike_times )
538
- spike_times = spike_times [order ]
539
- spike_labels = spike_labels [order ]
540
- else :
541
- spike_times = np .array ([], dtype = np .int64 )
542
- spike_labels = np .array ([], dtype = np .int64 )
543
-
544
- spikes .append ((spike_times , spike_labels ))
545
- return spikes
546
-
547
523
def precompute_spike_trains (self , from_spike_vector = None ):
548
524
"""
549
525
Pre-computes and caches all spike trains for this sorting
0 commit comments