Skip to content
This repository was archived by the owner on Jun 6, 2023. It is now read-only.

Commit 3be474e

Browse files
authored
Merge pull request #600 from catalystneuro/fix_issue_587
fix issue 587
2 parents 322cb44 + 31717d6 commit 3be474e

File tree

1 file changed

+29
-14
lines changed

1 file changed

+29
-14
lines changed

spikeextractors/extractors/nwbextractors/nwbextractors.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train
1212

1313
try:
14-
import pynwb
1514
import pandas as pd
15+
import pynwb
1616
from pynwb import NWBHDF5IO
1717
from pynwb import NWBFile
1818
from pynwb.ecephys import ElectricalSeries, LFP
@@ -95,10 +95,14 @@ def get_nspikes(units_table, unit_id):
9595
return units_table['spike_times_index'].data[index] - units_table['spike_times_index'].data[index - 1]
9696

9797

98-
def most_relevant_ch(traces):
98+
def most_relevant_ch(traces: ArrayType):
9999
"""
100-
Calculates the most relevant channel for an Unit.
100+
Calculate the most relevant channel for a given Unit.
101+
101102
Estimates the channel where the max-min difference of the average traces is greatest.
103+
104+
Parameters
105+
----------
102106
traces : ndarray
103107
ndarray of shape (nSpikes, nChannels, nSamples)
104108
"""
@@ -113,7 +117,8 @@ def most_relevant_ch(traces):
113117
return relevant_ch
114118

115119

116-
def update_dict(d, u):
120+
def update_dict(d: dict, u: dict):
121+
"""Smart dictionary updates."""
117122
for k, v in u.items():
118123
if isinstance(v, abc.Mapping):
119124
d[k] = update_dict(d.get(k, {}), v)
@@ -122,9 +127,10 @@ def update_dict(d, u):
122127
return d
123128

124129

125-
def list_get(l, idx, default):
130+
def list_get(li: list, idx: int, default):
131+
"""Safe index retrieval from list."""
126132
try:
127-
return l[idx]
133+
return li[idx]
128134
except IndexError:
129135
return default
130136

@@ -153,6 +159,8 @@ def check_module(nwbfile, name: str, description: str = None):
153159

154160

155161
class NwbRecordingExtractor(se.RecordingExtractor):
162+
"""Primary class for interfacing between NWBFiles and RecordingExtractors."""
163+
156164
extractor_name = 'NwbRecording'
157165
has_default_locations = True
158166
has_unscaled = False
@@ -161,8 +169,10 @@ class NwbRecordingExtractor(se.RecordingExtractor):
161169
mode = 'file'
162170
installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n"
163171

164-
def __init__(self, file_path, electrical_series_name='ElectricalSeries'):
172+
def __init__(self, file_path: PathType, electrical_series_name: str = None):
165173
"""
174+
Load an NWBFile as a RecordingExtractor.
175+
166176
Parameters
167177
----------
168178
file_path: path to NWB file
@@ -178,9 +188,9 @@ def __init__(self, file_path, electrical_series_name='ElectricalSeries'):
178188
else:
179189
a_names = list(nwbfile.acquisition)
180190
if len(a_names) > 1:
181-
raise ValueError('More than one acquisition found. You must specify electrical_series.')
191+
raise ValueError("More than one acquisition found! You must specify 'electrical_series_name'.")
182192
if len(a_names) == 0:
183-
raise ValueError('No acquisitions found in the .nwb file.')
193+
raise ValueError("No acquisitions found in the .nwb file.")
184194
self._electrical_series_name = a_names[0]
185195
es = nwbfile.acquisition[self._electrical_series_name]
186196
if hasattr(es, 'timestamps') and es.timestamps:
@@ -194,7 +204,7 @@ def __init__(self, file_path, electrical_series_name='ElectricalSeries'):
194204
self.recording_start_time = 0.
195205

196206
self.num_frames = int(es.data.shape[0])
197-
num_channels = len(es.electrodes.table.id[:])
207+
num_channels = len(es.electrodes.data)
198208

199209
# Channels gains - for RecordingExtractor, these are values to cast traces to uV
200210
if es.channel_conversion is not None:
@@ -206,7 +216,7 @@ def __init__(self, file_path, electrical_series_name='ElectricalSeries'):
206216
unique_grp_names = list(np.unique(nwbfile.electrodes['group_name'][:]))
207217

208218
# Fill channel properties dictionary from electrodes table
209-
self.channel_ids = es.electrodes.table.id[es.electrodes.data]
219+
self.channel_ids = [es.electrodes.table.id[x] for x in es.electrodes.data]
210220

211221
# If gains are not 1, set has_scaled to True
212222
if np.any(gains != 1):
@@ -294,7 +304,13 @@ def make_nwb_metadata(self, nwbfile, es):
294304
})
295305

296306
@check_get_traces_args
297-
def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):
307+
def get_traces(
308+
self,
309+
channel_ids: ArrayType = None,
310+
start_frame: int = None,
311+
end_frame: int = None,
312+
return_scaled: bool = True
313+
):
298314
with NWBHDF5IO(self._path, 'r') as io:
299315
nwbfile = io.read()
300316
es = nwbfile.acquisition[self._electrical_series_name]
@@ -322,7 +338,7 @@ def get_num_frames(self):
322338
return self.num_frames
323339

324340
def get_channel_ids(self):
325-
return self.channel_ids.tolist()
341+
return self.channel_ids
326342

327343
@staticmethod
328344
def add_devices(recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None):
@@ -494,7 +510,6 @@ def add_electrodes(recording: se.RecordingExtractor, nwbfile=None, metadata: dic
494510
nwbfile.add_electrode_column('rel_y', 'y position of electrode in electrode group')
495511

496512
defaults = dict(
497-
id=np.nan,
498513
x=np.nan,
499514
y=np.nan,
500515
z=np.nan,

0 commit comments

Comments
 (0)