Skip to content

Commit e45a8be

Browse files
h-mayorquinzm711
andauthored
Cleanup base sorting extractor (#3871)
* small cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix zarr test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * another deprecation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Zach McKenzie <[email protected]>
1 parent d34be57 commit e45a8be

File tree

8 files changed

+29
-69
lines changed

8 files changed

+29
-69
lines changed

src/spikeinterface/core/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from .globals import get_global_tmp_folder, is_set_global_tmp_folder
1818
from .core_tools import (
19-
is_path_remote,
2019
clean_zarr_folder_name,
2120
is_dict_extractor,
2221
SIJsonEncoder,
@@ -50,9 +49,6 @@ class BaseExtractor:
5049
# kwargs which can be precomputed before being used by the extractor
5150
_precomputable_kwarg_names = []
5251

53-
installation_mesg = ""
54-
installed = True
55-
5652
def __init__(self, main_ids: Sequence) -> None:
5753
# store init kwargs for nested serialisation
5854
self._kwargs = {}

src/spikeinterface/core/baserecording.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRec
764764
sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame)
765765
return sub_recording
766766

767-
def time_slice(self, start_time: float | None, end_time: float) -> BaseRecording:
767+
def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRecording:
768768
"""
769769
Returns a new recording object, restricted to the time interval [start_time, end_time].
770770

src/spikeinterface/core/baserecordingsnippets.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ class BaseRecordingSnippets(BaseExtractor):
1616
Mixin that handles all probe and channel operations
1717
"""
1818

19-
has_default_locations = False
20-
2119
def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype):
2220
BaseExtractor.__init__(self, channel_ids)
2321
self._sampling_frequency = float(sampling_frequency)

src/spikeinterface/core/basesorting.py

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ class BaseSorting(BaseExtractor):
1717
Abstract class representing several segment several units and relative spiketrains.
1818
"""
1919

20-
def __init__(self, sampling_frequency: float, unit_ids: List):
20+
def __init__(self, sampling_frequency: float, unit_ids: list):
2121
BaseExtractor.__init__(self, unit_ids)
2222
self._sampling_frequency = float(sampling_frequency)
23-
self._sorting_segments: List[BaseSortingSegment] = []
23+
self._sorting_segments: list[BaseSortingSegment] = []
2424
# this weak link is to handle times from a recording object
2525
self._recording = None
2626
self._sorting_info = None
@@ -212,7 +212,7 @@ def set_sorting_info(self, recording_dict, params_dict, log_dict):
212212
sorting_info = dict(recording=recording_dict, params=params_dict, log=log_dict)
213213
self.annotate(__sorting_info__=sorting_info)
214214

215-
def has_recording(self):
215+
def has_recording(self) -> bool:
216216
return self._recording is not None
217217

218218
def has_time_vector(self, segment_index=None) -> bool:
@@ -302,14 +302,6 @@ def get_unit_property(self, unit_id, key):
302302
v = values[self.id_to_index(unit_id)]
303303
return v
304304

305-
def get_total_num_spikes(self):
306-
warnings.warn(
307-
"Sorting.get_total_num_spikes() is deprecated and will be removed in spikeinterface 0.102, use sorting.count_num_spikes_per_unit()",
308-
DeprecationWarning,
309-
stacklevel=2,
310-
)
311-
return self.count_num_spikes_per_unit(outputs="dict")
312-
313305
def count_num_spikes_per_unit(self, outputs="dict"):
314306
"""
315307
For each unit : get number of spikes across segments.
@@ -451,12 +443,34 @@ def remove_empty_units(self):
451443
non_empty_units = self.get_non_empty_unit_ids()
452444
return self.select_units(non_empty_units)
453445

454-
def get_non_empty_unit_ids(self):
446+
def get_non_empty_unit_ids(self) -> np.ndarray:
447+
"""
448+
Return the unit IDs that have at least one spike across all segments.
449+
450+
This method computes the number of spikes for each unit using
451+
`count_num_spikes_per_unit` and filters out units with zero spikes.
452+
453+
Returns
454+
-------
455+
np.ndarray
456+
Array of unit IDs (same dtype as self.unit_ids) for which at least one spike exists.
457+
"""
455458
num_spikes_per_unit = self.count_num_spikes_per_unit()
456459

457460
return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] != 0])
458461

