Skip to content

Commit 1c474e1

Browse files
authored
Merge pull request #3685 from chrishalcrow/add-RemoveBadChannel-class
Add `DetectAndRemoveBadChannelsRecording` and `DetectAndInterpolateBadChannelsRecording` classes
2 parents cd8f8c5 + a2a6f18 commit 1c474e1

File tree

7 files changed

+297
-68
lines changed

7 files changed

+297
-68
lines changed

doc/how_to/process_by_channel_group.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,13 @@ to any preprocessing function.
9797
shifted_recordings = spre.phase_shift(split_recording_dict)
9898
filtered_recording = spre.bandpass_filter(shifted_recording)
9999
referenced_recording = spre.common_reference(filtered_recording)
100+
good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording)
100101
101102
We can then aggregate the recordings back together using the ``aggregate_channels`` function
102103

103104
.. code-block:: python
104105
105-
combined_preprocessed_recording = aggregate_channels(referenced_recording)
106+
combined_preprocessed_recording = aggregate_channels(good_channels_recording)
106107
107108
Now, when ``combined_preprocessed_recording`` is used in sorting, plotting, or whenever
108109
calling its :py:func:`~get_traces` method, the data will have been

doc/modules/preprocessing.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,18 @@ interpolated with the :code:`interpolate_bad_channels()` function (channels labe
247247
# Case 2 : interpolate then
248248
rec_clean = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids)
249249
250+
Once you have tested these functions and decided on your workflow, you can use the `detect_and_*`
251+
functions to do everything at once. These return a Preprocessor class, so are consistent with
252+
the "chain" concept for this module. For example:
253+
254+
.. code-block:: python
255+
256+
# detect and remove bad channels
257+
rec_only_good_channels = detect_and_remove_bad_channels(recording=rec)
258+
259+
# detect and interpolate the bad channels
260+
rec_interpolated_channels = detect_and_interpolate_bad_channels(recording=rec)
261+
250262
251263
* :py:func:`~spikeinterface.preprocessing.detect_bad_channels()`
252264
* :py:func:`~spikeinterface.preprocessing.interpolate_bad_channels()`

src/spikeinterface/preprocessing/detect_bad_channels.py

Lines changed: 144 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,148 @@
44
import numpy as np
55
from typing import Literal
66

