Skip to content

Commit c528a89

Browse files
authored
Merge pull request #1179 from effigies/type/primary_image_api
ENH: Annotate SpatialImage API, improve superclass annotations
2 parents 7327927 + a253451 commit c528a89

21 files changed

+326
-208
lines changed

nibabel/analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1064,5 +1064,5 @@ def to_file_map(self, file_map=None, dtype=None):
10641064
hdr['scl_inter'] = inter
10651065

10661066

1067-
load = AnalyzeImage.load
1067+
load = AnalyzeImage.from_filename
10681068
save = AnalyzeImage.instance_to_filename

nibabel/arrayproxy.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
if ty.TYPE_CHECKING: # pragma: no cover
6060
import numpy.typing as npt
6161

62+
# Taken from numpy/__init__.pyi
63+
_DType = ty.TypeVar('_DType', bound=np.dtype[ty.Any])
64+
6265

6366
class ArrayLike(ty.Protocol):
6467
"""Protocol for numpy ndarray-like objects
@@ -68,9 +71,19 @@ class ArrayLike(ty.Protocol):
6871
"""
6972

7073
shape: tuple[int, ...]
71-
ndim: int
7274

73-
def __array__(self, dtype: npt.DTypeLike | None = None, /) -> npt.NDArray:
75+
@property
76+
def ndim(self) -> int:
77+
... # pragma: no cover
78+
79+
# If no dtype is passed, any dtype might be returned, depending on the array-like
80+
@ty.overload
81+
def __array__(self, dtype: None = ..., /) -> np.ndarray[ty.Any, np.dtype[ty.Any]]:
82+
... # pragma: no cover
83+
84+
# Any dtype might be passed, and *that* dtype must be returned
85+
@ty.overload
86+
def __array__(self, dtype: _DType, /) -> np.ndarray[ty.Any, _DType]:
7487
... # pragma: no cover
7588

7689
def __getitem__(self, key, /) -> npt.NDArray:

nibabel/brikhead.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,4 +564,4 @@ def filespec_to_file_map(klass, filespec):
564564
return file_map
565565

566566

567-
load = AFNIImage.load
567+
load = AFNIImage.from_filename

nibabel/dataobj_images.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
if ty.TYPE_CHECKING: # pragma: no cover
2121
import numpy.typing as npt
2222

23+
ArrayImgT = ty.TypeVar('ArrayImgT', bound='DataobjImage')
24+
2325

2426
class DataobjImage(FileBasedImage):
2527
"""Template class for images that have dataobj data stores"""
2628

2729
_data_cache: np.ndarray | None
28-
_fdata_cache: np.ndarray | None
30+
_fdata_cache: np.ndarray[ty.Any, np.dtype[np.floating]] | None
2931

