11
11
from spikeextractors .extraction_tools import check_get_traces_args , check_get_unit_spike_train
12
12
13
13
try :
14
- import pynwb
15
14
import pandas as pd
15
+ import pynwb
16
16
from pynwb import NWBHDF5IO
17
17
from pynwb import NWBFile
18
18
from pynwb .ecephys import ElectricalSeries , LFP
@@ -95,10 +95,14 @@ def get_nspikes(units_table, unit_id):
95
95
return units_table ['spike_times_index' ].data [index ] - units_table ['spike_times_index' ].data [index - 1 ]
96
96
97
97
98
- def most_relevant_ch (traces ):
98
+ def most_relevant_ch (traces : ArrayType ):
99
99
"""
100
- Calculates the most relevant channel for an Unit.
100
+ Calculate the most relevant channel for a given Unit.
101
+
101
102
Estimates the channel where the max-min difference of the average traces is greatest.
103
+
104
+ Parameters
105
+ ----------
102
106
traces : ndarray
103
107
ndarray of shape (nSpikes, nChannels, nSamples)
104
108
"""
@@ -113,7 +117,8 @@ def most_relevant_ch(traces):
113
117
return relevant_ch
114
118
115
119
116
- def update_dict (d , u ):
120
+ def update_dict (d : dict , u : dict ):
121
+ """Smart dictionary updates."""
117
122
for k , v in u .items ():
118
123
if isinstance (v , abc .Mapping ):
119
124
d [k ] = update_dict (d .get (k , {}), v )
@@ -122,9 +127,10 @@ def update_dict(d, u):
122
127
return d
123
128
124
129
125
- def list_get (l , idx , default ):
130
+ def list_get (li : list , idx : int , default ):
131
+ """Safe index retrieval from list."""
126
132
try :
127
- return l [idx ]
133
+ return li [idx ]
128
134
except IndexError :
129
135
return default
130
136
@@ -153,6 +159,8 @@ def check_module(nwbfile, name: str, description: str = None):
153
159
154
160
155
161
class NwbRecordingExtractor (se .RecordingExtractor ):
162
+ """Primary class for interfacing between NWBFiles and RecordingExtractors."""
163
+
156
164
extractor_name = 'NwbRecording'
157
165
has_default_locations = True
158
166
has_unscaled = False
@@ -161,8 +169,10 @@ class NwbRecordingExtractor(se.RecordingExtractor):
161
169
mode = 'file'
162
170
installation_mesg = "To use the Nwb extractors, install pynwb: \n \n pip install pynwb\n \n "
163
171
164
- def __init__ (self , file_path , electrical_series_name = 'ElectricalSeries' ):
172
+ def __init__ (self , file_path : PathType , electrical_series_name : str = None ):
165
173
"""
174
+ Load an NWBFile as a RecordingExtractor.
175
+
166
176
Parameters
167
177
----------
168
178
file_path: path to NWB file
@@ -178,9 +188,9 @@ def __init__(self, file_path, electrical_series_name='ElectricalSeries'):
178
188
else :
179
189
a_names = list (nwbfile .acquisition )
180
190
if len (a_names ) > 1 :
181
- raise ValueError (' More than one acquisition found. You must specify electrical_series.' )
191
+ raise ValueError (" More than one acquisition found! You must specify 'electrical_series_name'." )
182
192
if len (a_names ) == 0 :
183
- raise ValueError (' No acquisitions found in the .nwb file.' )
193
+ raise ValueError (" No acquisitions found in the .nwb file." )
184
194
self ._electrical_series_name = a_names [0 ]
185
195
es = nwbfile .acquisition [self ._electrical_series_name ]
186
196
if hasattr (es , 'timestamps' ) and es .timestamps :
@@ -194,7 +204,7 @@ def __init__(self, file_path, electrical_series_name='ElectricalSeries'):
194
204
self .recording_start_time = 0.
195
205
196
206
self .num_frames = int (es .data .shape [0 ])
197
- num_channels = len (es .electrodes .table . id [:] )
207
+ num_channels = len (es .electrodes .data )
198
208
199
209
# Channels gains - for RecordingExtractor, these are values to cast traces to uV
200
210
if es .channel_conversion is not None :
@@ -206,7 +216,7 @@ def __init__(self, file_path, electrical_series_name='ElectricalSeries'):
206
216
unique_grp_names = list (np .unique (nwbfile .electrodes ['group_name' ][:]))
207
217
208
218
# Fill channel properties dictionary from electrodes table
209
- self .channel_ids = es .electrodes .table .id [es .electrodes .data ]
219
+ self .channel_ids = [ es .electrodes .table .id [x ] for x in es .electrodes .data ]
210
220
211
221
# If gains are not 1, set has_scaled to True
212
222
if np .any (gains != 1 ):
@@ -294,7 +304,13 @@ def make_nwb_metadata(self, nwbfile, es):
294
304
})
295
305
296
306
@check_get_traces_args
297
- def get_traces (self , channel_ids = None , start_frame = None , end_frame = None , return_scaled = True ):
307
+ def get_traces (
308
+ self ,
309
+ channel_ids : ArrayType = None ,
310
+ start_frame : int = None ,
311
+ end_frame : int = None ,
312
+ return_scaled : bool = True
313
+ ):
298
314
with NWBHDF5IO (self ._path , 'r' ) as io :
299
315
nwbfile = io .read ()
300
316
es = nwbfile .acquisition [self ._electrical_series_name ]
@@ -322,7 +338,7 @@ def get_num_frames(self):
322
338
return self .num_frames
323
339
324
340
def get_channel_ids (self ):
325
- return self .channel_ids . tolist ()
341
+ return self .channel_ids
326
342
327
343
@staticmethod
328
344
def add_devices (recording : se .RecordingExtractor , nwbfile = None , metadata : dict = None ):
@@ -494,7 +510,6 @@ def add_electrodes(recording: se.RecordingExtractor, nwbfile=None, metadata: dic
494
510
nwbfile .add_electrode_column ('rel_y' , 'y position of electrode in electrode group' )
495
511
496
512
defaults = dict (
497
- id = np .nan ,
498
513
x = np .nan ,
499
514
y = np .nan ,
500
515
z = np .nan ,
0 commit comments