7+
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
78
from .filter import highpass_filter
89
from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording
10+
from spikeinterface.core.channelslice import ChannelSliceRecording
11+
12+
from inspect import signature
13+
14+
_bad_channel_detection_kwargs_doc = """Different methods are implemented:
15+
16+
* std : threhshold on channel standard deviations
17+
If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all
18+
channels standard deviations, the channel is flagged as noisy
19+
* mad : same as std, but using median absolute deviations instead
20+
* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types:
21+
* Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median)
22+
* Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz)
23+
and a high coherence with "far away" channels"
24+
* Out of brain channels are contigious regions of channels dissimilar to the median of all channels
25+
at the top end of the probe (i.e. large channel number)
26+
* neighborhood_r2
27+
A method tuned for LFP use-cases, where channels should be highly correlated with their spatial
28+
neighbors. This method estimates the correlation of each channel with the median of its spatial
29+
neighbors, and considers channels bad when this correlation is too small.
30+
31+
Parameters
32+
----------
33+
recording : BaseRecording
34+
The recording for which bad channels are detected
35+
method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd"
36+
The method to be used for bad channel detection
37+
std_mad_threshold : float, default: 5
38+
The standard deviation/mad multiplier threshold
39+
psd_hf_threshold : float, default: 0.02
40+
For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels.
41+
Channels with average power at >80% Nyquist larger than this threshold
42+
will be labeled as noise
43+
dead_channel_threshold : float, default: -0.5
44+
For coherence+psd - threshold for channel coherence below which channels are labeled as dead
45+
noisy_channel_threshold : float, default: 1
46+
Threshold for channel coherence above which channels are labeled as noisy (together with psd condition)
47+
outside_channel_threshold : float, default: -0.75
48+
For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside
49+
of the brain
50+
outside_channels_location : "top" | "bottom" | "both", default: "top"
51+
For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be
52+
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
53+
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
54+
marked as outside channels
55+
n_neighbors : int, default: 11
56+
For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd)
57+
nyquist_threshold : float, default: 0.8
58+
For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared
59+
with psd_hf_threshold
60+
direction : "x" | "y" | "z", default: "y"
61+
For coherence+psd - the depth dimension
62+
highpass_filter_cutoff : float, default: 300
63+
If the recording is not filtered, the cutoff frequency of the highpass filter
64+
chunk_duration_s : float, default: 0.5
65+
Duration of each chunk
66+
num_random_chunks : int, default: 100
67+
Number of random chunks
68+
Having many chunks is important for reproducibility.
69+
welch_window_ms : float, default: 10
70+
Window size for the scipy.signal.welch that will be converted to nperseg
71+
neighborhood_r2_threshold : float, default: 0.95
72+
R^2 threshold for the neighborhood_r2 method.
73+
neighborhood_r2_radius_um : float, default: 30
74+
Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method.
75+
seed : int or None, default: None
76+
The random seed to extract chunks
77+
"""
78+
79+
80+
class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording):
81+
"""
82+
Detects and removes bad channels. If `bad_channel_ids` are given,
83+
the detection is skipped and uses these instead.
84+
85+
{}
86+
bad_channel_ids : np.array | list | None, default: None
87+
If given, these are used rather than being detected.
88+
channel_labels : np.array | list | None, default: None
89+
If given, these are labels given to the channels by the
90+
detection process. Only intended for use when loading.
91+
92+
Returns
93+
-------
94+
removed_bad_channels_recording : DetectAndRemoveBadChannelsRecording
95+
The recording with bad channels removed
96+
"""
97+
98+
_precomputable_kwarg_names = ["bad_channel_ids", "channel_labels"]
99+
100+
def __init__(
101+
self,
102+
parent_recording: BaseRecording,
103+
bad_channel_ids=None,
104+
channel_labels=None,
105+
**detect_bad_channels_kwargs,
106+
):
107+
108+
if bad_channel_ids is None:
109+
bad_channel_ids, channel_labels = detect_bad_channels(
110+
recording=parent_recording, **detect_bad_channels_kwargs
111+
)
112+
else:
113+
channel_labels = None
114+
115+
self._main_ids = parent_recording.get_channel_ids()
116+
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)]
117+
118+
ChannelSliceRecording.__init__(
119+
self,
120+
parent_recording=parent_recording,
121+
channel_ids=new_channel_ids,
122+
)
123+
124+
self._kwargs.update({"bad_channel_ids": bad_channel_ids})
125+
if channel_labels is not None:
126+
self._kwargs.update({"channel_labels": channel_labels})
127+
128+
all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs)
129+
self._kwargs.update(all_bad_channels_kwargs)
130+
131+
132+
detect_and_remove_bad_channels = define_function_handling_dict_from_class(
133+
source_class=DetectAndRemoveBadChannelsRecording, name="detect_and_remove_bad_channels"
134+
)
135+
DetectAndRemoveBadChannelsRecording.__doc__ = DetectAndRemoveBadChannelsRecording.__doc__.format(
136+
_bad_channel_detection_kwargs_doc
137+
)
138+
139+
140+
def _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs):
141+
"""Get the default parameters from `detect_bad_channels`, and update with any user-specified parameters."""
142+
143+
sig = signature(detect_bad_channels)
144+
all_detect_bad_channels_kwargs = {
145+
k: v.default for k, v in sig.parameters.items() if k not in ["recording", "parent_recording"]
146+
}
147+
all_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs)
148+
return all_detect_bad_channels_kwargs
9149

10150

