From d78573c78086e2698d7079671a2964274e38f9a1 Mon Sep 17 00:00:00 2001 From: fededagos Date: Fri, 11 Apr 2025 10:45:26 +0200 Subject: [PATCH 1/4] Added 3D ACGs, references, tests --- doc/api.rst | 1 + doc/modules/postprocessing.rst | 20 + doc/references.rst | 9 + src/spikeinterface/postprocessing/__init__.py | 2 + .../postprocessing/correlograms.py | 382 +++++++++++++++++- .../postprocessing/tests/test_correlograms.py | 135 ++++++- 6 files changed, 536 insertions(+), 13 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 7d20ed7c19..197c7f7d56 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -214,6 +214,7 @@ spikeinterface.postprocessing .. autofunction:: compute_spike_locations .. autofunction:: compute_template_similarity .. autofunction:: compute_correlograms + .. autofunction:: compute_acgs_3d .. autofunction:: compute_isi_histograms .. autofunction:: get_template_metric_names .. autofunction:: align_sorting diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 3c30f248c8..a312633a08 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -358,6 +358,26 @@ with shape (num_units, num_units, num_bins) with all correlograms for each pair For more information, see :py:func:`~spikeinterface.postprocessing.compute_correlograms` +acgs_3d +^^^^^^^ + +This extension computes the 3D Autocorrelograms (3D-ACG) from units' spike times to analyze how a neuron's temporal +firing pattern varies with its firing rate. The 3D-ACG, described in [Beau]_ et al., 2025, provides rich +representations of a unit's spike train statistics while accounting for firing rate modulations. + +.. code-block:: python + + acg3d = sorting_analyzer.compute( + input="acgs_3d", + window_ms=50.0, + bin_ms=1.0, + num_firing_rate_quantiles=10, + smoothing_factor=250, + ) + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_acgs_3d` + + isi_histograms ^^^^^^^^^^^^^^ diff --git a/doc/references.rst b/doc/references.rst index ce4672a9ca..6ba6efe6a1 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -44,6 +44,11 @@ please include the appropriate citation for the :code:`sorter_name` parameter yo - :code:`wavclus` [Chaure]_ - :code:`yass` [Lee]_ +Postprocessing Module +--------------------- + +If you use the :code:`acgs_3d` extensions, (i.e. :code:`postprocessing.compute_acgs_3d`, :code:`postprocessing.ComputeACG3D`) please cite [Beau]_ + Qualitymetrics Module --------------------- If you use the :code:`qualitymetrics` module, i.e. you use the :code:`analyzer.compute()` @@ -76,6 +81,8 @@ If you use the :code:`get_potential_auto_merge` method from the curation module, References ---------- +.. [Beau] `A deep learning strategy to identify cell types across species from high-density extracellular recordings. 2025. `_ + .. [Buccino] `SpikeInterface, a unified framework for spike sorting. 2020. `_ .. [Buzsáki] `The Log-Dynamic Brain: How Skewed Distributions Affect Network Operations. 2014. `_ @@ -112,6 +119,8 @@ References .. [Niediek] `Reliable Analysis of Single-Unit Recordings from the Human Brain under Noisy Conditions: Tracking Neurons over Hours. 2016. `_ +.. [npyx] `NeuroPyxels: loading, processing and plotting Neuropixels data in python. 2021. _` + .. [Pachitariu] `Spike sorting with Kilosort4. 2024. `_ .. [Pouzat] `Using noise signature to optimize spike-sorting and to assess neuronal classification quality. 2002. `_ diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 31bfdf1863..b1adbff281 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -19,7 +19,9 @@ from .spike_amplitudes import compute_spike_amplitudes, ComputeSpikeAmplitudes from .correlograms import ( + ComputeACG3D, ComputeCorrelograms, + compute_acgs_3d, compute_correlograms, correlogram_for_one_segment, ) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 5e30d7c68b..0522c511c3 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -1,11 +1,21 @@ from __future__ import annotations + import math import warnings -import numpy as np -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension, SortingAnalyzer from copy import deepcopy -from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor +import numpy as np +from joblib import Parallel, delayed + +from spikeinterface.core import BaseSorting +from spikeinterface.core.sortinganalyzer import ( + AnalyzerExtension, + SortingAnalyzer, + register_result_extension, +) +from spikeinterface.core.waveforms_extractor_backwards_compatibility import ( + MockWaveformExtractor, +) try: import numba @@ -107,10 +117,8 @@ def _merge_extension_data( if censor_ms is not None: # if censor_ms has no effect, can apply "soft" method. Check if any spikes have been removed for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): - num_segments = new_sorting_analyzer.get_num_segments() for segment_index in range(num_segments): - merged_spike_train_length = len( new_sorting_analyzer.sorting.get_unit_spike_train(new_unit_id, segment_index=segment_index) ) @@ -132,7 +140,6 @@ def _merge_extension_data( new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) new_data = dict(ccgs=new_ccgs, bins=new_bins) else: - # Make a transformation dict, which tells us how unit_indices from the # old to the new sorter are mapped. old_to_new_unit_index_map = {} @@ -160,7 +167,6 @@ def _merge_extension_data( correlograms, new_bins = deepcopy(self.get_data()) for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): - merge_unit_group_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_unit_group) # Sum unit rows of the correlogram matrix: C_{k,l} = C_{i,l} + C_{j,l} @@ -564,7 +570,6 @@ def _compute_correlograms_one_segment_numba( start_j = 0 for i in range(spike_times.size): for j in range(start_j, spike_times.size): - if i == j: continue @@ -593,3 +598,364 @@ def _compute_correlograms_one_segment_numba( bin = diff // bin_size correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1 + + +class ComputeACG3D(AnalyzerExtension): + """ + Computes the 3D Autocorrelograms (3D-ACG) from units spike times to analyze how a neuron's temporal firing + pattern varies with its firing rate. + + The 3D-ACG, originally described in [Beau]_ et al., 2025, provides a rich representations of a unit's + spike train statistics while accounting for firing rate modulations. + The method was developed to normalize for the impact of changes in firing rate on measures of firing statistics, + particularly in awake animals performing behavioral tasks where firing rates naturally vary over time. + + The approach works as follows: + 1. The instantaneous firing rate is calculated at each spike time using inverse ISI + 2. Firing rates are smoothed with a boxcar filter (default 250ms width) + 3. Spikes are grouped into firing rate bins (deciles by default) + 4. Separate ACGs are computed for each firing rate bin + + The result can be visualized as an image where: + - The y-axis represents different firing rate bins + - The x-axis represents time lag from the trigger spike + - The z-axis (color) represents spike probability + + Parameters + ---------- + - spike_times: vector of spike timestamps (in sample units) + - window_ms (float): window size for auto-correlation, in milliseconds + - bin_ms (float): bin size for auto-correlation, in milliseconds + - num_firing_rate_quantiles (integer): number of firing rate quantiles. Default=10 (deciles) + - smoothing_factor (float): width of the boxcar filter for smoothing (in milliseconds). + Default=250ms. Set to None to disable smoothing. + - firing_rate_bins (array-like): Optional predefined firing rate bin edges. + If provided, num_firing_rate_bins is ignored. + - n_jobs (int): The number of parallel jobs spawned to compute the acgs across units. + Defaults to -1 (one job per cpu). + + Returns + ------- + - acg_3d (numpy.ndarray): 2D array with dimension (num_firing_rate_bins x num_timepoints), + where each element is the probability of observing a spike at the given time lag, + conditioned on the neuron's firing rate at the trigger spike time. + - firing_rate_quantiles (numpy.ndarray): The firing rate values that define the quantiles edges + + Notes + ----- + - The central bin (t=0) is set to 0 as it would always be 1 by definition + - Edge spikes are excluded to avoid boundary artifacts + - Spike counts are normalized by the total number of trigger spikes in each rate bin + + References + ---------- + Based on work in [Beau]_ et al., 2025. + + Adapted Python implementation from [npyx]_ : https://github.com/m-beau/NeuroPyxels/ + + Original author: David Herzfeld + """ + + extension_name = "acgs_3d" + depend_on = [] + need_recording = False + use_nodepipeline = False + need_job_kwargs = False + + def __init__(self, sorting_analyzer): + AnalyzerExtension.__init__(self, sorting_analyzer) + + def _set_params( + self, + window_ms: float = 50.0, + bin_ms: float = 1.0, + num_firing_rate_quantiles: int = 10, + smoothing_factor: int = 250, + n_jobs: int = -1, + ): + params = dict( + window_ms=window_ms, + bin_ms=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + n_jobs=n_jobs, + ) + + return params + + def _select_extension_data(self, unit_ids): + # filter metrics dataframe + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + new_acgs_3d = self.data["acgs_3d"][unit_indices] + new_firing_quantiles = self.data["firing_quantiles"][unit_indices] + new_bins = self.data["bins"][:] + new_data = dict(acgs_3d=new_acgs_3d, firing_quantiles=new_firing_quantiles, bins=new_bins) + return new_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + new_sorting = new_sorting_analyzer.sorting + + acgs_3d, firing_rate_quantiles, _ = _compute_acgs_3d( + new_sorting, + unit_ids=new_unit_ids, + window_ms=self.params["window_ms"], + bin_ms=self.params["bin_ms"], + num_firing_rate_quantiles=self.params["num_firing_rate_quantiles"], + smoothing_factor=self.params["smoothing_factor"], + ) + + new_unit_ids_indices = new_sorting.ids_to_indices(new_unit_ids) + old_unit_ids = [unit_id for unit_id in new_sorting_analyzer.unit_ids if unit_id not in new_unit_ids] + old_unit_ids_indices = new_sorting.ids_to_indices(old_unit_ids) + + new_acgs_3d = np.zeros((len(new_sorting.unit_ids), acgs_3d.shape[1], acgs_3d.shape[2])) + new_firing_quantiles = np.zeros((len(new_sorting.unit_ids), firing_rate_quantiles.shape[1])) + + new_acgs_3d[new_unit_ids_indices, :, :] = acgs_3d + new_acgs_3d[old_unit_ids_indices, :, :] = self.data["acgs_3d"][old_unit_ids_indices, :, :] + + new_firing_quantiles[new_unit_ids_indices, :] = firing_rate_quantiles + new_firing_quantiles[old_unit_ids_indices, :] = self.data["firing_quantiles"][old_unit_ids_indices, :] + + new_data = dict( + acgs_3d=new_acgs_3d, + firing_quantiles=new_firing_quantiles, + bins=self.data["bins"], + ) + return new_data + + def _run(self, verbose=False): + acgs_3d, firing_quantiles, bins = _compute_acgs_3d(self.sorting_analyzer.sorting, **self.params) + self.data["firing_quantiles"] = firing_quantiles + self.data["acgs_3d"] = acgs_3d + self.data["bins"] = bins + + def _get_data(self): + return self.data["acgs_3d"], self.data["firing_quantiles"], self.data["bins"] + + +def _compute_acgs_3d( + sorting: BaseSorting, + unit_ids=None, + window_ms: float = 50.0, + bin_ms: float = 1.0, + num_firing_rate_quantiles: int = 10, + smoothing_factor: int = 250, + n_jobs: int = -1, +): + """ + Computes the 3D autocorrelogram for a single unit. + + See ComputeACG3D() for more details. + + Parameters + ---------- + sorting : Sorting + A SpikeInterface Sorting object + unit_ids : list of int, default: None + The unit ids to compute the autocorrelogram for. If None, + all units in the sorting are used. + window_ms : float, default: 50.0 + The window around the spike to compute the correlation in ms. For example, + if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. + bin_ms : float, default: 1.0 + The bin size in ms. This determines the bin size over which to + combine lags. For example, with a window size of -25 ms to 25 ms, and + bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ... + num_firing_rate_quantiles : int, default: 10 + The number of quantiles to use for firing rate bins. + smoothing_factor : float, default: 250 + The width of the smoothing kernel in milliseconds. + n_jobs : int, default: -1. + Number of parallel jobs to spawn to compute the 3D-ACGS on different units. + + Returns + ------- + firing_quantiles : np.array + The firing rate quantiles used for each unit. + acgs_3d : np.array + The autocorrelograms for each unit at each firing rate quantile. + bins : np.array + The bin edges in ms + + """ + if unit_ids is None: + unit_ids = sorting.unit_ids + + # pre-compute time bins + winsize_bins = 2 * int(0.5 * window_ms * 1.0 / bin_ms) + 1 + bin_times_ms = np.linspace(-window_ms / 2, window_ms / 2, num=winsize_bins) + num_units = len(unit_ids) + winsize_bins = winsize_bins + acgs_3d = np.zeros((num_units, num_firing_rate_quantiles, winsize_bins)) + firing_quantiles = np.zeros((num_units, num_firing_rate_quantiles)) + + time_bins_ms = np.repeat(bin_times_ms, num_units, axis=0) + + # Process units in parallel + results = Parallel(n_jobs=n_jobs)( + delayed(_compute_3d_acg_one_unit)( + sorting, + unit_id, + window_ms, + bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + ) + for unit_id in unit_ids + ) + + # Unpack results + for unit_index, (acg_3d, firing_quantile) in enumerate(results): + acgs_3d[unit_index, :, :] = acg_3d + firing_quantiles[unit_index, :] = firing_quantile + + return acgs_3d, firing_quantiles, time_bins_ms + + +register_result_extension(ComputeACG3D) +compute_acgs_3d_sorting_analyzer = ComputeACG3D.function_factory() + + +def _compute_3d_acg_one_unit( + sorting: BaseSorting, + unit_id: int | str, + win_size: float, + bin_size: float, + num_firing_rate_quantiles: int = 10, + smoothing_factor: int = 250, + use_spikes_around_times1_for_deciles: bool = True, + firing_rate_quantiles: list | None = None, +): + fs = sorting.sampling_frequency + + bin_size = np.clip(bin_size, 1000 * 1.0 / fs, 1e8) # in milliseconds + win_size = np.clip(win_size, 1e-2, 1e8) # in milliseconds + winsize_bins = 2 * int(0.5 * win_size * 1.0 / bin_size) + 1 # Both in millisecond + assert winsize_bins >= 1 + assert winsize_bins % 2 == 1 + bin_times_ms = np.linspace(-win_size / 2, win_size / 2, num=winsize_bins) + + if firing_rate_quantiles is not None: + num_firing_rate_quantiles = len(firing_rate_quantiles) + spike_counts = np.zeros( + (num_firing_rate_quantiles, len(bin_times_ms)) + ) # Counts number of occurences of spikes in a given bin in time axis + firing_rate_bin_occurence = np.zeros(num_firing_rate_quantiles, dtype=np.int64) # total occurence + + # Samples per bin + samples_per_bin = int(np.ceil(fs / (1000 / bin_size))) + + segments = sorting.get_num_segments() + if segments > 1 and firing_rate_quantiles is None: + warnings.warn( + "Multiple segments detected. Firing rate quantiles will be automatically computed on the first segment." + " Manually define global firing_rate_quantiles if needed.", + UserWarning, + 2, + ) + + # Convert times_1 and times_2 (which are in units of fs to units of bin_size) + for segment_index in range(segments): + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + + # Convert to bin indices + spike_bins = np.floor(spike_times / samples_per_bin).astype(np.int64) + + if len(spike_bins) <= 1: + continue # Need at least 2 spikes for ACG + + # Create a binary spike train spanning the entire time range + max_bin = int(np.ceil(spike_bins[-1] + 1)) + spiketrain = np.zeros(max_bin + winsize_bins, dtype=bool) # Add extra space for window + spiketrain[spike_bins] = True + + # Convert spikes to firing rate using the inverse ISI method + firing_rate = np.zeros(max_bin) + for i in range(1, len(spike_bins) - 1): + start = 0 if i == 0 else (spike_bins[i - 1] + (spike_bins[i] - spike_bins[i - 1]) // 2) + stop = max_bin if i == len(spike_bins) - 1 else (spike_bins[i] + (spike_bins[i + 1] - spike_bins[i]) // 2) + current_firing_rate = 1.0 / ((stop - start) * (bin_size / 1000)) + firing_rate[start:stop] = current_firing_rate + + # Smooth the firing rate using numpy convolution function if requested + if isinstance(smoothing_factor, (int, float)) and smoothing_factor > 0: + kernel_size = int(np.ceil(smoothing_factor / bin_size)) + half_kernel_size = kernel_size // 2 + kernel = np.ones(kernel_size) / kernel_size + smoothed_firing_rate = np.convolve(firing_rate, kernel, mode="same") + + # Correct manually for possible artefacts at the edges + for i in range(kernel_size): + start = max(0, i - half_kernel_size) + stop = min(len(firing_rate), i + half_kernel_size) + smoothed_firing_rate[i] = np.mean(firing_rate[start:stop]) + for i in range(len(firing_rate) - kernel_size, len(firing_rate)): + start = max(0, i - half_kernel_size) + stop = min(len(firing_rate), i + half_kernel_size) + + smoothed_firing_rate[i] = np.mean(firing_rate[start:stop]) + firing_rate = smoothed_firing_rate + + # Get firing rate quantiles + if firing_rate_quantiles is None: + quantile_bins = np.linspace(0, 1, num_firing_rate_quantiles + 2)[1:-1] + if use_spikes_around_times1_for_deciles: + firing_rate_quantiles = np.quantile(firing_rate[spike_bins], quantile_bins) + else: + firing_rate_quantiles = np.quantile(firing_rate, quantile_bins) + for i, spike_index in enumerate(spike_bins): + start = spike_index + int(np.ceil(bin_times_ms[0] / bin_size)) + stop = start + len(bin_times_ms) + if (start < 0) or (stop >= len(spiketrain)) or spike_index < spike_bins[0] or spike_index >= spike_bins[-1]: + continue # Skip these spikes to avoid edge artifacts + current_firing_rate = firing_rate[spike_index] # Firing of neuron 2 at neuron 1's spike index + current_firing_rate_bin_number = np.argmax(firing_rate_quantiles >= current_firing_rate) + if current_firing_rate_bin_number == 0 and current_firing_rate > firing_rate_quantiles[0]: + current_firing_rate_bin_number = len(firing_rate_quantiles) - 1 + spike_counts[current_firing_rate_bin_number, :] += spiketrain[start:stop] + firing_rate_bin_occurence[current_firing_rate_bin_number] += 1 + + acg_3d = spike_counts / (np.ones((len(bin_times_ms), num_firing_rate_quantiles)) * firing_rate_bin_occurence).T + # Divison by zero cases will return nans, so we fix this + acg_3d = np.nan_to_num(acg_3d) + # remove bin 0 which will always be 1 + acg_3d[:, acg_3d.shape[1] // 2] = 0 + + return acg_3d, firing_rate_quantiles + + +def compute_acgs_3d( + sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting, + window_ms: float = 50.0, + bin_ms: float = 1.0, + num_firing_rate_quantiles: int = 10, + smoothing_factor: int = 250, +): + """ + Compute 3D Autocorrelograms. See ComputeACG3D() for a detailed documentation. + """ + if isinstance(sorting_analyzer_or_sorting, MockWaveformExtractor): + sorting_analyzer_or_sorting = sorting_analyzer_or_sorting.sorting + + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + return compute_acgs_3d_sorting_analyzer( + sorting_analyzer_or_sorting, + window_ms=window_ms, + bin_ms=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + ) + else: + return _compute_acgs_3d( + sorting_analyzer_or_sorting, + window_ms=window_ms, + bin_ms=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + ) + + +compute_acgs_3d.__doc__ = compute_acgs_3d_sorting_analyzer.__doc__ diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 3bdfb38e06..ae2a27cbba 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -7,22 +7,31 @@ except ModuleNotFoundError as err: HAVE_NUMBA = False +try: + import npyx + + HAVE_NPYX = True +except ModuleNotFoundError: + HAVE_NPYX = False + +import pytest +from pytest import param + from spikeinterface import NumpySorting, generate_sorting -from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import ComputeCorrelograms +from spikeinterface.postprocessing import ComputeACG3D, ComputeCorrelograms from spikeinterface.postprocessing.correlograms import ( + _compute_3d_acg_one_unit, _compute_correlograms_on_sorting, _make_bins, + compute_acgs_3d, compute_correlograms, ) -import pytest -from pytest import param +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") class TestComputeCorrelograms(AnalyzerExtensionCommonTestSuite): - @pytest.mark.parametrize( "params", [ @@ -368,3 +377,119 @@ def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin expected_result_corr[int(num_bins / 2)] = num_filled_bins return window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr + + +######################################### +# 3D ACG Tests +######################################### + + +class TestComputeACG3D(AnalyzerExtensionCommonTestSuite): + @pytest.mark.parametrize( + "params", + [ + dict(window_ms=50.0, bin_ms=1.0, num_firing_rate_quantiles=10, smoothing_factor=250), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeACG3D, params) + + @pytest.mark.parametrize("num_firing_rate_quantiles", [10, 15]) + def test_sortinganalyzer_acgs_3d(self, num_firing_rate_quantiles): + """ + Test the outputs when using SortingAnalyzer against + the output passing sorting directly to `compute_acgs_3d`. + """ + sorting_analyzer = self._prepare_sorting_analyzer("memory", sparse=False, extension_class=ComputeACG3D) + + params = dict(num_firing_rate_quantiles=num_firing_rate_quantiles, window_ms=100, bin_ms=6.5) + ext_numpy = sorting_analyzer.compute(ComputeACG3D.extension_name, **params) + + result_sorting, quantiles_sorting, time_bins = compute_acgs_3d(self.sorting, **params) + + assert np.array_equal(result_sorting, ext_numpy.data["acgs_3d"]) + assert np.array_equal(quantiles_sorting, ext_numpy.data["firing_quantiles"]) + assert np.array_equal(time_bins, ext_numpy.data["bins"]) + + +@pytest.mark.skipif(not HAVE_NPYX, reason="npyx not installed") +@pytest.mark.parametrize("window_ms", [50.0, 100.0]) +@pytest.mark.parametrize("bin_ms", [1.0, 2.0]) +@pytest.mark.parametrize("num_firing_rate_quantiles", [10, 15]) +@pytest.mark.parametrize("smoothing_factor", [250, 500]) +def test_acgs_3d_original_implementation(window_ms, bin_ms, num_firing_rate_quantiles, smoothing_factor): + sorting = generate_sorting(num_units=1, sampling_frequency=30000.0, durations=[100], firing_rates=60, seed=0) + + unit_ids = sorting.get_unit_ids() + + times_1 = times_2 = sorting.get_unit_spike_train(unit_id=unit_ids[0]) + + npyx_quantiles, npyx_3dacg = npyx.corr.crosscorr_vs_firing_rate( + times_1, + times_2, + win_size=window_ms, + bin_size=bin_ms, + fs=sorting.sampling_frequency, + num_firing_rate_bins=num_firing_rate_quantiles, + smooth=smoothing_factor, + ) + + acg_3d, firing_rate_quantiles = _compute_3d_acg_one_unit( + sorting=sorting, + unit_id=unit_ids[0], + win_size=window_ms, + bin_size=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + ) + + assert np.allclose(npyx_3dacg, acg_3d, rtol=1e-2, atol=1e-2), "3D-ACG result does not match original implementation" + assert np.allclose( + npyx_quantiles, firing_rate_quantiles, rtol=1e-5, atol=1e-5 + ), "3D-ACG quantiles do not match original implementation" + + +# Check in the case where we have segments +@pytest.mark.skipif(not HAVE_NPYX, reason="npyx not installed") +@pytest.mark.parametrize("window_ms", [50.0, 100.0]) +@pytest.mark.parametrize("bin_ms", [1.0, 2.0]) +@pytest.mark.parametrize("num_firing_rate_quantiles", [10, 15]) +@pytest.mark.parametrize("segments", [2, 5, 10]) +def test_acgs_3d_original_implementation_with_segments(window_ms, bin_ms, num_firing_rate_quantiles, segments): + durations = [np.random.randint(30, 60) for _ in range(segments)] + sorting = generate_sorting(num_units=1, sampling_frequency=30000.0, durations=durations, firing_rates=50, seed=0) + unit_ids = sorting.get_unit_ids() + + # Simulate single segment by concatenating all segments + segment_spike_trains = [ + sorting.get_unit_spike_train(unit_id=unit_ids[0], segment_index=i) for i in range(sorting.get_num_segments()) + ] + segment_durations = [np.max(segment) + 1 for segment in segment_spike_trains] + segment_offsets = [0] + np.cumsum(segment_durations).tolist() + concat_times = np.concatenate( + [np.array(segment) + segment_offsets[i] for i, segment in enumerate(segment_spike_trains)] + ) + + npyx_quantiles, npyx_3dacg = npyx.corr.crosscorr_vs_firing_rate( + concat_times, + concat_times, + win_size=window_ms, + bin_size=bin_ms, + fs=sorting.sampling_frequency, + num_firing_rate_bins=num_firing_rate_quantiles, + ) + + acg_3d, firing_rate_quantiles = _compute_3d_acg_one_unit( + sorting=sorting, + unit_id=unit_ids[0], + win_size=window_ms, + bin_size=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + # use npyx quantiles to check for matching results. multiple segments will use first segments' quantiles + firing_rate_quantiles=npyx_quantiles.tolist(), + ) + + # Use a less strict tolerance here as we are only simulating a single segment + assert np.allclose( + npyx_3dacg, acg_3d, rtol=5e-2, atol=5e-2 + ), f"3D-ACG result for {segments} segments does not match original implementation with concatenated times" From 5bb0460e15373ed75cde5824c493704e02be7f7f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 11 Apr 2025 17:29:01 +0200 Subject: [PATCH 2/4] Use SI parallel machinery --- .../postprocessing/correlograms.py | 122 ++++++++++++------ 1 file changed, 86 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 21d8911752..826bf21b87 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -2,12 +2,18 @@ import importlib.util import warnings +import platform from copy import deepcopy +from tqdm.auto import tqdm + +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +import multiprocessing as mp +from threadpoolctl import threadpool_limits import numpy as np -from joblib import Parallel, delayed from spikeinterface.core import BaseSorting +from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.sortinganalyzer import ( AnalyzerExtension, SortingAnalyzer, @@ -630,9 +636,7 @@ class ComputeACG3D(AnalyzerExtension): - smoothing_factor (float): width of the boxcar filter for smoothing (in milliseconds). Default=250ms. Set to None to disable smoothing. - firing_rate_bins (array-like): Optional predefined firing rate bin edges. - If provided, num_firing_rate_bins is ignored. - - n_jobs (int): The number of parallel jobs spawned to compute the acgs across units. - Defaults to -1 (one job per cpu). + If provided, num_firing_rate_bins is ignored. Returns ------- @@ -660,7 +664,7 @@ class ComputeACG3D(AnalyzerExtension): depend_on = [] need_recording = False use_nodepipeline = False - need_job_kwargs = False + need_job_kwargs = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) @@ -671,14 +675,12 @@ def _set_params( bin_ms: float = 1.0, num_firing_rate_quantiles: int = 10, smoothing_factor: int = 250, - n_jobs: int = -1, ): params = dict( window_ms=window_ms, bin_ms=bin_ms, num_firing_rate_quantiles=num_firing_rate_quantiles, smoothing_factor=smoothing_factor, - n_jobs=n_jobs, ) return params @@ -704,6 +706,7 @@ def _merge_extension_data( bin_ms=self.params["bin_ms"], num_firing_rate_quantiles=self.params["num_firing_rate_quantiles"], smoothing_factor=self.params["smoothing_factor"], + **job_kwargs, ) new_unit_ids_indices = new_sorting.ids_to_indices(new_unit_ids) @@ -726,8 +729,8 @@ def _merge_extension_data( ) return new_data - def _run(self, verbose=False): - acgs_3d, firing_quantiles, bins = _compute_acgs_3d(self.sorting_analyzer.sorting, **self.params) + def _run(self, verbose=False, **job_kwargs): + acgs_3d, firing_quantiles, bins = _compute_acgs_3d(self.sorting_analyzer.sorting, **self.params, **job_kwargs) self.data["firing_quantiles"] = firing_quantiles self.data["acgs_3d"] = acgs_3d self.data["bins"] = bins @@ -743,7 +746,7 @@ def _compute_acgs_3d( bin_ms: float = 1.0, num_firing_rate_quantiles: int = 10, smoothing_factor: int = 250, - n_jobs: int = -1, + **job_kwargs, ): """ Computes the 3D autocorrelogram for a single unit. @@ -784,6 +787,25 @@ def _compute_acgs_3d( if unit_ids is None: unit_ids = sorting.unit_ids + job_kwargs = fix_job_kwargs(job_kwargs) + + n_jobs = job_kwargs["n_jobs"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] + mp_context = job_kwargs["mp_context"] + pool_engine = job_kwargs["pool_engine"] + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + + num_segments = sorting.get_num_segments() + if num_segments > 1: + warnings.warn( + "Multiple segments detected. Firing rate quantiles will be automatically computed on the first segment. " + "Manually define global firing_rate_quantiles if needed.", + ) + # pre-compute time bins winsize_bins = 2 * int(0.5 * window_ms * 1.0 / bin_ms) + 1 bin_times_ms = np.linspace(-window_ms / 2, window_ms / 2, num=winsize_bins) @@ -794,23 +816,45 @@ def _compute_acgs_3d( time_bins_ms = np.repeat(bin_times_ms, num_units, axis=0) - # Process units in parallel - results = Parallel(n_jobs=n_jobs)( - delayed(_compute_3d_acg_one_unit)( - sorting, - unit_id, - window_ms, - bin_ms, - num_firing_rate_quantiles=num_firing_rate_quantiles, - smoothing_factor=smoothing_factor, - ) + items = [ + (sorting, unit_id, window_ms, bin_ms, num_firing_rate_quantiles, smoothing_factor, max_threads_per_worker) for unit_id in unit_ids - ) - - # Unpack results - for unit_index, (acg_3d, firing_quantile) in enumerate(results): - acgs_3d[unit_index, :, :] = acg_3d - firing_quantiles[unit_index, :] = firing_quantile + ] + + if n_jobs > 1: + job_name = "calculate_acgs_3d" + if pool_engine == "process": + parallel_pool_class = ProcessPoolExecutor + pool_kwargs = dict(mp_context=mp.get_context(mp_context)) + desc = f"{job_name} (workers: {n_jobs} processes)" + else: + parallel_pool_class = ThreadPoolExecutor + pool_kwargs = dict() + desc = f"{job_name} (workers: {n_jobs} threads)" + + # Process units in parallel + with parallel_pool_class(max_workers=n_jobs, **pool_kwargs) as executor: + results = executor.map(_compute_3d_acg_one_unit_star, items) + if progress_bar: + results = tqdm(results, total=len(unit_ids), desc=desc) + for unit_index, (acg_3d, firing_quantile) in enumerate(results): + acgs_3d[unit_index, :, :] = acg_3d + firing_quantiles[unit_index, :] = firing_quantile + else: + # Process units in serial + for unit_index, (sorting, unit_id, window_ms, bin_ms, num_firing_rate_quantiles, smoothing_factor) in enumerate( + items + ): + acg_3d, firing_quantile = _compute_3d_acg_one_unit( + sorting, + unit_id, + window_ms, + bin_ms, + num_firing_rate_quantiles, + smoothing_factor, + ) + acgs_3d[unit_index, :, :] = acg_3d + firing_quantiles[unit_index, :] = firing_quantile return acgs_3d, firing_quantiles, time_bins_ms @@ -848,17 +892,9 @@ def _compute_3d_acg_one_unit( # Samples per bin samples_per_bin = int(np.ceil(fs / (1000 / bin_size))) - segments = sorting.get_num_segments() - if segments > 1 and firing_rate_quantiles is None: - warnings.warn( - "Multiple segments detected. Firing rate quantiles will be automatically computed on the first segment." - " Manually define global firing_rate_quantiles if needed.", - UserWarning, - 2, - ) - + num_segments = sorting.get_num_segments() # Convert times_1 and times_2 (which are in units of fs to units of bin_size) - for segment_index in range(segments): + for segment_index in range(num_segments): spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) # Convert to bin indices @@ -927,6 +963,20 @@ def _compute_3d_acg_one_unit( return acg_3d, firing_rate_quantiles +def _compute_3d_acg_one_unit_star(args): + """ + Helper function to compute the 3D ACG for a single unit. + This is used to parallelize the computation. + """ + max_threads_per_worker = args[-1] + new_args = args[:-1] + if max_threads_per_worker is None: + return _compute_3d_acg_one_unit(*new_args) + else: + with threadpool_limits(limits=int(max_threads_per_worker)): + return _compute_3d_acg_one_unit(*new_args) + + def compute_acgs_3d( sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting, window_ms: float = 50.0, From d3fddee8e49ee9b46271540ad64fee3300ddae75 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 11 Apr 2025 18:18:42 +0200 Subject: [PATCH 3/4] Fix tests --- src/spikeinterface/postprocessing/correlograms.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 826bf21b87..06be0252a2 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -842,9 +842,15 @@ def _compute_acgs_3d( firing_quantiles[unit_index, :] = firing_quantile else: # Process units in serial - for unit_index, (sorting, unit_id, window_ms, bin_ms, num_firing_rate_quantiles, smoothing_factor) in enumerate( - items - ): + for unit_index, ( + sorting, + unit_id, + window_ms, + bin_ms, + num_firing_rate_quantiles, + smoothing_factor, + _, + ) in enumerate(items): acg_3d, firing_quantile = _compute_3d_acg_one_unit( sorting, unit_id, From 55adef6c4eaa87a2f11a2bcf7896722a2134090e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 14 Apr 2025 10:34:45 +0200 Subject: [PATCH 4/4] Assertion messages, importlib, and job_kwargs docs --- .../postprocessing/correlograms.py | 20 +++++++++++-------- .../postprocessing/tests/test_correlograms.py | 16 +++++++-------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 06be0252a2..dd05bc1bd3 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -13,7 +13,7 @@ import numpy as np from spikeinterface.core import BaseSorting -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs, _shared_job_kwargs_doc from spikeinterface.core.sortinganalyzer import ( AnalyzerExtension, SortingAnalyzer, @@ -272,7 +272,7 @@ def _make_bins(sorting, window_ms, bin_ms) -> tuple[np.ndarray, int, int]: bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size num_bins = 2 * int(window_size / bin_size) - assert num_bins >= 1 + assert num_bins >= 1, "Number of bins must be >= 1" bins = np.arange(-window_size, window_size + bin_size, bin_size) * 1e3 / fs @@ -330,7 +330,7 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): bins : np.array The bins edges in ms """ - assert method in ("auto", "numba", "numpy") + assert method in ("auto", "numba", "numpy"), "method must be 'auto', 'numba' or 'numpy'" if method == "auto": method = "numba" if HAVE_NUMBA else "numpy" @@ -771,8 +771,7 @@ def _compute_acgs_3d( The number of quantiles to use for firing rate bins. smoothing_factor : float, default: 250 The width of the smoothing kernel in milliseconds. - n_jobs : int, default: -1. - Number of parallel jobs to spawn to compute the 3D-ACGS on different units. + {} Returns ------- @@ -841,7 +840,7 @@ def _compute_acgs_3d( acgs_3d[unit_index, :, :] = acg_3d firing_quantiles[unit_index, :] = firing_quantile else: - # Process units in serial + # Process units serially for unit_index, ( sorting, unit_id, @@ -868,6 +867,8 @@ def _compute_acgs_3d( register_result_extension(ComputeACG3D) compute_acgs_3d_sorting_analyzer = ComputeACG3D.function_factory() +_compute_acgs_3d.__doc__ = _compute_acgs_3d.__doc__.format(_shared_job_kwargs_doc) + def _compute_3d_acg_one_unit( sorting: BaseSorting, @@ -884,8 +885,8 @@ def _compute_3d_acg_one_unit( bin_size = np.clip(bin_size, 1000 * 1.0 / fs, 1e8) # in milliseconds win_size = np.clip(win_size, 1e-2, 1e8) # in milliseconds winsize_bins = 2 * int(0.5 * win_size * 1.0 / bin_size) + 1 # Both in millisecond - assert winsize_bins >= 1 - assert winsize_bins % 2 == 1 + assert winsize_bins >= 1, "Number of bins must be >= 1" + assert winsize_bins % 2 == 1, "Number of bins must be odd" bin_times_ms = np.linspace(-win_size / 2, win_size / 2, num=winsize_bins) if firing_rate_quantiles is not None: @@ -989,6 +990,7 @@ def compute_acgs_3d( bin_ms: float = 1.0, num_firing_rate_quantiles: int = 10, smoothing_factor: int = 250, + **job_kwargs, ): """ Compute 3D Autocorrelograms. See ComputeACG3D() for a detailed documentation. @@ -1003,6 +1005,7 @@ def compute_acgs_3d( bin_ms=bin_ms, num_firing_rate_quantiles=num_firing_rate_quantiles, smoothing_factor=smoothing_factor, + **job_kwargs, ) else: return _compute_acgs_3d( @@ -1011,6 +1014,7 @@ def compute_acgs_3d( bin_ms=bin_ms, num_firing_rate_quantiles=num_firing_rate_quantiles, smoothing_factor=smoothing_factor, + **job_kwargs, ) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index ae2a27cbba..183fba95d3 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -1,19 +1,19 @@ import numpy as np +import importlib -try: - import numba - +numba_spec = importlib.util.find_spec("numba") +if numba_spec is not None: HAVE_NUMBA = True -except ModuleNotFoundError as err: +else: HAVE_NUMBA = False -try: - import npyx - +npyx_spec = importlib.util.find_spec("npyx") +if npyx_spec is not None: HAVE_NPYX = True -except ModuleNotFoundError: +else: HAVE_NPYX = False + import pytest from pytest import param