Skip to content

Commit 32c3b99

Browse files
authored
Merge pull request #1334 from zm711/typing
Add Typing to `baseio`, `baserawio`, and `exampleio`
2 parents fdf6013 + 89091de commit 32c3b99

File tree

3 files changed

+151
-69
lines changed

3 files changed

+151
-69
lines changed

neo/io/baseio.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
1111
If you want a model for developing a new IO start from exampleIO.
1212
"""
13+
from __future__ import annotations
14+
from pathlib import Path
1315

1416
try:
1517
from collections.abc import Sequence
@@ -96,7 +98,7 @@ class BaseIO:
9698

9799
mode = 'file' # or 'fake' or 'dir' or 'database'
98100

99-
def __init__(self, filename=None, **kargs):
101+
def __init__(self, filename: str | Path = None, **kargs):
100102
self.filename = str(filename)
101103
# create a logger for the IO class
102104
fullname = self.__class__.__module__ + '.' + self.__class__.__name__
@@ -111,7 +113,7 @@ def __init__(self, filename=None, **kargs):
111113
corelogger.addHandler(logging_handler)
112114

113115
######## General read/write methods #######################
114-
def read(self, lazy=False, **kargs):
116+
def read(self, lazy: bool = False, **kargs):
115117
"""
116118
Return all data from the file as a list of Blocks
117119
"""

neo/rawio/baserawio.py

+102-50
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
constructions of a RawIO for a given set of data.
6868
6969
"""
70+
from __future__ import annotations
7071

7172
import logging
7273
import numpy as np
@@ -133,7 +134,7 @@ class BaseRawIO:
133134

134135
rawmode = None # one key from possible_raw_modes
135136

136-
def __init__(self, use_cache=False, cache_path='same_as_resource', **kargs):
137+
def __init__(self, use_cache: bool = False, cache_path: str = 'same_as_resource', **kargs):
137138
"""
138139
:TODO: Why multi-file would have a single filename is confusing here - shouldn't
139140
the name of this argument be filenames_list or filenames_base or similar?
@@ -369,7 +370,7 @@ def block_count(self):
369370
"""return number of blocks"""
370371
return self.header['nb_block']
371372

372-
def segment_count(self, block_index):
373+
def segment_count(self, block_index: int):
373374
"""return number of segments for a given block"""
374375
return self.header['nb_segment'][block_index]
375376

@@ -379,7 +380,7 @@ def signal_streams_count(self):
379380
"""
380381
return len(self.header['signal_streams'])
381382

382-
def signal_channels_count(self, stream_index):
383+
def signal_channels_count(self, stream_index: int):
383384
"""Return the number of signal channels for a given stream.
384385
This number is the same for all Blocks and Segments.
385386
"""
@@ -400,7 +401,7 @@ def event_channels_count(self):
400401
"""
401402
return len(self.header['event_channels'])
402403

403-
def segment_t_start(self, block_index, seg_index):
404+
def segment_t_start(self, block_index: int, seg_index: int):
404405
"""Global t_start of a Segment in s. Shared by all objects except
405406
for AnalogSignal.
406407
"""
@@ -445,7 +446,7 @@ def _check_stream_signal_channel_characteristics(self):
445446

446447
self._several_channel_groups = signal_streams.size > 1
447448

448-
def channel_name_to_index(self, stream_index, channel_names):
449+
def channel_name_to_index(self, stream_index: int, channel_names: list[str]):
449450
"""
450451
Inside a stream, transform channel_names to channel_indexes.
451452
Based on self.header['signal_channels']
@@ -459,7 +460,7 @@ def channel_name_to_index(self, stream_index, channel_names):
459460
channel_indexes = np.array([chan_names.index(name) for name in channel_names])
460461
return channel_indexes
461462

