Skip to content

Commit e1a8e16

Browse files
authored
Merge pull request #1584 from h-mayorquin/fix_edf_handle
EDFIO: Alleviate EDF single handle problem
2 parents 8e180f4 + bc0f518 commit e1a8e16

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

neo/rawio/edfrawio.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,14 @@ def _parse_header(self):
8383
# or continuous EDF+ files ('EDF+C' in header)
8484
if ("EDF+" in file_version_header) and ("EDF+C" not in file_version_header):
8585
raise ValueError("Only continuous EDF+ files are currently supported.")
86-
87-
self.edf_reader = EdfReader(self.filename)
86+
self._open_reader()
8887
# load headers, signal information and
8988
self.edf_header = self.edf_reader.getHeader()
9089
self.signal_headers = self.edf_reader.getSignalHeaders()
9190

9291
# add annotations to header
93-
annotations = self.edf_reader.readAnnotations()
94-
self.signal_annotations = [[s, d, a] for s, d, a in zip(*annotations)]
92+
self._edf_annotations = self.edf_reader.readAnnotations()
93+
self.signal_annotations = [[s, d, a] for s, d, a in zip(*self._edf_annotations)]
9594

9695
# 1 stream = 1 sampling rate
9796
stream_characteristics = []
@@ -120,7 +119,7 @@ def _parse_header(self):
120119
signal_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id, buffer_id))
121120

122121
# convert channel index lists to arrays for indexing
123-
self.stream_idx_to_chidx = {k: np.array(v) for k, v in self.stream_idx_to_chidx.items()}
122+
self.stream_idx_to_chidx = {k: np.asarray(v) for k, v in self.stream_idx_to_chidx.items()}
124123

125124
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
126125

@@ -174,6 +173,15 @@ def _parse_header(self):
174173
for array_key in array_keys:
175174
array_anno = {array_key: [h[array_key] for h in self.signal_headers]}
176175
seg_ann["signals"].append({"__array_annotations__": array_anno})
176+
177+
# We store the following attributes for rapid access without needing the reader
178+
179+
self._t_stop = self.edf_reader.datarecord_duration * self.edf_reader.datarecords_in_file
180+
# use sample count of first signal in stream
181+
self._stream_index_samples = {stream_index : self.edf_reader.getNSamples()[chidx][0] for stream_index, chidx in self.stream_idx_to_chidx.items()}
182+
self._number_of_events = len(self.edf_reader.readAnnotations()[0])
183+
184+
self.close()
177185

178186
def _get_stream_channels(self, stream_index):
179187
return self.header["signal_channels"][self.stream_idx_to_chidx[stream_index]]
@@ -183,14 +191,11 @@ def _segment_t_start(self, block_index, seg_index):
183191
return 0.0 # in seconds
184192

185193
def _segment_t_stop(self, block_index, seg_index):
186-
t_stop = self.edf_reader.datarecord_duration * self.edf_reader.datarecords_in_file
187194
# this must return an float scale in second
188-
return t_stop
195+
return self._t_stop
189196

190197
def _get_signal_size(self, block_index, seg_index, stream_index):
191-
chidx = self.stream_idx_to_chidx[stream_index][0]
192-
# use sample count of first signal in stream
193-
return self.edf_reader.getNSamples()[chidx]
198+
return self._stream_index_samples[stream_index]
194199

195200
def _get_signal_t_start(self, block_index, seg_index, stream_index):
196201
return 0.0 # EDF does not provide temporal offset information
@@ -219,12 +224,13 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
219224

220225
# load data into numpy array buffer
221226
data = []
227+
self._open_reader()
222228
for i, channel_idx in enumerate(selected_channel_idxs):
223229
# use int32 for compatibility with pyedflib
224230
buffer = np.empty(n, dtype=np.int32)
225231
self.edf_reader.read_digital_signal(channel_idx, i_start, n, buffer)
226232
data.append(buffer)
227-
233+
self._close_reader()
228234
# downgrade to int16 as this is what is used in the edf file format
229235
# use fortran (column major) order to be more efficient after transposing
230236
data = np.asarray(data, dtype=np.int16, order="F")
@@ -247,11 +253,11 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
247253
return None
248254

249255
def _event_count(self, block_index, seg_index, event_channel_index):
250-
return len(self.edf_reader.readAnnotations()[0])
256+
return self._number_of_events
251257

252258
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
253259
# these time should be already in seconds
254-
timestamps, durations, labels = self.edf_reader.readAnnotations()
260+
timestamps, durations, labels = self._edf_annotations
255261
if t_start is None:
256262
t_start = self.segment_t_start(block_index, seg_index)
257263
if t_stop is None:
@@ -281,6 +287,9 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index)
281287
def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
282288
return np.asarray(raw_duration, dtype=dtype)
283289

290+
def _open_reader(self):
291+
self.edf_reader = EdfReader(self.filename)
292+
284293
def __enter__(self):
285294
return self
286295

0 commit comments

Comments
 (0)