67
67
constructions of a RawIO for a given set of data.
68
68
69
69
"""
70
+ from __future__ import annotations
70
71
71
72
import logging
72
73
import numpy as np
@@ -133,7 +134,7 @@ class BaseRawIO:
133
134
134
135
rawmode = None # one key from possible_raw_modes
135
136
136
- def __init__ (self , use_cache = False , cache_path = 'same_as_resource' , ** kargs ):
137
+ def __init__ (self , use_cache : bool = False , cache_path : str = 'same_as_resource' , ** kargs ):
137
138
"""
138
139
:TODO: Why multi-file would have a single filename is confusing here - shouldn't
139
140
the name of this argument be filenames_list or filenames_base or similar?
@@ -369,7 +370,7 @@ def block_count(self):
369
370
"""return number of blocks"""
370
371
return self .header ['nb_block' ]
371
372
372
- def segment_count (self , block_index ):
373
+ def segment_count (self , block_index : int ):
373
374
"""return number of segments for a given block"""
374
375
return self .header ['nb_segment' ][block_index ]
375
376
@@ -379,7 +380,7 @@ def signal_streams_count(self):
379
380
"""
380
381
return len (self .header ['signal_streams' ])
381
382
382
- def signal_channels_count (self , stream_index ):
383
+ def signal_channels_count (self , stream_index : int ):
383
384
"""Return the number of signal channels for a given stream.
384
385
This number is the same for all Blocks and Segments.
385
386
"""
@@ -400,7 +401,7 @@ def event_channels_count(self):
400
401
"""
401
402
return len (self .header ['event_channels' ])
402
403
403
- def segment_t_start (self , block_index , seg_index ):
404
+ def segment_t_start (self , block_index : int , seg_index : int ):
404
405
"""Global t_start of a Segment in s. Shared by all objects except
405
406
for AnalogSignal.
406
407
"""
@@ -445,7 +446,7 @@ def _check_stream_signal_channel_characteristics(self):
445
446
446
447
self ._several_channel_groups = signal_streams .size > 1
447
448
448
- def channel_name_to_index (self , stream_index , channel_names ):
449
+ def channel_name_to_index (self , stream_index : int , channel_names : list [ str ] ):
449
450
"""
450
451
Inside a stream, transform channel_names to channel_indexes.
451
452
Based on self.header['signal_channels']
@@ -459,7 +460,7 @@ def channel_name_to_index(self, stream_index, channel_names):
459
460
channel_indexes = np .array ([chan_names .index (name ) for name in channel_names ])
460
461
return channel_indexes
461
462
462
- def channel_id_to_index (self , stream_index , channel_ids ):
463
+ def channel_id_to_index (self , stream_index : int , channel_ids : list [ str ] ):
463
464
"""
464
465
Inside a stream, transform channel_ids to channel_indexes.
465
466
Based on self.header['signal_channels']
@@ -473,7 +474,11 @@ def channel_id_to_index(self, stream_index, channel_ids):
473
474
channel_indexes = np .array ([chan_ids .index (chan_id ) for chan_id in channel_ids ])
474
475
return channel_indexes
475
476
476
- def _get_channel_indexes (self , stream_index , channel_indexes , channel_names , channel_ids ):
477
+ def _get_channel_indexes (self ,
478
+ stream_index : int ,
479
+ channel_indexes : list [int ] | None ,
480
+ channel_names : list [str ] | None ,
481
+ channel_ids : list [str ] | None ):
477
482
"""
478
483
Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids
479
484
depending which is not None.
@@ -484,7 +489,7 @@ def _get_channel_indexes(self, stream_index, channel_indexes, channel_names, cha
484
489
channel_indexes = self .channel_id_to_index (stream_index , channel_ids )
485
490
return channel_indexes
486
491
487
- def _get_stream_index_from_arg (self , stream_index_arg ):
492
+ def _get_stream_index_from_arg (self , stream_index_arg : int | None ):
488
493
if stream_index_arg is None :
489
494
assert self .header ['signal_streams' ].size == 1
490
495
stream_index = 0
@@ -493,7 +498,7 @@ def _get_stream_index_from_arg(self, stream_index_arg):
493
498
stream_index = stream_index_arg
494
499
return stream_index
495
500
496
- def get_signal_size (self , block_index , seg_index , stream_index = None ):
501
+ def get_signal_size (self , block_index : int , seg_index : int , stream_index : int | None = None ):
497
502
"""
498
503
Retrieve the length of a single section of the channels in a stream.
499
504
:param block_index:
@@ -504,7 +509,10 @@ def get_signal_size(self, block_index, seg_index, stream_index=None):
504
509
stream_index = self ._get_stream_index_from_arg (stream_index )
505
510
return self ._get_signal_size (block_index , seg_index , stream_index )
506
511
507
- def get_signal_t_start (self , block_index , seg_index , stream_index = None ):
512
+ def get_signal_t_start (self ,
513
+ block_index : int ,
514
+ seg_index : int ,
515
+ stream_index : int | None = None ):
508
516
"""
509
517
Retrieve the t_start of a single section of the channels in a stream.
510
518
:param block_index:
@@ -515,7 +523,7 @@ def get_signal_t_start(self, block_index, seg_index, stream_index=None):
515
523
stream_index = self ._get_stream_index_from_arg (stream_index )
516
524
return self ._get_signal_t_start (block_index , seg_index , stream_index )
517
525
518
- def get_signal_sampling_rate (self , stream_index = None ):
526
+ def get_signal_sampling_rate (self , stream_index : int | None = None ):
519
527
"""
520
528
Retrieve sampling rate for a stream and all channels in that stream.
521
529
:param stream_index:
@@ -528,9 +536,16 @@ def get_signal_sampling_rate(self, stream_index=None):
528
536
sr = signal_channels [0 ]['sampling_rate' ]
529
537
return float (sr )
530
538
531
- def get_analogsignal_chunk (self , block_index = 0 , seg_index = 0 , i_start = None , i_stop = None ,
532
- stream_index = None , channel_indexes = None , channel_names = None ,
533
- channel_ids = None , prefer_slice = False ):
539
+ def get_analogsignal_chunk (self ,
540
+ block_index : int = 0 ,
541
+ seg_index : int = 0 ,
542
+ i_start : int | None = None ,
543
+ i_stop : int | None = None ,
544
+ stream_index : int | None = None ,
545
+ channel_indexes : list [int ] | None = None ,
546
+ channel_names : list [str ] | None = None ,
547
+ channel_ids : list [str ] | None = None ,
548
+ prefer_slice : bool = False ):
534
549
"""
535
550
Return a chunk of raw signal as a Numpy array. columns correspond to samples from a
536
551
section of a single channel of recording. The channels are chosen either by channel_names,
@@ -587,8 +602,13 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto
587
602
588
603
return raw_chunk
589
604
590
- def rescale_signal_raw_to_float (self , raw_signal , dtype = 'float32' , stream_index = None ,
591
- channel_indexes = None , channel_names = None , channel_ids = None ):
605
+ def rescale_signal_raw_to_float (self ,
606
+ raw_signal : np .ndarray ,
607
+ dtype : np .dtype = 'float32' ,
608
+ stream_index : int | None = None ,
609
+ channel_indexes : list [int ] | None = None ,
610
+ channel_names : list [str ] | None = None ,
611
+ channel_ids : list [str ] | None = None ):
592
612
"""
593
613
Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
594
614
returned by a call to get_analogsignal_chunk. The channels are specified either by
@@ -627,11 +647,15 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=
627
647
return float_signal
628
648
629
649
# spiketrain and unit zone
630
- def spike_count (self , block_index = 0 , seg_index = 0 , spike_channel_index = 0 ):
650
+ def spike_count (self , block_index : int = 0 , seg_index : int = 0 , spike_channel_index : int = 0 ):
631
651
return self ._spike_count (block_index , seg_index , spike_channel_index )
632
652
633
- def get_spike_timestamps (self , block_index = 0 , seg_index = 0 , spike_channel_index = 0 ,
634
- t_start = None , t_stop = None ):
653
+ def get_spike_timestamps (self ,
654
+ block_index : int = 0 ,
655
+ seg_index : int = 0 ,
656
+ spike_channel_index : int = 0 ,
657
+ t_start : float | None = None ,
658
+ t_stop : float | None = None ):
635
659
"""
636
660
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
637
661
Sometimes it is the index on the signal but not always.
@@ -643,21 +667,25 @@ def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0
643
667
spike_channel_index , t_start , t_stop )
644
668
return timestamp
645
669
646
- def rescale_spike_timestamp (self , spike_timestamps , dtype = 'float64' ):
670
+ def rescale_spike_timestamp (self , spike_timestamps : np . ndarray , dtype : np . dtype = 'float64' ):
647
671
"""
648
672
Rescale spike timestamps to seconds.
649
673
"""
650
674
return self ._rescale_spike_timestamp (spike_timestamps , dtype )
651
675
652
676
# spiketrain waveform zone
653
- def get_spike_raw_waveforms (self , block_index = 0 , seg_index = 0 , spike_channel_index = 0 ,
654
- t_start = None , t_stop = None ):
677
+ def get_spike_raw_waveforms (self ,
678
+ block_index : int = 0 ,
679
+ seg_index : int = 0 ,
680
+ spike_channel_index : int = 0 ,
681
+ t_start : float | None = None ,
682
+ t_stop : float | None = None ):
655
683
wf = self ._get_spike_raw_waveforms (block_index , seg_index ,
656
684
spike_channel_index , t_start , t_stop )
657
685
return wf
658
686
659
- def rescale_waveforms_to_float (self , raw_waveforms , dtype = 'float32' ,
660
- spike_channel_index = 0 ):
687
+ def rescale_waveforms_to_float (self , raw_waveforms : np . ndarray , dtype : np . dtype = 'float32' ,
688
+ spike_channel_index : int = 0 ):
661
689
wf_gain = self .header ['spike_channels' ]['wf_gain' ][spike_channel_index ]
662
690
wf_offset = self .header ['spike_channels' ]['wf_offset' ][spike_channel_index ]
663
691
@@ -671,11 +699,15 @@ def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32',
671
699
return float_waveforms
672
700
673
701
# event and epoch zone
674
- def event_count (self , block_index = 0 , seg_index = 0 , event_channel_index = 0 ):
702
+ def event_count (self , block_index : int = 0 , seg_index : int = 0 , event_channel_index : int = 0 ):
675
703
return self ._event_count (block_index , seg_index , event_channel_index )
676
704
677
- def get_event_timestamps (self , block_index = 0 , seg_index = 0 , event_channel_index = 0 ,
678
- t_start = None , t_stop = None ):
705
+ def get_event_timestamps (self ,
706
+ block_index : int = 0 ,
707
+ seg_index : int = 0 ,
708
+ event_channel_index : int = 0 ,
709
+ t_start : float | None = None ,
710
+ t_stop : float | None = None ):
679
711
"""
680
712
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
681
713
Sometimes it is the index on the signal but not always.
@@ -693,21 +725,23 @@ def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0
693
725
block_index , seg_index , event_channel_index , t_start , t_stop )
694
726
return timestamp , durations , labels
695
727
696
- def rescale_event_timestamp (self , event_timestamps , dtype = 'float64' ,
697
- event_channel_index = 0 ):
728
+ def rescale_event_timestamp (self ,
729
+ event_timestamps : np .ndarray ,
730
+ dtype : np .dtype = 'float64' ,
731
+ event_channel_index : int = 0 ):
698
732
"""
699
733
Rescale event timestamps to seconds.
700
734
"""
701
735
return self ._rescale_event_timestamp (event_timestamps , dtype , event_channel_index )
702
736
703
- def rescale_epoch_duration (self , raw_duration , dtype = 'float64' ,
704
- event_channel_index = 0 ):
737
+ def rescale_epoch_duration (self , raw_duration : np . ndarray , dtype : np . dtype = 'float64' ,
738
+ event_channel_index : int = 0 ):
705
739
"""
706
740
Rescale epoch raw duration to seconds.
707
741
"""
708
742
return self ._rescale_epoch_duration (raw_duration , dtype , event_channel_index )
709
743
710
- def setup_cache (self , cache_path , ** init_kargs ):
744
+ def setup_cache (self , cache_path : 'home' | 'same_as_resource' , ** init_kargs ):
711
745
try :
712
746
import joblib
713
747
except ImportError :
@@ -735,7 +769,7 @@ def setup_cache(self, cache_path, **init_kargs):
735
769
dirname = os .path .dirname (resource_name )
736
770
else :
737
771
assert os .path .exists (cache_path ), \
738
- 'cache_path do not exists use "home" or "same_as_resource" to make this auto'
772
+ 'cache_path does not exists use "home" or "same_as_resource" to make this auto'
739
773
740
774
# the hash of the resource (dir of file) is done with filename+datetime
741
775
# TODO make something more sophisticated when rawmode='one-dir' that use all
@@ -776,32 +810,37 @@ def _parse_header(self):
776
810
def _source_name (self ):
777
811
raise (NotImplementedError )
778
812
779
- def _segment_t_start (self , block_index , seg_index ):
813
+ def _segment_t_start (self , block_index : int , seg_index : int ):
780
814
raise (NotImplementedError )
781
815
782
- def _segment_t_stop (self , block_index , seg_index ):
816
+ def _segment_t_stop (self , block_index : int , seg_index : int ):
783
817
raise (NotImplementedError )
784
818
785
819
###
786
820
# signal and channel zone
787
- def _get_signal_size (self , block_index , seg_index , stream_index ):
821
+ def _get_signal_size (self , block_index : int , seg_index : int , stream_index : int ):
788
822
"""
789
823
Return the size of a set of AnalogSignals indexed by channel_indexes.
790
824
791
825
All channels indexed must have the same size and t_start.
792
826
"""
793
827
raise (NotImplementedError )
794
828
795
- def _get_signal_t_start (self , block_index , seg_index , stream_index ):
829
+ def _get_signal_t_start (self , block_index : int , seg_index : int , stream_index : int ):
796
830
"""
797
831
Return the t_start of a set of AnalogSignals indexed by channel_indexes.
798
832
799
833
All channels indexed must have the same size and t_start.
800
834
"""
801
835
raise (NotImplementedError )
802
836
803
- def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop ,
804
- stream_index , channel_indexes ):
837
+ def _get_analogsignal_chunk (self ,
838
+ block_index : int ,
839
+ seg_index : int ,
840
+ i_start : int | None ,
841
+ i_stop : int | None ,
842
+ stream_index : int ,
843
+ channel_indexes : list [int ] | None ):
805
844
"""
806
845
Return the samples from a set of AnalogSignals indexed
807
846
by stream_index and channel_indexes (local index inner stream).
@@ -815,38 +854,51 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
815
854
816
855
###
817
856
# spiketrain and unit zone
818
- def _spike_count (self , block_index , seg_index , spike_channel_index ):
857
+ def _spike_count (self , block_index : int , seg_index : int , spike_channel_index : int ):
819
858
raise (NotImplementedError )
820
859
821
- def _get_spike_timestamps (self , block_index , seg_index ,
822
- spike_channel_index , t_start , t_stop ):
860
+ def _get_spike_timestamps (self ,
861
+ block_index : int ,
862
+ seg_index : int ,
863
+ spike_channel_index : int ,
864
+ t_start : float | None ,
865
+ t_stop : float | None ):
823
866
raise (NotImplementedError )
824
867
825
- def _rescale_spike_timestamp (self , spike_timestamps , dtype ):
868
+ def _rescale_spike_timestamp (self , spike_timestamps : np . ndarray , dtype : np . dtype ):
826
869
raise (NotImplementedError )
827
870
828
871
###
829
872
# spike waveforms zone
830
- def _get_spike_raw_waveforms (self , block_index , seg_index ,
831
- spike_channel_index , t_start , t_stop ):
873
+ def _get_spike_raw_waveforms (self ,
874
+ block_index : int ,
875
+ seg_index : int ,
876
+ spike_channel_index : int ,
877
+ t_start : float | None ,
878
+ t_stop : float | None ):
832
879
raise (NotImplementedError )
833
880
834
881
###
835
882
# event and epoch zone
836
- def _event_count (self , block_index , seg_index , event_channel_index ):
883
+ def _event_count (self , block_index : int , seg_index : int , event_channel_index : int ):
837
884
raise (NotImplementedError )
838
885
839
- def _get_event_timestamps (self , block_index , seg_index , event_channel_index , t_start , t_stop ):
886
+ def _get_event_timestamps (self ,
887
+ block_index : int ,
888
+ seg_index : int ,
889
+ event_channel_index : int ,
890
+ t_start : float | None ,
891
+ t_stop : float | None ):
840
892
raise (NotImplementedError )
841
893
842
- def _rescale_event_timestamp (self , event_timestamps , dtype ):
894
+ def _rescale_event_timestamp (self , event_timestamps : np . ndarray , dtype : np . dtype ):
843
895
raise (NotImplementedError )
844
896
845
- def _rescale_epoch_duration (self , raw_duration , dtype ):
897
+ def _rescale_epoch_duration (self , raw_duration : np . ndarray , dtype : np . dtype ):
846
898
raise (NotImplementedError )
847
899
848
900
849
- def pprint_vector (vector , lim = 8 ):
901
+ def pprint_vector (vector , lim : int = 8 ):
850
902
vector = np .asarray (vector )
851
903
assert vector .ndim == 1
852
904
if len (vector ) > lim :
0 commit comments