-
Notifications
You must be signed in to change notification settings - Fork 210
Handle automatic chunks duration for SC2 #3721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 15 commits
1ebb6db
0ce092c
cf31e7f
b299dd3
7255530
b6dc572
26fc226
3819f8c
42d1c2c
d156f60
29eb160
ed319bf
8c1ca39
0a2bd23
b913be2
5d88e72
f746d41
d6b3bdb
a891d13
f2a3ac4
b72ee86
151fefb
0445d4e
6409824
b9b3457
5d4e516
03ccc7e
fcea1b6
3812ece
8c400f4
d6917c9
12b5edb
297d9b9
9bb2f33
c86a587
0e9468c
b7de6a1
8282461
ce56760
7cc16ce
c4faee6
d25105f
08213c4
eae719c
ba96c50
68c03f3
0138b17
bd07dc3
b68e2fb
b569023
08f579e
c26b092
16e9a51
7559588
51f753a
b7fb2bf
5b47453
456e257
71c9eec
c97dff3
a2add36
fa95a33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
from spikeinterface.core.sparsity import ChannelSparsity | ||
from spikeinterface.core.template import Templates | ||
from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer | ||
from spikeinterface.core.job_tools import split_job_kwargs | ||
from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs | ||
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer | ||
from spikeinterface.core.sparsity import ChannelSparsity | ||
from spikeinterface.core.analyzer_extension_core import ComputeTemplates | ||
|
@@ -249,19 +249,132 @@ def check_probe_for_drift_correction(recording, dist_x_max=60): | |
return True | ||
|
||
|
||
def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs): | ||
save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) | ||
def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None): | ||
""" | ||
Set the optimal chunk size for a job given the memory_limit and the number of jobs | ||
|
||
if mode == "memory": | ||
Parameters | ||
---------- | ||
|
||
recording: Recording | ||
The recording object | ||
job_kwargs: dict | ||
The job kwargs | ||
memory_limit: float | ||
The memory limit in fraction of available memory | ||
total_memory: str, Default None | ||
The total memory to use for the job in bytes | ||
|
||
Returns | ||
------- | ||
|
||
job_kwargs: dict | ||
The updated job kwargs | ||
""" | ||
job_kwargs = fix_job_kwargs(job_kwargs) | ||
n_jobs = job_kwargs["n_jobs"] | ||
if total_memory is None: | ||
if HAVE_PSUTIL: | ||
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" | ||
memory_usage = memory_limit * psutil.virtual_memory().available | ||
if recording.get_total_memory_size() < memory_usage: | ||
num_channels = recording.get_num_channels() | ||
dtype_size_bytes = recording.get_dtype().itemsize | ||
chunk_size = memory_usage / ((num_channels * dtype_size_bytes) * n_jobs) | ||
chunk_duration = chunk_size / recording.get_sampling_frequency() | ||
job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) | ||
else: | ||
print("psutil is required to use only a fraction of available memory") | ||
else: | ||
from spikeinterface.core.job_tools import convert_string_to_bytes | ||
|
||
total_memory = convert_string_to_bytes(total_memory) | ||
num_channels = recording.get_num_channels() | ||
dtype_size_bytes = recording.get_dtype().itemsize | ||
chunk_size = (num_channels * dtype_size_bytes) * n_jobs / total_memory | ||
chunk_duration = chunk_size / recording.get_sampling_frequency() | ||
job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) | ||
return job_kwargs | ||
|
||
|
||
def get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): | ||
""" | ||
Set the optimal chunk size for a job given the memory_limit and the number of jobs | ||
|
||
Parameters | ||
---------- | ||
|
||
recording: Recording | ||
The recording object | ||
ram_requested: dict | ||
The amount of RAM (in bytes) requested for the job | ||
memory_limit: float | ||
The memory limit in fraction of available memory | ||
|
||
Returns | ||
------- | ||
|
||
job_kwargs: dict | ||
The updated job kwargs | ||
""" | ||
job_kwargs = fix_job_kwargs(job_kwargs) | ||
n_jobs = job_kwargs["n_jobs"] | ||
if HAVE_PSUTIL: | ||
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" | ||
memory_usage = memory_limit * psutil.virtual_memory().available | ||
n_jobs = int(min(n_jobs, memory_usage // ram_requested)) | ||
yger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
job_kwargs.update(dict(n_jobs=n_jobs)) | ||
else: | ||
print("psutil is required to use only a fraction of available memory") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these prints or warnings? in general. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll use the warning package instead |
||
return job_kwargs | ||
|
||
|
||
def cache_preprocessing( | ||
recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs | ||
): | ||
""" | ||
Cache the preprocessing of a recording object | ||
|
||
Parameters | ||
---------- | ||
|
||
recording: Recording | ||
The recording object | ||
mode: str | ||
The mode to cache the preprocessing, can be 'memory', 'folder', 'zarr' or 'no-cache' | ||
memory_limit: float | ||
The memory limit in fraction of available memory | ||
total_memory: str, Default None | ||
The total memory to use for the job in bytes | ||
delete_cache: bool | ||
If True, delete the cache after the job | ||
**extra_kwargs: dict | ||
The extra kwargs for the job | ||
|
||
Returns | ||
------- | ||
|
||
recording: Recording | ||
The cached recording object | ||
""" | ||
|
||
save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) | ||
|
||
if mode == "memory": | ||
if total_memory is None: | ||
if HAVE_PSUTIL: | ||
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" | ||
memory_usage = memory_limit * psutil.virtual_memory().available | ||
if recording.get_total_memory_size() < memory_usage: | ||
recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) | ||
else: | ||
print("Recording too large to be preloaded in RAM...") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for these prints. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
else: | ||
print("psutil is required to preload in memory given only a fraction of available memory") | ||
else: | ||
if recording.get_total_memory_size() < total_memory: | ||
recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) | ||
else: | ||
print("Recording too large to be preloaded in RAM...") | ||
else: | ||
print("psutil is required to preload in memory") | ||
elif mode == "folder": | ||
recording = recording.save_to_folder(**extra_kwargs) | ||
elif mode == "zarr": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to say what units of RAM this will work with. I'm confused by this with the way it is currently written. I think users will need a bit more info to really understand how to use this.