diff --git a/doc/api.rst b/doc/api.rst index 7d20ed7c19..197c7f7d56 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -214,6 +214,7 @@ spikeinterface.postprocessing .. autofunction:: compute_spike_locations .. autofunction:: compute_template_similarity .. autofunction:: compute_correlograms + .. autofunction:: compute_acgs_3d .. autofunction:: compute_isi_histograms .. autofunction:: get_template_metric_names .. autofunction:: align_sorting diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 3c30f248c8..a312633a08 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -358,6 +358,26 @@ with shape (num_units, num_units, num_bins) with all correlograms for each pair For more information, see :py:func:`~spikeinterface.postprocessing.compute_correlograms` +acgs_3d +^^^^^^^ + +This extension computes the 3D Autocorrelograms (3D-ACG) from units' spike times to analyze how a neuron's temporal +firing pattern varies with its firing rate. The 3D-ACG, described in [Beau]_ et al., 2025, provides rich +representations of a unit's spike train statistics while accounting for firing rate modulations. + +.. code-block:: python + + acg3d = sorting_analyzer.compute( + input="acgs_3d", + window_ms=50.0, + bin_ms=1.0, + num_firing_rate_quantiles=10, + smoothing_factor=250, + ) + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_acgs_3d` + + isi_histograms ^^^^^^^^^^^^^^ diff --git a/doc/references.rst b/doc/references.rst index ce4672a9ca..6ba6efe6a1 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -44,6 +44,11 @@ please include the appropriate citation for the :code:`sorter_name` parameter yo - :code:`wavclus` [Chaure]_ - :code:`yass` [Lee]_ +Postprocessing Module +--------------------- + +If you use the :code:`acgs_3d` extensions, (i.e. :code:`postprocessing.compute_acgs_3d`, :code:`postprocessing.ComputeACG3D`) please cite [Beau]_ + Qualitymetrics Module --------------------- If you use the :code:`qualitymetrics` module, i.e. you use the :code:`analyzer.compute()` @@ -76,6 +81,8 @@ If you use the :code:`get_potential_auto_merge` method from the curation module, References ---------- +.. [Beau] `A deep learning strategy to identify cell types across species from high-density extracellular recordings. 2025. `_ + .. [Buccino] `SpikeInterface, a unified framework for spike sorting. 2020. `_ .. [Buzsáki] `The Log-Dynamic Brain: How Skewed Distributions Affect Network Operations. 2014. `_ @@ -112,6 +119,8 @@ References .. [Niediek] `Reliable Analysis of Single-Unit Recordings from the Human Brain under Noisy Conditions: Tracking Neurons over Hours. 2016. `_ +.. [npyx] `NeuroPyxels: loading, processing and plotting Neuropixels data in python. 2021. _` + .. [Pachitariu] `Spike sorting with Kilosort4. 2024. `_ .. [Pouzat] `Using noise signature to optimize spike-sorting and to assess neuronal classification quality. 2002. `_ diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 31bfdf1863..b1adbff281 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -19,7 +19,9 @@ from .spike_amplitudes import compute_spike_amplitudes, ComputeSpikeAmplitudes from .correlograms import ( + ComputeACG3D, ComputeCorrelograms, + compute_acgs_3d, compute_correlograms, correlogram_for_one_segment, ) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 8afa30f977..dd05bc1bd3 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -1,13 +1,27 @@ from __future__ import annotations -from copy import deepcopy import importlib.util +import warnings +import platform +from copy import deepcopy +from tqdm.auto import tqdm -import numpy as np -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension, SortingAnalyzer +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +import multiprocessing as mp +from threadpoolctl import threadpool_limits +import numpy as np -from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor +from spikeinterface.core import BaseSorting +from spikeinterface.core.job_tools import fix_job_kwargs, _shared_job_kwargs_doc +from spikeinterface.core.sortinganalyzer import ( + AnalyzerExtension, + SortingAnalyzer, + register_result_extension, +) +from spikeinterface.core.waveforms_extractor_backwards_compatibility import ( + MockWaveformExtractor, +) numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -108,10 +122,8 @@ def _merge_extension_data( if censor_ms is not None: # if censor_ms has no effect, can apply "soft" method. Check if any spikes have been removed for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): - num_segments = new_sorting_analyzer.get_num_segments() for segment_index in range(num_segments): - merged_spike_train_length = len( new_sorting_analyzer.sorting.get_unit_spike_train(new_unit_id, segment_index=segment_index) ) @@ -133,7 +145,6 @@ def _merge_extension_data( new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) new_data = dict(ccgs=new_ccgs, bins=new_bins) else: - # Make a transformation dict, which tells us how unit_indices from the # old to the new sorter are mapped. old_to_new_unit_index_map = {} @@ -161,7 +172,6 @@ def _merge_extension_data( correlograms, new_bins = deepcopy(self.get_data()) for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): - merge_unit_group_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_unit_group) # Sum unit rows of the correlogram matrix: C_{k,l} = C_{i,l} + C_{j,l} @@ -262,7 +272,7 @@ def _make_bins(sorting, window_ms, bin_ms) -> tuple[np.ndarray, int, int]: bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size num_bins = 2 * int(window_size / bin_size) - assert num_bins >= 1 + assert num_bins >= 1, "Number of bins must be >= 1" bins = np.arange(-window_size, window_size + bin_size, bin_size) * 1e3 / fs @@ -320,7 +330,7 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): bins : np.array The bins edges in ms """ - assert method in ("auto", "numba", "numpy") + assert method in ("auto", "numba", "numpy"), "method must be 'auto', 'numba' or 'numpy'" if method == "auto": method = "numba" if HAVE_NUMBA else "numpy" @@ -566,7 +576,6 @@ def _compute_correlograms_one_segment_numba( start_j = 0 for i in range(spike_times.size): for j in range(start_j, spike_times.size): - if i == j: continue @@ -595,3 +604,418 @@ def _compute_correlograms_one_segment_numba( bin = diff // bin_size correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1 + + +class ComputeACG3D(AnalyzerExtension): + """ + Computes the 3D Autocorrelograms (3D-ACG) from units spike times to analyze how a neuron's temporal firing + pattern varies with its firing rate. + + The 3D-ACG, originally described in [Beau]_ et al., 2025, provides a rich representations of a unit's + spike train statistics while accounting for firing rate modulations. + The method was developed to normalize for the impact of changes in firing rate on measures of firing statistics, + particularly in awake animals performing behavioral tasks where firing rates naturally vary over time. + + The approach works as follows: + 1. The instantaneous firing rate is calculated at each spike time using inverse ISI + 2. Firing rates are smoothed with a boxcar filter (default 250ms width) + 3. Spikes are grouped into firing rate bins (deciles by default) + 4. Separate ACGs are computed for each firing rate bin + + The result can be visualized as an image where: + - The y-axis represents different firing rate bins + - The x-axis represents time lag from the trigger spike + - The z-axis (color) represents spike probability + + Parameters + ---------- + - spike_times: vector of spike timestamps (in sample units) + - window_ms (float): window size for auto-correlation, in milliseconds + - bin_ms (float): bin size for auto-correlation, in milliseconds + - num_firing_rate_quantiles (integer): number of firing rate quantiles. Default=10 (deciles) + - smoothing_factor (float): width of the boxcar filter for smoothing (in milliseconds). + Default=250ms. Set to None to disable smoothing. + - firing_rate_bins (array-like): Optional predefined firing rate bin edges. + If provided, num_firing_rate_bins is ignored. + + Returns + ------- + - acg_3d (numpy.ndarray): 2D array with dimension (num_firing_rate_bins x num_timepoints), + where each element is the probability of observing a spike at the given time lag, + conditioned on the neuron's firing rate at the trigger spike time. + - firing_rate_quantiles (numpy.ndarray): The firing rate values that define the quantiles edges + + Notes + ----- + - The central bin (t=0) is set to 0 as it would always be 1 by definition + - Edge spikes are excluded to avoid boundary artifacts + - Spike counts are normalized by the total number of trigger spikes in each rate bin + + References + ---------- + Based on work in [Beau]_ et al., 2025. + + Adapted Python implementation from [npyx]_ : https://github.com/m-beau/NeuroPyxels/ + + Original author: David Herzfeld + """ + + extension_name = "acgs_3d" + depend_on = [] + need_recording = False + use_nodepipeline = False + need_job_kwargs = True + + def __init__(self, sorting_analyzer): + AnalyzerExtension.__init__(self, sorting_analyzer) + + def _set_params( + self, + window_ms: float = 50.0, + bin_ms: float = 1.0, + num_firing_rate_quantiles: int = 10, + smoothing_factor: int = 250, + ): + params = dict( + window_ms=window_ms, + bin_ms=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + ) + + return params + + def _select_extension_data(self, unit_ids): + # filter metrics dataframe + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + new_acgs_3d = self.data["acgs_3d"][unit_indices] + new_firing_quantiles = self.data["firing_quantiles"][unit_indices] + new_bins = self.data["bins"][:] + new_data = dict(acgs_3d=new_acgs_3d, firing_quantiles=new_firing_quantiles, bins=new_bins) + return new_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + new_sorting = new_sorting_analyzer.sorting + + acgs_3d, firing_rate_quantiles, _ = _compute_acgs_3d( + new_sorting, + unit_ids=new_unit_ids, + window_ms=self.params["window_ms"], + bin_ms=self.params["bin_ms"], + num_firing_rate_quantiles=self.params["num_firing_rate_quantiles"], + smoothing_factor=self.params["smoothing_factor"], + **job_kwargs, + ) + + new_unit_ids_indices = new_sorting.ids_to_indices(new_unit_ids) + old_unit_ids = [unit_id for unit_id in new_sorting_analyzer.unit_ids if unit_id not in new_unit_ids] + old_unit_ids_indices = new_sorting.ids_to_indices(old_unit_ids) + + new_acgs_3d = np.zeros((len(new_sorting.unit_ids), acgs_3d.shape[1], acgs_3d.shape[2])) + new_firing_quantiles = np.zeros((len(new_sorting.unit_ids), firing_rate_quantiles.shape[1])) + + new_acgs_3d[new_unit_ids_indices, :, :] = acgs_3d + new_acgs_3d[old_unit_ids_indices, :, :] = self.data["acgs_3d"][old_unit_ids_indices, :, :] + + new_firing_quantiles[new_unit_ids_indices, :] = firing_rate_quantiles + new_firing_quantiles[old_unit_ids_indices, :] = self.data["firing_quantiles"][old_unit_ids_indices, :] + + new_data = dict( + acgs_3d=new_acgs_3d, + firing_quantiles=new_firing_quantiles, + bins=self.data["bins"], + ) + return new_data + + def _run(self, verbose=False, **job_kwargs): + acgs_3d, firing_quantiles, bins = _compute_acgs_3d(self.sorting_analyzer.sorting, **self.params, **job_kwargs) + self.data["firing_quantiles"] = firing_quantiles + self.data["acgs_3d"] = acgs_3d + self.data["bins"] = bins + + def _get_data(self): + return self.data["acgs_3d"], self.data["firing_quantiles"], self.data["bins"] + + +def _compute_acgs_3d( + sorting: BaseSorting, + unit_ids=None, + window_ms: float = 50.0, + bin_ms: float = 1.0, + num_firing_rate_quantiles: int = 10, + smoothing_factor: int = 250, + **job_kwargs, +): + """ + Computes the 3D autocorrelogram for a single unit. + + See ComputeACG3D() for more details. + + Parameters + ---------- + sorting : Sorting + A SpikeInterface Sorting object + unit_ids : list of int, default: None + The unit ids to compute the autocorrelogram for. If None, + all units in the sorting are used. + window_ms : float, default: 50.0 + The window around the spike to compute the correlation in ms. For example, + if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. + bin_ms : float, default: 1.0 + The bin size in ms. This determines the bin size over which to + combine lags. For example, with a window size of -25 ms to 25 ms, and + bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ... + num_firing_rate_quantiles : int, default: 10 + The number of quantiles to use for firing rate bins. + smoothing_factor : float, default: 250 + The width of the smoothing kernel in milliseconds. + {} + + Returns + ------- + firing_quantiles : np.array + The firing rate quantiles used for each unit. + acgs_3d : np.array + The autocorrelograms for each unit at each firing rate quantile. + bins : np.array + The bin edges in ms + + """ + if unit_ids is None: + unit_ids = sorting.unit_ids + + job_kwargs = fix_job_kwargs(job_kwargs) + + n_jobs = job_kwargs["n_jobs"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] + mp_context = job_kwargs["mp_context"] + pool_engine = job_kwargs["pool_engine"] + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + + num_segments = sorting.get_num_segments() + if num_segments > 1: + warnings.warn( + "Multiple segments detected. Firing rate quantiles will be automatically computed on the first segment. " + "Manually define global firing_rate_quantiles if needed.", + ) + + # pre-compute time bins + winsize_bins = 2 * int(0.5 * window_ms * 1.0 / bin_ms) + 1 + bin_times_ms = np.linspace(-window_ms / 2, window_ms / 2, num=winsize_bins) + num_units = len(unit_ids) + winsize_bins = winsize_bins + acgs_3d = np.zeros((num_units, num_firing_rate_quantiles, winsize_bins)) + firing_quantiles = np.zeros((num_units, num_firing_rate_quantiles)) + + time_bins_ms = np.repeat(bin_times_ms, num_units, axis=0) + + items = [ + (sorting, unit_id, window_ms, bin_ms, num_firing_rate_quantiles, smoothing_factor, max_threads_per_worker) + for unit_id in unit_ids + ] + + if n_jobs > 1: + job_name = "calculate_acgs_3d" + if pool_engine == "process": + parallel_pool_class = ProcessPoolExecutor + pool_kwargs = dict(mp_context=mp.get_context(mp_context)) + desc = f"{job_name} (workers: {n_jobs} processes)" + else: + parallel_pool_class = ThreadPoolExecutor + pool_kwargs = dict() + desc = f"{job_name} (workers: {n_jobs} threads)" + + # Process units in parallel + with parallel_pool_class(max_workers=n_jobs, **pool_kwargs) as executor: + results = executor.map(_compute_3d_acg_one_unit_star, items) + if progress_bar: + results = tqdm(results, total=len(unit_ids), desc=desc) + for unit_index, (acg_3d, firing_quantile) in enumerate(results): + acgs_3d[unit_index, :, :] = acg_3d + firing_quantiles[unit_index, :] = firing_quantile + else: + # Process units serially + for unit_index, ( + sorting, + unit_id, + window_ms, + bin_ms, + num_firing_rate_quantiles, + smoothing_factor, + _, + ) in enumerate(items): + acg_3d, firing_quantile = _compute_3d_acg_one_unit( + sorting, + unit_id, + window_ms, + bin_ms, + num_firing_rate_quantiles, + smoothing_factor, + ) + acgs_3d[unit_index, :, :] = acg_3d + firing_quantiles[unit_index, :] = firing_quantile + + return acgs_3d, firing_quantiles, time_bins_ms + + +register_result_extension(ComputeACG3D) +compute_acgs_3d_sorting_analyzer = ComputeACG3D.function_factory() + +_compute_acgs_3d.__doc__ = _compute_acgs_3d.__doc__.format(_shared_job_kwargs_doc) + + +def _compute_3d_acg_one_unit( + sorting: BaseSorting, + unit_id: int | str, + win_size: float, + bin_size: float, + num_firing_rate_quantiles: int = 10, + smoothing_factor: int = 250, + use_spikes_around_times1_for_deciles: bool = True, + firing_rate_quantiles: list | None = None, +): + fs = sorting.sampling_frequency + + bin_size = np.clip(bin_size, 1000 * 1.0 / fs, 1e8) # in milliseconds + win_size = np.clip(win_size, 1e-2, 1e8) # in milliseconds + winsize_bins = 2 * int(0.5 * win_size * 1.0 / bin_size) + 1 # Both in millisecond + assert winsize_bins >= 1, "Number of bins must be >= 1" + assert winsize_bins % 2 == 1, "Number of bins must be odd" + bin_times_ms = np.linspace(-win_size / 2, win_size / 2, num=winsize_bins) + + if firing_rate_quantiles is not None: + num_firing_rate_quantiles = len(firing_rate_quantiles) + spike_counts = np.zeros( + (num_firing_rate_quantiles, len(bin_times_ms)) + ) # Counts number of occurences of spikes in a given bin in time axis + firing_rate_bin_occurence = np.zeros(num_firing_rate_quantiles, dtype=np.int64) # total occurence + + # Samples per bin + samples_per_bin = int(np.ceil(fs / (1000 / bin_size))) + + num_segments = sorting.get_num_segments() + # Convert times_1 and times_2 (which are in units of fs to units of bin_size) + for segment_index in range(num_segments): + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + + # Convert to bin indices + spike_bins = np.floor(spike_times / samples_per_bin).astype(np.int64) + + if len(spike_bins) <= 1: + continue # Need at least 2 spikes for ACG + + # Create a binary spike train spanning the entire time range + max_bin = int(np.ceil(spike_bins[-1] + 1)) + spiketrain = np.zeros(max_bin + winsize_bins, dtype=bool) # Add extra space for window + spiketrain[spike_bins] = True + + # Convert spikes to firing rate using the inverse ISI method + firing_rate = np.zeros(max_bin) + for i in range(1, len(spike_bins) - 1): + start = 0 if i == 0 else (spike_bins[i - 1] + (spike_bins[i] - spike_bins[i - 1]) // 2) + stop = max_bin if i == len(spike_bins) - 1 else (spike_bins[i] + (spike_bins[i + 1] - spike_bins[i]) // 2) + current_firing_rate = 1.0 / ((stop - start) * (bin_size / 1000)) + firing_rate[start:stop] = current_firing_rate + + # Smooth the firing rate using numpy convolution function if requested + if isinstance(smoothing_factor, (int, float)) and smoothing_factor > 0: + kernel_size = int(np.ceil(smoothing_factor / bin_size)) + half_kernel_size = kernel_size // 2 + kernel = np.ones(kernel_size) / kernel_size + smoothed_firing_rate = np.convolve(firing_rate, kernel, mode="same") + + # Correct manually for possible artefacts at the edges + for i in range(kernel_size): + start = max(0, i - half_kernel_size) + stop = min(len(firing_rate), i + half_kernel_size) + smoothed_firing_rate[i] = np.mean(firing_rate[start:stop]) + for i in range(len(firing_rate) - kernel_size, len(firing_rate)): + start = max(0, i - half_kernel_size) + stop = min(len(firing_rate), i + half_kernel_size) + + smoothed_firing_rate[i] = np.mean(firing_rate[start:stop]) + firing_rate = smoothed_firing_rate + + # Get firing rate quantiles + if firing_rate_quantiles is None: + quantile_bins = np.linspace(0, 1, num_firing_rate_quantiles + 2)[1:-1] + if use_spikes_around_times1_for_deciles: + firing_rate_quantiles = np.quantile(firing_rate[spike_bins], quantile_bins) + else: + firing_rate_quantiles = np.quantile(firing_rate, quantile_bins) + for i, spike_index in enumerate(spike_bins): + start = spike_index + int(np.ceil(bin_times_ms[0] / bin_size)) + stop = start + len(bin_times_ms) + if (start < 0) or (stop >= len(spiketrain)) or spike_index < spike_bins[0] or spike_index >= spike_bins[-1]: + continue # Skip these spikes to avoid edge artifacts + current_firing_rate = firing_rate[spike_index] # Firing of neuron 2 at neuron 1's spike index + current_firing_rate_bin_number = np.argmax(firing_rate_quantiles >= current_firing_rate) + if current_firing_rate_bin_number == 0 and current_firing_rate > firing_rate_quantiles[0]: + current_firing_rate_bin_number = len(firing_rate_quantiles) - 1 + spike_counts[current_firing_rate_bin_number, :] += spiketrain[start:stop] + firing_rate_bin_occurence[current_firing_rate_bin_number] += 1 + + acg_3d = spike_counts / (np.ones((len(bin_times_ms), num_firing_rate_quantiles)) * firing_rate_bin_occurence).T + # Divison by zero cases will return nans, so we fix this + acg_3d = np.nan_to_num(acg_3d) + # remove bin 0 which will always be 1 + acg_3d[:, acg_3d.shape[1] // 2] = 0 + + return acg_3d, firing_rate_quantiles + + +def _compute_3d_acg_one_unit_star(args): + """ + Helper function to compute the 3D ACG for a single unit. + This is used to parallelize the computation. + """ + max_threads_per_worker = args[-1] + new_args = args[:-1] + if max_threads_per_worker is None: + return _compute_3d_acg_one_unit(*new_args) + else: + with threadpool_limits(limits=int(max_threads_per_worker)): + return _compute_3d_acg_one_unit(*new_args) + + +def compute_acgs_3d( + sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting, + window_ms: float = 50.0, + bin_ms: float = 1.0, + num_firing_rate_quantiles: int = 10, + smoothing_factor: int = 250, + **job_kwargs, +): + """ + Compute 3D Autocorrelograms. See ComputeACG3D() for a detailed documentation. + """ + if isinstance(sorting_analyzer_or_sorting, MockWaveformExtractor): + sorting_analyzer_or_sorting = sorting_analyzer_or_sorting.sorting + + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + return compute_acgs_3d_sorting_analyzer( + sorting_analyzer_or_sorting, + window_ms=window_ms, + bin_ms=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + **job_kwargs, + ) + else: + return _compute_acgs_3d( + sorting_analyzer_or_sorting, + window_ms=window_ms, + bin_ms=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + **job_kwargs, + ) + + +compute_acgs_3d.__doc__ = compute_acgs_3d_sorting_analyzer.__doc__ diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 3bdfb38e06..183fba95d3 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -1,28 +1,37 @@ import numpy as np +import importlib -try: - import numba - +numba_spec = importlib.util.find_spec("numba") +if numba_spec is not None: HAVE_NUMBA = True -except ModuleNotFoundError as err: +else: HAVE_NUMBA = False +npyx_spec = importlib.util.find_spec("npyx") +if npyx_spec is not None: + HAVE_NPYX = True +else: + HAVE_NPYX = False + + +import pytest +from pytest import param + from spikeinterface import NumpySorting, generate_sorting -from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import ComputeCorrelograms +from spikeinterface.postprocessing import ComputeACG3D, ComputeCorrelograms from spikeinterface.postprocessing.correlograms import ( + _compute_3d_acg_one_unit, _compute_correlograms_on_sorting, _make_bins, + compute_acgs_3d, compute_correlograms, ) -import pytest -from pytest import param +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") class TestComputeCorrelograms(AnalyzerExtensionCommonTestSuite): - @pytest.mark.parametrize( "params", [ @@ -368,3 +377,119 @@ def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin expected_result_corr[int(num_bins / 2)] = num_filled_bins return window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr + + +######################################### +# 3D ACG Tests +######################################### + + +class TestComputeACG3D(AnalyzerExtensionCommonTestSuite): + @pytest.mark.parametrize( + "params", + [ + dict(window_ms=50.0, bin_ms=1.0, num_firing_rate_quantiles=10, smoothing_factor=250), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeACG3D, params) + + @pytest.mark.parametrize("num_firing_rate_quantiles", [10, 15]) + def test_sortinganalyzer_acgs_3d(self, num_firing_rate_quantiles): + """ + Test the outputs when using SortingAnalyzer against + the output passing sorting directly to `compute_acgs_3d`. + """ + sorting_analyzer = self._prepare_sorting_analyzer("memory", sparse=False, extension_class=ComputeACG3D) + + params = dict(num_firing_rate_quantiles=num_firing_rate_quantiles, window_ms=100, bin_ms=6.5) + ext_numpy = sorting_analyzer.compute(ComputeACG3D.extension_name, **params) + + result_sorting, quantiles_sorting, time_bins = compute_acgs_3d(self.sorting, **params) + + assert np.array_equal(result_sorting, ext_numpy.data["acgs_3d"]) + assert np.array_equal(quantiles_sorting, ext_numpy.data["firing_quantiles"]) + assert np.array_equal(time_bins, ext_numpy.data["bins"]) + + +@pytest.mark.skipif(not HAVE_NPYX, reason="npyx not installed") +@pytest.mark.parametrize("window_ms", [50.0, 100.0]) +@pytest.mark.parametrize("bin_ms", [1.0, 2.0]) +@pytest.mark.parametrize("num_firing_rate_quantiles", [10, 15]) +@pytest.mark.parametrize("smoothing_factor", [250, 500]) +def test_acgs_3d_original_implementation(window_ms, bin_ms, num_firing_rate_quantiles, smoothing_factor): + sorting = generate_sorting(num_units=1, sampling_frequency=30000.0, durations=[100], firing_rates=60, seed=0) + + unit_ids = sorting.get_unit_ids() + + times_1 = times_2 = sorting.get_unit_spike_train(unit_id=unit_ids[0]) + + npyx_quantiles, npyx_3dacg = npyx.corr.crosscorr_vs_firing_rate( + times_1, + times_2, + win_size=window_ms, + bin_size=bin_ms, + fs=sorting.sampling_frequency, + num_firing_rate_bins=num_firing_rate_quantiles, + smooth=smoothing_factor, + ) + + acg_3d, firing_rate_quantiles = _compute_3d_acg_one_unit( + sorting=sorting, + unit_id=unit_ids[0], + win_size=window_ms, + bin_size=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + smoothing_factor=smoothing_factor, + ) + + assert np.allclose(npyx_3dacg, acg_3d, rtol=1e-2, atol=1e-2), "3D-ACG result does not match original implementation" + assert np.allclose( + npyx_quantiles, firing_rate_quantiles, rtol=1e-5, atol=1e-5 + ), "3D-ACG quantiles do not match original implementation" + + +# Check in the case where we have segments +@pytest.mark.skipif(not HAVE_NPYX, reason="npyx not installed") +@pytest.mark.parametrize("window_ms", [50.0, 100.0]) +@pytest.mark.parametrize("bin_ms", [1.0, 2.0]) +@pytest.mark.parametrize("num_firing_rate_quantiles", [10, 15]) +@pytest.mark.parametrize("segments", [2, 5, 10]) +def test_acgs_3d_original_implementation_with_segments(window_ms, bin_ms, num_firing_rate_quantiles, segments): + durations = [np.random.randint(30, 60) for _ in range(segments)] + sorting = generate_sorting(num_units=1, sampling_frequency=30000.0, durations=durations, firing_rates=50, seed=0) + unit_ids = sorting.get_unit_ids() + + # Simulate single segment by concatenating all segments + segment_spike_trains = [ + sorting.get_unit_spike_train(unit_id=unit_ids[0], segment_index=i) for i in range(sorting.get_num_segments()) + ] + segment_durations = [np.max(segment) + 1 for segment in segment_spike_trains] + segment_offsets = [0] + np.cumsum(segment_durations).tolist() + concat_times = np.concatenate( + [np.array(segment) + segment_offsets[i] for i, segment in enumerate(segment_spike_trains)] + ) + + npyx_quantiles, npyx_3dacg = npyx.corr.crosscorr_vs_firing_rate( + concat_times, + concat_times, + win_size=window_ms, + bin_size=bin_ms, + fs=sorting.sampling_frequency, + num_firing_rate_bins=num_firing_rate_quantiles, + ) + + acg_3d, firing_rate_quantiles = _compute_3d_acg_one_unit( + sorting=sorting, + unit_id=unit_ids[0], + win_size=window_ms, + bin_size=bin_ms, + num_firing_rate_quantiles=num_firing_rate_quantiles, + # use npyx quantiles to check for matching results. multiple segments will use first segments' quantiles + firing_rate_quantiles=npyx_quantiles.tolist(), + ) + + # Use a less strict tolerance here as we are only simulating a single segment + assert np.allclose( + npyx_3dacg, acg_3d, rtol=5e-2, atol=5e-2 + ), f"3D-ACG result for {segments} segments does not match original implementation with concatenated times"