diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index b56bb41eaa..6359b2b448 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -512,7 +512,7 @@ def get_random_recording_slices( chunk_duration : str | float | None, default "500ms" The duration of each chunk in 's' or 'ms' chunk_size : int | None - Size of a chunk in number of frames. This is ued only if chunk_duration is None. + Size of a chunk in number of frames. This is used only if chunk_duration is None. This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. concatenated : bool, default: True If True chunk are concatenated along time axis diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b775e27b64..3e3695b9ef 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -13,6 +13,7 @@ cache_preprocessing, get_prototype_and_waveforms_from_recording, get_shuffled_recording_slices, + _set_optimal_chunk_size, ) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity @@ -39,6 +40,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_preprocessing": True, "templates_from_svd": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, + "chunk_preprocessing": {"memory_limit": None}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.75}, "seed": 42, @@ -66,6 +68,9 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", + "chunk_preprocessing": "How much RAM (approximately) should be devoted to load all data chunks (given n_jobs).\ + memory_limit will control how much RAM can be used as a fraction of available memory. Otherwise, use total_memory to fix a hard limit, with\ + a string syntax (e.g. '1G', '500M')", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", "seed": "An int to control how chunks are shuffled while detecting peaks", @@ -100,8 +105,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) - recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + if params["chunk_preprocessing"].get("memory_limit", None) is not None: + job_kwargs = _set_optimal_chunk_size(recording, job_kwargs, **params["chunk_preprocessing"]) sampling_frequency = recording.get_sampling_frequency() num_channels = recording.get_num_channels() @@ -401,7 +407,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # np.save(fitting_folder / "amplitudes", guessed_amplitudes) if sorting.get_non_empty_unit_ids().size > 0: - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) + final_analyzer = final_cleaning_circus( + recording_w, sorting, templates, **merging_params, **job_kwargs + ) + final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer") + + sorting = final_analyzer.sorting if verbose: print(f"Kept {len(sorting.unit_ids)} units after final merging") @@ -460,4 +471,5 @@ def final_cleaning_circus( sparsity_overlap=sparsity_overlap, **job_kwargs, ) - return final_sa.sorting + + return final_sa diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 0fd58d3011..47d8dcf77c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -14,17 +14,12 @@ import random, string from spikeinterface.core import get_global_tmp_folder -from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.sortingcomponents.tools import remove_empty_templates +from spikeinterface.sortingcomponents.tools import remove_empty_templates, _get_optimal_n_jobs from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd - - from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel @@ -62,6 +57,7 @@ class CircusClustering: "noise_levels": None, "tmp_folder": None, "verbose": True, + "memory_limit": 0.25, "debug": False, } @@ -162,13 +158,17 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if not params["templates_from_svd"]: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording + job_kwargs_local = job_kwargs.copy() + unit_ids = np.unique(peak_labels) + ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4 + job_kwargs_local = _get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) templates = get_templates_from_peaks_and_recording( recording, peaks, peak_labels, ms_before, ms_after, - **job_kwargs, + **job_kwargs_local, ) else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index e4fd3c2539..6a380b4d56 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -12,7 +12,7 @@ from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer -from spikeinterface.core.job_tools import split_job_kwargs +from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.analyzer_extension_core import ComputeTemplates @@ -249,19 +249,144 @@ def check_probe_for_drift_correction(recording, dist_x_max=60): return True -def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs): - save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) +def _set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None): + """ + Set the optimal chunk size for a job given the memory_limit and the number of jobs - if mode == "memory": + Parameters + ---------- + + recording: Recording + The recording object + job_kwargs: dict + The job kwargs + memory_limit: float + The memory limit in fraction of available memory + total_memory: str, Default None + The total memory to use for the job in bytes + + Returns + ------- + + job_kwargs: dict + The updated job kwargs + """ + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + if total_memory is None: if HAVE_PSUTIL: assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" memory_usage = memory_limit * psutil.virtual_memory().available - if recording.get_total_memory_size() < memory_usage: - recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) + num_channels = recording.get_num_channels() + dtype_size_bytes = recording.get_dtype().itemsize + chunk_size = memory_usage / ((num_channels * dtype_size_bytes) * n_jobs) + chunk_duration = chunk_size / recording.get_sampling_frequency() + job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s")) + job_kwargs = fix_job_kwargs(job_kwargs) + else: + import warnings + + warnings.warn("psutil is required to use only a fraction of available memory") + else: + from spikeinterface.core.job_tools import convert_string_to_bytes + + total_memory = convert_string_to_bytes(total_memory) + num_channels = recording.get_num_channels() + dtype_size_bytes = recording.get_dtype().itemsize + chunk_size = (num_channels * dtype_size_bytes) * n_jobs / total_memory + chunk_duration = chunk_size / recording.get_sampling_frequency() + job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s")) + job_kwargs = fix_job_kwargs(job_kwargs) + return job_kwargs + + +def _get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): + """ + Set the optimal chunk size for a job given the memory_limit and the number of jobs + + Parameters + ---------- + + recording: Recording + The recording object + ram_requested: int + The amount of RAM (in bytes) requested for the job + memory_limit: float + The memory limit in fraction of available memory + + Returns + ------- + + job_kwargs: dict + The updated job kwargs + """ + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + if HAVE_PSUTIL: + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" + memory_usage = memory_limit * psutil.virtual_memory().available + n_jobs = max(1, int(min(n_jobs, memory_usage // ram_requested))) + job_kwargs.update(dict(n_jobs=n_jobs)) + else: + import warnings + + warnings.warn("psutil is required to use only a fraction of available memory") + return job_kwargs + + +def cache_preprocessing( + recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs +): + """ + Cache the preprocessing of a recording object + + Parameters + ---------- + + recording: Recording + The recording object + mode: str + The mode to cache the preprocessing, can be 'memory', 'folder', 'zarr' or 'no-cache' + memory_limit: float + The memory limit in fraction of available memory + total_memory: str, Default None + The total memory to use for the job in bytes + delete_cache: bool + If True, delete the cache after the job + **extra_kwargs: dict + The extra kwargs for the job + + Returns + ------- + + recording: Recording + The cached recording object + """ + + save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) + + if mode == "memory": + if total_memory is None: + if HAVE_PSUTIL: + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" + memory_usage = memory_limit * psutil.virtual_memory().available + if recording.get_total_memory_size() < memory_usage: + recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) + else: + import warnings + + warnings.warn("Recording too large to be preloaded in RAM...") else: - print("Recording too large to be preloaded in RAM...") + import warnings + + warnings.warn("psutil is required to preload in memory given only a fraction of available memory") else: - print("psutil is required to preload in memory") + if recording.get_total_memory_size() < total_memory: + recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) + else: + import warnings + + warnings.warn("Recording too large to be preloaded in RAM...") elif mode == "folder": recording = recording.save_to_folder(**extra_kwargs) elif mode == "zarr": diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index ee43f65852..c8e01cc598 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -120,7 +120,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if i < len(self.axes) - 1: self.axes[i, j].set_xticks([], []) - plt.tight_layout() + self.figure.tight_layout() for i, unit_id in enumerate(unit_ids): self.axes[0, i].set_title(str(unit_id))