462-
def channel_id_to_index(self, stream_index, channel_ids):
463+
def channel_id_to_index(self, stream_index: int, channel_ids: list[str]):
463464
"""
464465
Inside a stream, transform channel_ids to channel_indexes.
465466
Based on self.header['signal_channels']
@@ -473,7 +474,11 @@ def channel_id_to_index(self, stream_index, channel_ids):
473474
channel_indexes = np.array([chan_ids.index(chan_id) for chan_id in channel_ids])
474475
return channel_indexes
475476

476-
def _get_channel_indexes(self, stream_index, channel_indexes, channel_names, channel_ids):
477+
def _get_channel_indexes(self,
478+
stream_index: int,
479+
channel_indexes: list[int] | None,
480+
channel_names: list[str] | None,
481+
channel_ids: list[str] | None):
477482
"""
478483
Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids
479484
depending which is not None.
@@ -484,7 +489,7 @@ def _get_channel_indexes(self, stream_index, channel_indexes, channel_names, cha
484489
channel_indexes = self.channel_id_to_index(stream_index, channel_ids)
485490
return channel_indexes
486491

487-
def _get_stream_index_from_arg(self, stream_index_arg):
492+
def _get_stream_index_from_arg(self, stream_index_arg: int | None):
488493
if stream_index_arg is None:
489494
assert self.header['signal_streams'].size == 1
490495
stream_index = 0
@@ -493,7 +498,7 @@ def _get_stream_index_from_arg(self, stream_index_arg):
493498
stream_index = stream_index_arg
494499
return stream_index
495500

496-
def get_signal_size(self, block_index, seg_index, stream_index=None):
501+
def get_signal_size(self, block_index: int, seg_index: int, stream_index: int | None = None):
497502
"""
498503
Retrieve the length of a single section of the channels in a stream.
499504
:param block_index:
@@ -504,7 +509,10 @@ def get_signal_size(self, block_index, seg_index, stream_index=None):
504509
stream_index = self._get_stream_index_from_arg(stream_index)
505510
return self._get_signal_size(block_index, seg_index, stream_index)
506511

507-
def get_signal_t_start(self, block_index, seg_index, stream_index=None):
512+
def get_signal_t_start(self,
513+
block_index: int,
514+
seg_index: int,
515+
stream_index: int | None = None):
508516
"""
509517
Retrieve the t_start of a single section of the channels in a stream.
510518
:param block_index:
@@ -515,7 +523,7 @@ def get_signal_t_start(self, block_index, seg_index, stream_index=None):
515523
stream_index = self._get_stream_index_from_arg(stream_index)
516524
return self._get_signal_t_start(block_index, seg_index, stream_index)
517525

518-
def get_signal_sampling_rate(self, stream_index=None):
526+
def get_signal_sampling_rate(self, stream_index: int | None = None):
519527
"""
520528
Retrieve sampling rate for a stream and all channels in that stream.
521529
:param stream_index:
@@ -528,9 +536,16 @@ def get_signal_sampling_rate(self, stream_index=None):
528536
sr = signal_channels[0]['sampling_rate']
529537
return float(sr)
530538

531-
def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_stop=None,
532-
stream_index=None, channel_indexes=None, channel_names=None,
533-
channel_ids=None, prefer_slice=False):
539+
def get_analogsignal_chunk(self,
540+
block_index: int = 0,
541+
seg_index: int = 0,
542+
i_start: int | None = None,
543+
i_stop: int | None = None,
544+
stream_index: int | None = None,
545+
channel_indexes: list[int] | None = None,
546+
channel_names: list[str] | None = None,
547+
channel_ids: list[str] | None = None,
548+
prefer_slice: bool = False):
534549
"""
535550
Return a chunk of raw signal as a Numpy array. columns correspond to samples from a
536551
section of a single channel of recording. The channels are chosen either by channel_names,
@@ -587,8 +602,13 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto
587602

588603
return raw_chunk
589604

590-
def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=None,
591-
channel_indexes=None, channel_names=None, channel_ids=None):
605+
def rescale_signal_raw_to_float(self,
606+
raw_signal: np.ndarray,
607+
dtype: np.dtype = 'float32',
608+
stream_index: int | None = None,
609+
channel_indexes: list[int] | None = None,
610+
channel_names: list[str] | None = None,
611+
channel_ids: list[str] | None = None):
592612
"""
593613
Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
594614
returned by a call to get_analogsignal_chunk. The channels are specified either by
@@ -627,11 +647,15 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=
627647
return float_signal
628648

629649
# spiketrain and unit zone
630-
def spike_count(self, block_index=0, seg_index=0, spike_channel_index=0):
650+
def spike_count(self, block_index: int = 0, seg_index: int = 0, spike_channel_index: int = 0):
631651
return self._spike_count(block_index, seg_index, spike_channel_index)
632652

633-
def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0,
634-
t_start=None, t_stop=None):
653+
def get_spike_timestamps(self,
654+
block_index: int = 0,
655+
seg_index: int = 0,
656+
spike_channel_index: int = 0,
657+
t_start: float | None = None,
658+
t_stop: float | None = None):
635659
"""
636660
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
637661
Sometimes it is the index on the signal but not always.
@@ -643,21 +667,25 @@ def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0
643667
spike_channel_index, t_start, t_stop)
644668
return timestamp
645669

