Skip to content

Commit fbac290

Browse files
committed
refactor timestamp masking
1 parent e7181b9 commit fbac290

File tree

1 file changed

+17
-32
lines changed

1 file changed

+17
-32
lines changed

neo/rawio/plexon2rawio/plexon2rawio.py

+17-32
Original file line numberDiff line numberDiff line change
@@ -380,19 +380,7 @@ def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index, t_s
380380

381381
spike_timestamps, unit_ids, waveforms = self._spike_channel_cache[channel_name]
382382

383-
if t_start is not None or t_stop is not None:
384-
# restrict spikes to given limits (in seconds)
385-
timestamp_frequency = self.pl2reader.pl2_file_info.m_TimestampFrequency
386-
lim0 = int(t_start * timestamp_frequency)
387-
lim1 = int(t_stop * self.pl2reader.pl2_file_info.m_TimestampFrequency)
388-
389-
# limits are with respect to segment t_start and not to time 0
390-
lim0 -= self.pl2reader.pl2_file_info.m_StartRecordingTime
391-
lim1 -= self.pl2reader.pl2_file_info.m_StartRecordingTime
392-
393-
time_mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1)
394-
else:
395-
time_mask = slice(None, None)
383+
time_mask = self._get_timestamp_time_mask(t_start, t_stop, spike_timestamps)
396384

397385
unit_mask = unit_ids[time_mask] == channel_unit_id
398386
spike_timestamps = spike_timestamps[time_mask][unit_mask]
@@ -425,25 +413,33 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
425413

426414
spike_timestamps, unit_ids, waveforms = self._spike_channel_cache[channel_name]
427415

416+
time_mask = self._get_timestamp_time_mask(t_start, t_stop, spike_timestamps)
417+
418+
unit_mask = unit_ids[time_mask] == int(channel_unit_id)
419+
waveforms = waveforms[time_mask][unit_mask]
420+
421+
# add tetrode dimension
422+
waveforms = np.expand_dims(waveforms, axis=1)
423+
return waveforms
424+
425+
def _get_timestamp_time_mask(self, t_start, t_stop, timestamps):
426+
428427
if t_start is not None or t_stop is not None:
429428
# restrict spikes to given limits (in seconds)
430429
timestamp_frequency = self.pl2reader.pl2_file_info.m_TimestampFrequency
431430
lim0 = int(t_start * timestamp_frequency)
432431
lim1 = int(t_stop * self.pl2reader.pl2_file_info.m_TimestampFrequency)
433-
time_mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1)
434432

435433
# limits are with respect to segment t_start and not to time 0
436434
lim0 -= self.pl2reader.pl2_file_info.m_StartRecordingTime
437435
lim1 -= self.pl2reader.pl2_file_info.m_StartRecordingTime
436+
437+
time_mask = (timestamps >= lim0) & (timestamps <= lim1)
438+
438439
else:
439440
time_mask = slice(None, None)
440441

441-
unit_mask = unit_ids[time_mask] == int(channel_unit_id)
442-
waveforms = waveforms[time_mask][unit_mask]
443-
444-
# add tetrode dimension
445-
waveforms = np.expand_dims(waveforms, axis=1)
446-
return waveforms
442+
return time_mask
447443

448444
def _event_count(self, block_index, seg_index, event_channel_index):
449445

@@ -474,18 +470,7 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
474470
event_times, labels = self._event_channel_cache[channel_name]
475471
labels = np.asarray(labels, dtype='U')
476472

477-
if t_start is not None or t_stop is not None:
478-
# restrict events to given limits (in seconds)
479-
timestamp_frequency = self.pl2reader.pl2_file_info.m_TimestampFrequency
480-
lim0 = int(t_start * timestamp_frequency)
481-
lim1 = int(t_stop * self.pl2reader.pl2_file_info.m_TimestampFrequency)
482-
time_mask = (event_times >= lim0) & (event_times <= lim1)
483-
484-
# limits are with respect to segment t_start and not to time 0
485-
lim0 -= self.pl2reader.pl2_file_info.m_StartRecordingTime
486-
lim1 -= self.pl2reader.pl2_file_info.m_StartRecordingTime
487-
else:
488-
time_mask = np.ones_like(event_times)
473+
time_mask = self._get_timestamp_time_mask(t_start, t_stop, event_times)
489474

490475
# events don't have a duration. Epochs are not supported
491476
durations = None

0 commit comments

Comments
 (0)