3032
def __init__(
3133
self,
@@ -222,7 +224,7 @@ def get_fdata(
222224
self,
223225
caching: ty.Literal['fill', 'unchanged'] = 'fill',
224226
dtype: npt.DTypeLike = np.float64,
225-
) -> np.ndarray:
227+
) -> np.ndarray[ty.Any, np.dtype[np.floating]]:
226228
"""Return floating point image data with necessary scaling applied
227229
228230
The image ``dataobj`` property can be an array proxy or an array. An
@@ -421,12 +423,12 @@ def ndim(self) -> int:
421423

422424
@classmethod
423425
def from_file_map(
424-
klass,
426+
klass: type[ArrayImgT],
425427
file_map: FileMap,
426428
*,
427429
mmap: bool | ty.Literal['c', 'r'] = True,
428430
keep_file_open: bool | None = None,
429-
):
431+
) -> ArrayImgT:
430432
"""Class method to create image from mapping in ``file_map``
431433
432434
Parameters
@@ -460,12 +462,12 @@ def from_file_map(
460462

461463
@classmethod
462464
def from_filename(
463-
klass,
465+
klass: type[ArrayImgT],
464466
filename: FileSpec,
465467
*,
466468
mmap: bool | ty.Literal['c', 'r'] = True,
467469
keep_file_open: bool | None = None,
468-
):
470+
) -> ArrayImgT:
469471
"""Class method to create image from filename `filename`
470472
471473
Parameters

nibabel/ecat.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -747,12 +747,14 @@ def __getitem__(self, sliceobj):
747747
class EcatImage(SpatialImage):
748748
"""Class returns a list of Ecat images, with one image(hdr/data) per frame"""
749749

750-
_header = EcatHeader
751-
header_class = _header
750+
header_class = EcatHeader
751+
subheader_class = EcatSubHeader
752752
valid_exts = ('.v',)
753-
_subheader = EcatSubHeader
754753
files_types = (('image', '.v'), ('header', '.v'))
755754

755+
_header: EcatHeader
756+
_subheader: EcatSubHeader
757+
756758
ImageArrayProxy = EcatImageArrayProxy
757759

758760
def __init__(self, dataobj, affine, header, subheader, mlist, extra=None, file_map=None):
@@ -879,14 +881,14 @@ def from_file_map(klass, file_map, *, mmap=True, keep_file_open=None):
879881
hdr_file, img_file = klass._get_fileholders(file_map)
880882
# note header and image are in same file
881883
hdr_fid = hdr_file.get_prepare_fileobj(mode='rb')
882-
header = klass._header.from_fileobj(hdr_fid)
884+
header = klass.header_class.from_fileobj(hdr_fid)
883885
hdr_copy = header.copy()
884886
# LOAD MLIST
885887
mlist = np.zeros((header['num_frames'], 4), dtype=np.int32)
886888
mlist_data = read_mlist(hdr_fid, hdr_copy.endianness)
887889
mlist[: len(mlist_data)] = mlist_data
888890
# LOAD SUBHEADERS
889-
subheaders = klass._subheader(hdr_copy, mlist, hdr_fid)
891+
subheaders = klass.subheader_class(hdr_copy, mlist, hdr_fid)
890892
# LOAD DATA
891893
# Class level ImageArrayProxy
892894
data = klass.ImageArrayProxy(subheaders)

nibabel/filebasedimages.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
FileMap = ty.Mapping[str, FileHolder]
2525
FileSniff = ty.Tuple[bytes, str]
2626

27+
ImgT = ty.TypeVar('ImgT', bound='FileBasedImage')
28+
HdrT = ty.TypeVar('HdrT', bound='FileBasedHeader')
29+
30+
StreamImgT = ty.TypeVar('StreamImgT', bound='SerializableImage')
31+
2732

2833
class ImageFileError(Exception):
2934
pass
@@ -33,7 +38,7 @@ class FileBasedHeader:
3338
"""Template class to implement header protocol"""
3439

3540
@classmethod
36-
def from_header(klass, header=None):
41+
def from_header(klass: type[HdrT], header: FileBasedHeader | ty.Mapping | None = None) -> HdrT:
3742
if header is None:
3843
return klass()
3944
# I can't do isinstance here because it is not necessarily true
@@ -47,19 +52,19 @@ def from_header(klass, header=None):
4752
)
4853

4954
@classmethod
50-
def from_fileobj(klass, fileobj: io.IOBase):
51-
raise NotImplementedError
55+
def from_fileobj(klass: type[HdrT], fileobj: io.IOBase) -> HdrT:
56+
raise NotImplementedError # pragma: no cover
5257

53-
def write_to(self, fileobj: io.IOBase):
54-
raise NotImplementedError
58+
def write_to(self, fileobj: io.IOBase) -> None:
59+
raise NotImplementedError # pragma: no cover
5560

56-
def __eq__(self, other):
57-
raise NotImplementedError
61+
def __eq__(self, other: object) -> bool:
62+
raise NotImplementedError # pragma: no cover
5863

59-
def __ne__(self, other):
64+
def __ne__(self, other: object) -> bool:
6065
return not self == other
6166

62-
def copy(self) -> FileBasedHeader:
67+
def copy(self: HdrT) -> HdrT:
6368
"""Copy object to independent representation
6469
6570
The copy should not be affected by any changes to the original
@@ -153,6 +158,7 @@ class FileBasedImage:
153158
"""
154159

155160
header_class: Type[FileBasedHeader] = FileBasedHeader
161+
_header: FileBasedHeader
156162
_meta_sniff_len: int = 0
157163
files_types: tuple[tuple[str, str | None], ...] = (('image', None),)
158164
valid_exts: tuple[str, ...] = ()
@@ -186,7 +192,7 @@ def __init__(
186192
self._header = self.header_class.from_header(header)
187193
if extra is None:
188194
extra = {}
189-
self.extra = extra
195+
self.extra = dict(extra)
190196

191197
if file_map is None:
192198
file_map = self.__class__.make_file_map()
@@ -196,7 +202,7 @@ def __init__(
196202
def header(self) -> FileBasedHeader:
197203
return self._header
198204

199-
def __getitem__(self, key):
205+
def __getitem__(self, key) -> None:
200206
"""No slicing or dictionary interface for images"""
201207
raise TypeError('Cannot slice image objects.')
202208

@@ -221,7 +227,7 @@ def get_filename(self) -> str | None:
221227
characteristic_type = self.files_types[0][0]
222228
return self.file_map[characteristic_type].filename
223229

224-
def set_filename(self, filename: str):
230+
def set_filename(self, filename: str) -> None:
225231
"""Sets the files in the object from a given filename
226232
227233
The different image formats may check whether the filename has
@@ -239,16 +245,16 @@ def set_filename(self, filename: str):
239245
self.file_map = self.__class__.filespec_to_file_map(filename)
240246