11151
def detect_bad_channels(
@@ -32,69 +172,7 @@ def detect_bad_channels(
32172
Perform bad channel detection.
33173
The recording is assumed to be filtered. If not, a highpass filter is applied on the fly.
34174
35-
Different methods are implemented:
36-
37-
* std : threhshold on channel standard deviations
38-
If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all
39-
channels standard deviations, the channel is flagged as noisy
40-
* mad : same as std, but using median absolute deviations instead
41-
* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types:
42-
* Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median)
43-
* Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz)
44-
and a high coherence with "far away" channels"
45-
* Out of brain channels are contigious regions of channels dissimilar to the median of all channels
46-
at the top end of the probe (i.e. large channel number)
47-
* neighborhood_r2
48-
A method tuned for LFP use-cases, where channels should be highly correlated with their spatial
49-
neighbors. This method estimates the correlation of each channel with the median of its spatial
50-
neighbors, and considers channels bad when this correlation is too small.
51-
52-
Parameters
53-
----------
54-
recording : BaseRecording
55-
The recording for which bad channels are detected
56-
method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd"
57-
The method to be used for bad channel detection
58-
std_mad_threshold : float, default: 5
59-
The standard deviation/mad multiplier threshold
60-
psd_hf_threshold : float, default: 0.02
61-
For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels.
62-
Channels with average power at >80% Nyquist larger than this threshold
63-
will be labeled as noise
64-
dead_channel_threshold : float, default: -0.5
65-
For coherence+psd - threshold for channel coherence below which channels are labeled as dead
66-
noisy_channel_threshold : float, default: 1
67-
Threshold for channel coherence above which channels are labeled as noisy (together with psd condition)
68-
outside_channel_threshold : float, default: -0.75
69-
For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside
70-
of the brain
71-
outside_channels_location : "top" | "bottom" | "both", default: "top"
72-
For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be
73-
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
74-
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
75-
marked as outside channels
76-
n_neighbors : int, default: 11
77-
For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd)
78-
nyquist_threshold : float, default: 0.8
79-
For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared
80-
with psd_hf_threshold
81-
direction : "x" | "y" | "z", default: "y"
82-
For coherence+psd - the depth dimension
83-
highpass_filter_cutoff : float, default: 300
84-
If the recording is not filtered, the cutoff frequency of the highpass filter
85-
chunk_duration_s : float, default: 0.5
86-
Duration of each chunk
87-
num_random_chunks : int, default: 100
88-
Number of random chunks
89-
Having many chunks is important for reproducibility.
90-
welch_window_ms : float, default: 10
91-
Window size for the scipy.signal.welch that will be converted to nperseg
92-
neighborhood_r2_threshold : float, default: 0.95
93-
R^2 threshold for the neighborhood_r2 method.
94-
neighborhood_r2_radius_um : float, default: 30
95-
Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method.
96-
seed : int or None, default: None
97-
The random seed to extract chunks
175+
{}
98176
99177
Returns
100178
-------
@@ -269,6 +347,9 @@ def detect_bad_channels(
269347
return bad_channel_ids, channel_labels
270348

271349

350+
detect_bad_channels.__doc__ = detect_bad_channels.__doc__.format(_bad_channel_detection_kwargs_doc)
351+
352+
272353
# ----------------------------------------------------------------------------------------------
273354
# IBL Detect Bad Channels
274355
# ----------------------------------------------------------------------------------------------

src/spikeinterface/preprocessing/interpolate_bad_channels.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22

33
import numpy as np
44

5-
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
5+
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording
66
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
77
from spikeinterface.preprocessing import preprocessing_tools
8+
from .detect_bad_channels import (
9+
_bad_channel_detection_kwargs_doc,
10+
detect_bad_channels,
11+
_get_all_detect_bad_channel_kwargs,
12+
)
13+
from inspect import signature
814

915

1016
class InterpolateBadChannelsRecording(BasePreprocessor):
@@ -82,6 +88,59 @@ def check_inputs(self, recording, bad_channel_ids):
8288
raise NotImplementedError("Channel spacing units must be um")
8389

8490

91+
class DetectAndInterpolateBadChannelsRecording(InterpolateBadChannelsRecording):
92+
"""
93+
Detects and interpolates bad channels. If `bad_channel_ids` are given,
94+
the detection is skipped and uses these instead.
95+
96+
{}
97+
bad_channel_ids : np.array | list | None, default: None
98+
If given, these are used rather than being detected.
99+
channel_labels : np.array | list | None, default: None
100+
If given, these are labels given to the channels by the
101+
detection process. Only intended for use when loading.
102+
103+
Returns
104+
-------
105+
interpolated_bad_channels_recording : DetectAndInterpolateBadChannelsRecording
106+
The recording with bad channels removed
107+
"""
108+
109+
_precomputable_kwarg_names = ["bad_channel_ids"]
110+
111+
def __init__(
112+
self,
113+
recording: BaseRecording,
114+
bad_channel_ids=None,
115+
**detect_bad_channels_kwargs,
116+
):
117+
if bad_channel_ids is None:
118+
bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs)
119+
else:
120+
channel_labels = None
121+
122+
InterpolateBadChannelsRecording.__init__(
123+
self,
124+
recording,
125+
bad_channel_ids=bad_channel_ids,
126+
)
127+
128+
self._kwargs.update({"bad_channel_ids": bad_channel_ids})
129+
if channel_labels is not None:
130+
self._kwargs.update({"channel_labels": channel_labels})
131+
132+
all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs)
133+
self._kwargs.update(all_bad_channels_kwargs)
134+
135+
136+
detect_and_interpolate_bad_channels = define_function_handling_dict_from_class(
137+
source_class=DetectAndInterpolateBadChannelsRecording, name="detect_and_interpolate_bad_channels"
138+
)
139+
DetectAndInterpolateBadChannelsRecording.__doc__ = DetectAndInterpolateBadChannelsRecording.__doc__.format(
140+
_bad_channel_detection_kwargs_doc
141+
)
142+
143+
85144
class InterpolateBadChannelsSegment(BasePreprocessorSegment):
86145
def __init__(self, parent_recording_segment, good_channel_indices, bad_channel_indices, weights):
87146
BasePreprocessorSegment.__init__(self, parent_recording_segment)

