@@ -380,19 +380,7 @@ def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index, t_s
380
380
381
381
spike_timestamps , unit_ids , waveforms = self ._spike_channel_cache [channel_name ]
382
382
383
- if t_start is not None or t_stop is not None :
384
- # restrict spikes to given limits (in seconds)
385
- timestamp_frequency = self .pl2reader .pl2_file_info .m_TimestampFrequency
386
- lim0 = int (t_start * timestamp_frequency )
387
- lim1 = int (t_stop * self .pl2reader .pl2_file_info .m_TimestampFrequency )
388
-
389
- # limits are with respect to segment t_start and not to time 0
390
- lim0 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
391
- lim1 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
392
-
393
- time_mask = (spike_timestamps >= lim0 ) & (spike_timestamps <= lim1 )
394
- else :
395
- time_mask = slice (None , None )
383
+ time_mask = self ._get_timestamp_time_mask (t_start , t_stop , spike_timestamps )
396
384
397
385
unit_mask = unit_ids [time_mask ] == channel_unit_id
398
386
spike_timestamps = spike_timestamps [time_mask ][unit_mask ]
@@ -425,25 +413,33 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
425
413
426
414
spike_timestamps , unit_ids , waveforms = self ._spike_channel_cache [channel_name ]
427
415
416
+ time_mask = self ._get_timestamp_time_mask (t_start , t_stop , spike_timestamps )
417
+
418
+ unit_mask = unit_ids [time_mask ] == int (channel_unit_id )
419
+ waveforms = waveforms [time_mask ][unit_mask ]
420
+
421
+ # add tetrode dimension
422
+ waveforms = np .expand_dims (waveforms , axis = 1 )
423
+ return waveforms
424
+
425
+ def _get_timestamp_time_mask (self , t_start , t_stop , timestamps ):
426
+
428
427
if t_start is not None or t_stop is not None :
429
428
# restrict spikes to given limits (in seconds)
430
429
timestamp_frequency = self .pl2reader .pl2_file_info .m_TimestampFrequency
431
430
lim0 = int (t_start * timestamp_frequency )
432
431
lim1 = int (t_stop * self .pl2reader .pl2_file_info .m_TimestampFrequency )
433
- time_mask = (spike_timestamps >= lim0 ) & (spike_timestamps <= lim1 )
434
432
435
433
# limits are with respect to segment t_start and not to time 0
436
434
lim0 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
437
435
lim1 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
436
+
437
+ time_mask = (timestamps >= lim0 ) & (timestamps <= lim1 )
438
+
438
439
else :
439
440
time_mask = slice (None , None )
440
441
441
- unit_mask = unit_ids [time_mask ] == int (channel_unit_id )
442
- waveforms = waveforms [time_mask ][unit_mask ]
443
-
444
- # add tetrode dimension
445
- waveforms = np .expand_dims (waveforms , axis = 1 )
446
- return waveforms
442
+ return time_mask
447
443
448
444
def _event_count (self , block_index , seg_index , event_channel_index ):
449
445
@@ -474,18 +470,7 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
474
470
event_times , labels = self ._event_channel_cache [channel_name ]
475
471
labels = np .asarray (labels , dtype = 'U' )
476
472
477
- if t_start is not None or t_stop is not None :
478
- # restrict events to given limits (in seconds)
479
- timestamp_frequency = self .pl2reader .pl2_file_info .m_TimestampFrequency
480
- lim0 = int (t_start * timestamp_frequency )
481
- lim1 = int (t_stop * self .pl2reader .pl2_file_info .m_TimestampFrequency )
482
- time_mask = (event_times >= lim0 ) & (event_times <= lim1 )
483
-
484
- # limits are with respect to segment t_start and not to time 0
485
- lim0 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
486
- lim1 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
487
- else :
488
- time_mask = np .ones_like (event_times )
473
+ time_mask = self ._get_timestamp_time_mask (t_start , t_stop , event_times )
489
474
490
475
# events don't have a duration. Epochs are not supported
491
476
durations = None
0 commit comments