241247
@classmethod
242-
def from_filename(klass, filename: FileSpec):
248+
def from_filename(klass: type[ImgT], filename: FileSpec) -> ImgT:
243249
file_map = klass.filespec_to_file_map(filename)
244250
return klass.from_file_map(file_map)
245251

246252
@classmethod
247-
def from_file_map(klass, file_map: FileMap):
248-
raise NotImplementedError
253+
def from_file_map(klass: type[ImgT], file_map: FileMap) -> ImgT:
254+
raise NotImplementedError # pragma: no cover
249255

250256
@classmethod
251-
def filespec_to_file_map(klass, filespec: FileSpec):
257+
def filespec_to_file_map(klass, filespec: FileSpec) -> FileMap:
252258
"""Make `file_map` for this class from filename `filespec`
253259
254260
Class method
@@ -282,7 +288,7 @@ def filespec_to_file_map(klass, filespec: FileSpec):
282288
file_map[key] = FileHolder(filename=fname)
283289
return file_map
284290

285-
def to_filename(self, filename: FileSpec, **kwargs):
291+
def to_filename(self, filename: FileSpec, **kwargs) -> None:
286292
r"""Write image to files implied by filename string
287293
288294
Parameters
@@ -301,11 +307,11 @@ def to_filename(self, filename: FileSpec, **kwargs):
301307
self.file_map = self.filespec_to_file_map(filename)
302308
self.to_file_map(**kwargs)
303309

304-
def to_file_map(self, file_map: FileMap | None = None, **kwargs):
305-
raise NotImplementedError
310+
def to_file_map(self, file_map: FileMap | None = None, **kwargs) -> None:
311+
raise NotImplementedError # pragma: no cover
306312

307313
@classmethod
308-
def make_file_map(klass, mapping: ty.Mapping[str, str | io.IOBase] | None = None):
314+
def make_file_map(klass, mapping: ty.Mapping[str, str | io.IOBase] | None = None) -> FileMap:
309315
"""Class method to make files holder for this image type
310316
311317
Parameters
@@ -338,7 +344,7 @@ def make_file_map(klass, mapping: ty.Mapping[str, str | io.IOBase] | None = None
338344
load = from_filename
339345

340346
@classmethod
341-
def instance_to_filename(klass, img: FileBasedImage, filename: FileSpec):
347+
def instance_to_filename(klass, img: FileBasedImage, filename: FileSpec) -> None:
342348
"""Save `img` in our own format, to name implied by `filename`
343349
344350
This is a class method
@@ -354,28 +360,28 @@ def instance_to_filename(klass, img: FileBasedImage, filename: FileSpec):
354360
img.to_filename(filename)
355361

356362
@classmethod
357-
def from_image(klass, img: FileBasedImage):
363+
def from_image(klass: type[ImgT], img: FileBasedImage) -> ImgT:
358364
"""Class method to create new instance of own class from `img`
359365
360366
Parameters
361367
----------
362-
img : ``spatialimage`` instance
368+
img : ``FileBasedImage`` instance
363369
In fact, an object with the API of ``FileBasedImage``.
364370
365371
Returns
366372
-------
367-
cimg : ``spatialimage`` instance
373+
img : ``FileBasedImage`` instance
368374
Image, of our own class
369375
"""
370-
raise NotImplementedError()
376+
raise NotImplementedError # pragma: no cover
371377