459-
def get_empty_unit_ids(self):
462+
def get_empty_unit_ids(self) -> np.ndarray:
463+
"""
464+
Return the unit IDs that have zero spikes across all segments.
465+
466+
This method returns the complement of `get_non_empty_unit_ids` with respect
467+
to all unit IDs in the sorting.
468+
469+
Returns
470+
-------
471+
np.ndarray
472+
Array of unit IDs (same dtype as self.unit_ids) for which no spikes exist.
473+
"""
460474
unit_ids = self.unit_ids
461475
empty_units = unit_ids[~np.isin(unit_ids, self.get_non_empty_unit_ids())]
462476
return empty_units
@@ -506,44 +520,6 @@ def time_to_sample_index(self, time, segment_index=0):
506520

507521
return sample_index
508522

509-
def get_all_spike_trains(self, outputs="unit_id"):
510-
"""
511-
Return all spike trains concatenated.
512-
This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead
513-
"""
514-
515-
warnings.warn(
516-
"Sorting.get_all_spike_trains() will be deprecated. Sorting.to_spike_vector() instead",
517-
DeprecationWarning,
518-
stacklevel=2,
519-
)
520-
521-
assert outputs in ("unit_id", "unit_index")
522-
spikes = []
523-
for segment_index in range(self.get_num_segments()):
524-
spike_times = []
525-
spike_labels = []
526-
for i, unit_id in enumerate(self.unit_ids):
527-
st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
528-
spike_times.append(st)
529-
if outputs == "unit_id":
530-
spike_labels.append(np.array([unit_id] * st.size))
531-
elif outputs == "unit_index":
532-
spike_labels.append(np.zeros(st.size, dtype="int64") + i)
533-
534-
if len(spike_times) > 0:
535-
spike_times = np.concatenate(spike_times)
536-
spike_labels = np.concatenate(spike_labels)
537-
order = np.argsort(spike_times)
538-
spike_times = spike_times[order]
539-
spike_labels = spike_labels[order]
540-
else:
541-
spike_times = np.array([], dtype=np.int64)
542-
spike_labels = np.array([], dtype=np.int64)
543-
544-
spikes.append((spike_times, spike_labels))
545-
return spikes
546-
547523
def precompute_spike_trains(self, from_spike_vector=None):
548524
"""
549525
Pre-computes and caches all spike trains for this sorting

src/spikeinterface/core/tests/test_basesorting.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ def test_BaseSorting(create_cache_folder):
9494
sorting4 = sorting.save(format="memory")
9595
check_sortings_equal(sorting, sorting4, check_annotations=True, check_properties=True)
9696

97-
with pytest.warns(DeprecationWarning):
98-
num_spikes = sorting.get_all_spike_trains()
9997
# print(spikes)
10098

10199
spikes = sorting.to_spike_vector()
@@ -198,12 +196,6 @@ def test_empty_sorting():
198196

199197
assert len(sorting.unit_ids) == 0
200198

201-
with pytest.warns(DeprecationWarning):
202-
spikes = sorting.get_all_spike_trains()
203-
assert len(spikes) == 1
204-
assert len(spikes[0][0]) == 0
205-
assert len(spikes[0][1]) == 0
206-
207199
spikes = sorting.to_spike_vector()
208200
assert spikes.shape == (0,)
209201

src/spikeinterface/core/zarrextractors.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ def get_default_zarr_compressor(clevel: int = 5):
382382
Blosc.compressor
383383
The compressor object that can be used with the save to zarr function
384384
"""
385-
assert ZarrRecordingExtractor.installed, ZarrRecordingExtractor.installation_mesg
386385
from numcodecs import Blosc
387386

388387
return Blosc(cname="zstd", clevel=clevel, shuffle=Blosc.BITSHUFFLE)

src/spikeinterface/extractors/tridesclousextractors.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __init__(self, folder_path, chan_grp=None):
3030
except ImportError:
3131
raise ImportError(self.installation_mesg)
3232

33-
assert self.installed, self.installation_mesg
3433
tdc_folder = Path(folder_path)
3534

3635
dataio = tdc.DataIO(str(tdc_folder))

src/spikeinterface/postprocessing/tests/test_principal_component.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_get_projections(self, sparse):
100100
some_channel_ids = sorting_analyzer.channel_ids[::2]
101101

102102
random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data()
103-
all_num_spikes = sorting_analyzer.sorting.get_total_num_spikes()
103+
all_num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit()
104104
unit_ids_num_spikes = np.sum(all_num_spikes[unit_id] for unit_id in some_unit_ids)
105105

106106
# this should be all spikes all channels

0 commit comments

Comments
 (0)