646-
def rescale_spike_timestamp(self, spike_timestamps, dtype='float64'):
670+
def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype = 'float64'):
647671
"""
648672
Rescale spike timestamps to seconds.
649673
"""
650674
return self._rescale_spike_timestamp(spike_timestamps, dtype)
651675

652676
# spiketrain waveform zone
653-
def get_spike_raw_waveforms(self, block_index=0, seg_index=0, spike_channel_index=0,
654-
t_start=None, t_stop=None):
677+
def get_spike_raw_waveforms(self,
678+
block_index: int = 0,
679+
seg_index: int = 0,
680+
spike_channel_index: int = 0,
681+
t_start: float | None = None,
682+
t_stop: float | None = None):
655683
wf = self._get_spike_raw_waveforms(block_index, seg_index,
656684
spike_channel_index, t_start, t_stop)
657685
return wf
658686

659-
def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32',
660-
spike_channel_index=0):
687+
def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype = 'float32',
688+
spike_channel_index: int = 0):
661689
wf_gain = self.header['spike_channels']['wf_gain'][spike_channel_index]
662690
wf_offset = self.header['spike_channels']['wf_offset'][spike_channel_index]
663691

@@ -671,11 +699,15 @@ def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32',
671699
return float_waveforms
672700

673701
# event and epoch zone
674-
def event_count(self, block_index=0, seg_index=0, event_channel_index=0):
702+
def event_count(self, block_index: int = 0, seg_index: int = 0, event_channel_index: int = 0):
675703
return self._event_count(block_index, seg_index, event_channel_index)
676704