src/spikeinterface/preprocessing/preprocessing_classes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@
3838
from .zero_channel_pad import ZeroChannelPaddedRecording, zero_channel_pad
3939
from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate, train_deepinterpolation
4040
from .highpass_spatial_filter import HighpassSpatialFilterRecording, highpass_spatial_filter
41-
from .interpolate_bad_channels import InterpolateBadChannelsRecording, interpolate_bad_channels
41+
from .interpolate_bad_channels import (
42+
DetectAndInterpolateBadChannelsRecording,
43+
detect_and_interpolate_bad_channels,
44+
InterpolateBadChannelsRecording,
45+
interpolate_bad_channels,
46+
)
47+
from .detect_bad_channels import DetectAndRemoveBadChannelsRecording, detect_and_remove_bad_channels
4248
from .average_across_direction import AverageAcrossDirectionRecording, average_across_direction
4349
from .directional_derivative import DirectionalDerivativeRecording, directional_derivative
4450
from .depth_order import DepthOrderRecording, depth_order
@@ -63,6 +69,9 @@
6369
# re-reference
6470
CommonReferenceRecording: common_reference,
6571
PhaseShiftRecording: phase_shift,
72+
# bad channel detection/interpolation
73+
DetectAndRemoveBadChannelsRecording: detect_and_remove_bad_channels,
74+
DetectAndInterpolateBadChannelsRecording: detect_and_interpolate_bad_channels,
6675
# misc
6776
RectifyRecording: rectify,
6877
ClipRecording: clip,

0 commit comments

Comments
 (0)