372378
@classmethod
373379
def _sniff_meta_for(
374380
klass,
375381
filename: FileSpec,
376382
sniff_nbytes: int,
377383
sniff: FileSniff | None = None,
378-
):
384+
) -> FileSniff | None:
379385
"""Sniff metadata for image represented by `filename`
380386
381387
Parameters
@@ -425,7 +431,7 @@ def path_maybe_image(
425431
filename: FileSpec,
426432
sniff: FileSniff | None = None,
427433
sniff_max: int = 1024,
428-
):
434+
) -> tuple[bool, FileSniff | None]:
429435
"""Return True if `filename` may be image matching this class
430436
431437
Parameters
@@ -527,14 +533,14 @@ class SerializableImage(FileBasedImage):
527533
"""
528534

529535
@classmethod
530-
def _filemap_from_iobase(klass, io_obj: io.IOBase):
536+
def _filemap_from_iobase(klass, io_obj: io.IOBase) -> FileMap:
531537
"""For single-file image types, make a file map with the correct key"""
532538
if len(klass.files_types) > 1:
533539
raise NotImplementedError('(de)serialization is undefined for multi-file images')
534540
return klass.make_file_map({klass.files_types[0][0]: io_obj})
535541

536542
@classmethod
537-
def from_stream(klass, io_obj: io.IOBase):
543+
def from_stream(klass: type[StreamImgT], io_obj: io.IOBase) -> StreamImgT:
538544
"""Load image from readable IO stream
539545
540546
Convert to BytesIO to enable seeking, if input stream is not seekable
@@ -548,7 +554,7 @@ def from_stream(klass, io_obj: io.IOBase):
548554
io_obj = io.BytesIO(io_obj.read())
549555
return klass.from_file_map(klass._filemap_from_iobase(io_obj))
550556

551-
def to_stream(self, io_obj: io.IOBase, **kwargs):
557+
def to_stream(self, io_obj: io.IOBase, **kwargs) -> None:
552558
r"""Save image to writable IO stream
553559
554560
Parameters
@@ -561,7 +567,7 @@ def to_stream(self, io_obj: io.IOBase, **kwargs):
561567
self.to_file_map(self._filemap_from_iobase(io_obj), **kwargs)
562568

563569
@classmethod
564-
def from_bytes(klass, bytestring: bytes):
570+
def from_bytes(klass: type[StreamImgT], bytestring: bytes) -> StreamImgT:
565571
"""Construct image from a byte string
566572
567573
Class method
@@ -592,7 +598,9 @@ def to_bytes(self, **kwargs) -> bytes:
592598
return bio.getvalue()
593599

594600
@classmethod
595-
def from_url(klass, url: str | request.Request, timeout: float = 5):
601+
def from_url(
602+
klass: type[StreamImgT], url: str | request.Request, timeout: float = 5
603+
) -> StreamImgT:
596604
"""Retrieve and load an image from a URL
597605
598606
Class method

nibabel/freesurfer/mghformat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,5 +589,5 @@ def _affine2header(self):
589589
hdr['Pxyz_c'] = c_ras
590590

591591

592-
load = MGHImage.load
592+
load = MGHImage.from_filename
593593
save = MGHImage.instance_to_filename

nibabel/minc1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,4 +334,4 @@ def from_file_map(klass, file_map, *, mmap=True, keep_file_open=None):
334334
return klass(data, affine, header, extra=None, file_map=file_map)
335335

336336

337-
load = Minc1Image.load
337+
load = Minc1Image.from_filename

nibabel/minc2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,4 @@ def from_file_map(klass, file_map, *, mmap=True, keep_file_open=None):
172172
return klass(data, affine, header, extra=None, file_map=file_map)
173173

174174

175-
load = Minc2Image.load
175+
load = Minc2Image.from_filename

nibabel/nifti1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from .batteryrunners import Report
2626
from .casting import have_binary128
2727
from .deprecated import alert_future_error
28-
from .filebasedimages import SerializableImage
28+
from .filebasedimages import ImageFileError, SerializableImage
2929
from .optpkg import optional_package
3030
from .quaternions import fillpositive, mat2quat, quat2mat
31-
from .spatialimages import HeaderDataError, ImageFileError
31+
from .spatialimages import HeaderDataError
3232
from .spm99analyze import SpmAnalyzeHeader
3333
from .volumeutils import Recoder, endian_codes, make_dt_codes
3434

0 commit comments

Comments
 (0)