677-
def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0,
678-
t_start=None, t_stop=None):
705+
def get_event_timestamps(self,
706+
block_index: int = 0,
707+
seg_index: int = 0,
708+
event_channel_index: int = 0,
709+
t_start: float | None = None,
710+
t_stop: float | None = None):
679711
"""
680712
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
681713
Sometimes it is the index on the signal but not always.
@@ -693,21 +725,23 @@ def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0
693725
block_index, seg_index, event_channel_index, t_start, t_stop)
694726
return timestamp, durations, labels
695727

696-
def rescale_event_timestamp(self, event_timestamps, dtype='float64',
697-
event_channel_index=0):
728+
def rescale_event_timestamp(self,
729+
event_timestamps: np.ndarray,
730+
dtype: np.dtype = 'float64',
731+
event_channel_index: int = 0):
698732
"""
699733
Rescale event timestamps to seconds.
700734
"""
701735
return self._rescale_event_timestamp(event_timestamps, dtype, event_channel_index)
702736

703-
def rescale_epoch_duration(self, raw_duration, dtype='float64',
704-
event_channel_index=0):
737+
def rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype = 'float64',
738+
event_channel_index: int = 0):
705739
"""
706740
Rescale epoch raw duration to seconds.
707741
"""
708742
return self._rescale_epoch_duration(raw_duration, dtype, event_channel_index)
709743

710-
def setup_cache(self, cache_path, **init_kargs):
744+
def setup_cache(self, cache_path: 'home' | 'same_as_resource', **init_kargs):
711745
try:
712746
import joblib
713747
except ImportError:
@@ -735,7 +769,7 @@ def setup_cache(self, cache_path, **init_kargs):
735769
dirname = os.path.dirname(resource_name)
736770
else:
737771
assert os.path.exists(cache_path), \
738-
'cache_path do not exists use "home" or "same_as_resource" to make this auto'
772+
'cache_path does not exists use "home" or "same_as_resource" to make this auto'
739773

740774
# the hash of the resource (dir of file) is done with filename+datetime
741775
# TODO make something more sophisticated when rawmode='one-dir' that use all
@@ -776,32 +810,37 @@ def _parse_header(self):
776810
def _source_name(self):
777811
raise (NotImplementedError)
778812

779-
def _segment_t_start(self, block_index, seg_index):
813+
def _segment_t_start(self, block_index: int, seg_index: int):
780814
raise (NotImplementedError)
781815

782-
def _segment_t_stop(self, block_index, seg_index):
816+
def _segment_t_stop(self, block_index: int, seg_index: int):
783817
raise (NotImplementedError)
784818

785819
###
786820
# signal and channel zone
787-
def _get_signal_size(self, block_index, seg_index, stream_index):
821+
def _get_signal_size(self, block_index: int, seg_index: int, stream_index: int):
788822
"""
789823
Return the size of a set of AnalogSignals indexed by channel_indexes.
790824
791825
All channels indexed must have the same size and t_start.
792826
"""
793827
raise (NotImplementedError)
794828

795-
def _get_signal_t_start(self, block_index, seg_index, stream_index):
829+
def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int):
796830
"""
797831
Return the t_start of a set of AnalogSignals indexed by channel_indexes.
798832
799833
All channels indexed must have the same size and t_start.
800834
"""
801835
raise (NotImplementedError)
802836

803-
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
804-
stream_index, channel_indexes):
837+
def _get_analogsignal_chunk(self,
838+
block_index: int,
839+
seg_index: int,
840+
i_start: int | None,
841+
i_stop: int | None,
842+
stream_index: int,
843+
channel_indexes: list[int] | None):
805844
"""
806845
Return the samples from a set of AnalogSignals indexed
807846
by stream_index and channel_indexes (local index inner stream).
@@ -815,38 +854,51 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
815854

816855
###
817856
# spiketrain and unit zone
818-
def _spike_count(self, block_index, seg_index, spike_channel_index):
857+
def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: int):
819858
raise (NotImplementedError)
820859

821-
def _get_spike_timestamps(self, block_index, seg_index,
822-
spike_channel_index, t_start, t_stop):
860+
def _get_spike_timestamps(self,
861+
block_index: int,
862+
seg_index: int,
863+
spike_channel_index: int,
864+
t_start: float | None,
865+
t_stop: float | None):
823866
raise (NotImplementedError)
824867

825-
def _rescale_spike_timestamp(self, spike_timestamps, dtype):
868+
def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype):
826869
raise (NotImplementedError)
827870

828871
###
829872
# spike waveforms zone
830-
def _get_spike_raw_waveforms(self, block_index, seg_index,
831-
spike_channel_index, t_start, t_stop):
873+
def _get_spike_raw_waveforms(self,
874+
block_index: int,
875+
seg_index: int,
876+
spike_channel_index: int,
877+
t_start: float | None,
878+
t_stop: float | None):
832879
raise (NotImplementedError)
833880

834881
###
835882
# event and epoch zone
836-
def _event_count(self, block_index, seg_index, event_channel_index):
883+
def _event_count(self, block_index: int, seg_index: int, event_channel_index: int):
837884
raise (NotImplementedError)
838885

839-
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
886+
def _get_event_timestamps(self,
887+
block_index: int,
888+
seg_index: int,
889+
event_channel_index: int,
890+
t_start: float | None,
891+
t_stop: float | None):
840892
raise (NotImplementedError)
841893

842-
def _rescale_event_timestamp(self, event_timestamps, dtype):
894+
def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype):
843895
raise (NotImplementedError)
844896

845-
def _rescale_epoch_duration(self, raw_duration, dtype):
897+
def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype):
846898
raise (NotImplementedError)
847899

848900

849-
def pprint_vector(vector, lim=8):
901+
def pprint_vector(vector, lim: int = 8):
850902
vector = np.asarray(vector)
851903
assert vector.ndim == 1
852904
if len(vector) > lim:

0 commit comments

Comments
 (0)