diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5c0c8af53..505fc2361 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -239,22 +239,18 @@ jobs: continue-on-error: true strategy: matrix: - check: ['style', 'doctest', 'typecheck', 'spellcheck'] + check: ['style', 'doctest', 'typecheck', 'spellcheck', 'type-inference'] steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: 3 - - name: Display Python version - run: python -c "import sys; print(sys.version)" + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v5 + - name: Install tox + run: uv tool install tox --with=tox-uv - name: Show tox config - run: pipx run tox c - - name: Show tox config (this call) - run: pipx run tox c -e ${{ matrix.check }} + run: tox c -e ${{ matrix.check }} - name: Run check - run: pipx run tox -e ${{ matrix.check }} + run: tox -e ${{ matrix.check }} publish: runs-on: ubuntu-latest diff --git a/nibabel/_typing.py b/nibabel/_typing.py new file mode 100644 index 000000000..8b6203181 --- /dev/null +++ b/nibabel/_typing.py @@ -0,0 +1,25 @@ +"""Helpers for typing compatibility across Python versions""" + +import sys + +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self + +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + + +__all__ = [ + 'ParamSpec', + 'Self', + 'TypeVar', +] diff --git a/nibabel/analyze.py b/nibabel/analyze.py index d02363c79..453fe83df 100644 --- a/nibabel/analyze.py +++ b/nibabel/analyze.py @@ -84,13 +84,15 @@ from __future__ import annotations +import typing as ty + import numpy as np from .arrayproxy import ArrayProxy from .arraywriters import ArrayWriter, WriterError, get_slope_inter, make_array_writer from .batteryrunners import Report from .fileholders import copy_file_map -from .spatialimages import HeaderDataError, HeaderTypeError, SpatialHeader, SpatialImage +from .spatialimages import AffT, HeaderDataError, HeaderTypeError, SpatialHeader, SpatialImage from .volumeutils import ( apply_read_scaling, array_from_file, @@ -102,6 +104,13 @@ ) from .wrapstruct import LabeledWrapStruct +if ty.TYPE_CHECKING: + from collections.abc import Mapping + + from .arrayproxy import ArrayLike + from .filebasedimages import FileBasedHeader + from .fileholders import FileMap + # Sub-parts of standard analyze header from # Mayo dbh.h file header_key_dtd = [ @@ -893,11 +902,12 @@ def may_contain_header(klass, binaryblock): return 348 in (hdr_struct['sizeof_hdr'], bs_hdr_struct['sizeof_hdr']) -class AnalyzeImage(SpatialImage): +class AnalyzeImage(SpatialImage[AffT]): """Class for basic Analyze format image""" header_class: type[AnalyzeHeader] = AnalyzeHeader header: AnalyzeHeader + _header: AnalyzeHeader _meta_sniff_len = header_class.sizeof_hdr files_types: tuple[tuple[str, str], ...] = (('image', '.img'), ('header', '.hdr')) valid_exts: tuple[str, ...] = ('.img', '.hdr') @@ -908,7 +918,15 @@ class AnalyzeImage(SpatialImage): ImageArrayProxy = ArrayProxy - def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, dtype=None): + def __init__( + self, + dataobj: ArrayLike, + affine: AffT, + header: FileBasedHeader | Mapping | None = None, + extra: Mapping | None = None, + file_map: FileMap | None = None, + dtype=None, + ) -> None: super().__init__(dataobj, affine, header, extra, file_map) # Reset consumable values self._header.set_data_offset(0) diff --git a/nibabel/arrayproxy.py b/nibabel/arrayproxy.py index ed2310519..82713f639 100644 --- a/nibabel/arrayproxy.py +++ b/nibabel/arrayproxy.py @@ -59,10 +59,11 @@ if ty.TYPE_CHECKING: import numpy.typing as npt - from typing_extensions import Self # PY310 + + from ._typing import Self, TypeVar # Taken from numpy/__init__.pyi - _DType = ty.TypeVar('_DType', bound=np.dtype[ty.Any]) + _DType = TypeVar('_DType', bound=np.dtype[ty.Any]) class ArrayLike(ty.Protocol): diff --git a/nibabel/brikhead.py b/nibabel/brikhead.py index cd791adac..71ad7f479 100644 --- a/nibabel/brikhead.py +++ b/nibabel/brikhead.py @@ -35,7 +35,7 @@ from .arrayproxy import ArrayProxy from .fileslice import strided_scalar -from .spatialimages import HeaderDataError, ImageDataError, SpatialHeader, SpatialImage +from .spatialimages import Affine, HeaderDataError, ImageDataError, SpatialHeader, SpatialImage from .volumeutils import Recoder # used for doc-tests @@ -453,7 +453,7 @@ def get_volume_labels(self): return labels -class AFNIImage(SpatialImage): +class AFNIImage(SpatialImage[Affine]): """ AFNI Image file diff --git a/nibabel/dataobj_images.py b/nibabel/dataobj_images.py index 565a22879..fdbd30530 100644 --- a/nibabel/dataobj_images.py +++ b/nibabel/dataobj_images.py @@ -14,6 +14,7 @@ import numpy as np +from ._typing import TypeVar from .deprecated import deprecate_with_version from .filebasedimages import FileBasedHeader, FileBasedImage @@ -24,7 +25,13 @@ from .fileholders import FileMap from .filename_parser import FileSpec -ArrayImgT = ty.TypeVar('ArrayImgT', bound='DataobjImage') + FT = TypeVar('FT', bound=np.floating) + F16 = ty.Literal['float16', 'f2', '|f2', '=f2', 'f2'] + F32 = ty.Literal['float32', 'f4', '|f4', '=f4', 'f4'] + F64 = ty.Literal['float64', 'f8', '|f8', '=f8', 'f8'] + Caching = ty.Literal['fill', 'unchanged'] + +ArrayImgT = TypeVar('ArrayImgT', bound='DataobjImage') class DataobjImage(FileBasedImage): @@ -39,7 +46,7 @@ def __init__( header: FileBasedHeader | ty.Mapping | None = None, extra: ty.Mapping | None = None, file_map: FileMap | None = None, - ): + ) -> None: """Initialize dataobj image The datobj image is a combination of (dataobj, header), with optional @@ -224,11 +231,33 @@ def get_data(self, caching='fill'): self._data_cache = data return data + # Types and dtypes, e.g., np.float64 or np.dtype('f8') + @ty.overload + def get_fdata( + self, *, caching: Caching = 'fill', dtype: type[FT] | np.dtype[FT] + ) -> npt.NDArray[FT]: ... + @ty.overload + def get_fdata(self, caching: Caching, dtype: type[FT] | np.dtype[FT]) -> npt.NDArray[FT]: ... + # Support string literals + @ty.overload + def get_fdata(self, caching: Caching, dtype: F16) -> npt.NDArray[np.float16]: ... + @ty.overload + def get_fdata(self, caching: Caching, dtype: F32) -> npt.NDArray[np.float32]: ... + @ty.overload + def get_fdata(self, *, caching: Caching = 'fill', dtype: F16) -> npt.NDArray[np.float16]: ... + @ty.overload + def get_fdata(self, *, caching: Caching = 'fill', dtype: F32) -> npt.NDArray[np.float32]: ... + # Double-up on float64 literals and the default (no arguments) case + @ty.overload + def get_fdata( + self, caching: Caching = 'fill', dtype: F64 = 'f8' + ) -> npt.NDArray[np.float64]: ... + def get_fdata( self, - caching: ty.Literal['fill', 'unchanged'] = 'fill', + caching: Caching = 'fill', dtype: npt.DTypeLike = np.float64, - ) -> np.ndarray[ty.Any, np.dtype[np.floating]]: + ) -> npt.NDArray[np.floating]: """Return floating point image data with necessary scaling applied The image ``dataobj`` property can be an array proxy or an array. An diff --git a/nibabel/deprecated.py b/nibabel/deprecated.py index d39c0624d..394fb0799 100644 --- a/nibabel/deprecated.py +++ b/nibabel/deprecated.py @@ -5,15 +5,11 @@ import typing as ty import warnings +from ._typing import ParamSpec from .deprecator import Deprecator from .pkg_info import cmp_pkg_version -if ty.TYPE_CHECKING: - # PY39: ParamSpec is available in Python 3.10+ - P = ty.ParamSpec('P') -else: - # Just to keep the runtime happy - P = ty.TypeVar('P') +P = ParamSpec('P') class ModuleProxy: diff --git a/nibabel/deprecator.py b/nibabel/deprecator.py index 972e5f2a8..cf2b525a3 100644 --- a/nibabel/deprecator.py +++ b/nibabel/deprecator.py @@ -10,8 +10,10 @@ from textwrap import dedent if ty.TYPE_CHECKING: - T = ty.TypeVar('T') - P = ty.ParamSpec('P') + from ._typing import ParamSpec, TypeVar + + T = TypeVar('T') + P = ParamSpec('P') _LEADING_WHITE = re.compile(r'^(\s*)') diff --git a/nibabel/ecat.py b/nibabel/ecat.py index f634bcd8a..f5b76a50a 100644 --- a/nibabel/ecat.py +++ b/nibabel/ecat.py @@ -43,17 +43,30 @@ below). It's not clear what the licenses are for these files. """ +from __future__ import annotations + import warnings from numbers import Integral +from typing import TYPE_CHECKING import numpy as np from .arraywriters import make_array_writer from .fileslice import canonical_slicers, predict_shape, slice2outax -from .spatialimages import SpatialHeader, SpatialImage +from .spatialimages import Affine, AffT, SpatialHeader, SpatialImage from .volumeutils import array_from_file, make_dt_codes, native_code, swapped_code from .wrapstruct import WrapStruct +if TYPE_CHECKING: + from collections.abc import Mapping + from typing import Literal as L + + import numpy.typing as npt + + from .arrayproxy import ArrayLike + from .filebasedimages import FileBasedHeader + from .fileholders import FileMap + BLOCK_SIZE = 512 main_header_dtd = [ @@ -743,7 +756,7 @@ def __getitem__(self, sliceobj): return out_data -class EcatImage(SpatialImage): +class EcatImage(SpatialImage[AffT]): """Class returns a list of Ecat images, with one image(hdr/data) per frame""" header_class = EcatHeader @@ -756,7 +769,16 @@ class EcatImage(SpatialImage): ImageArrayProxy = EcatImageArrayProxy - def __init__(self, dataobj, affine, header, subheader, mlist, extra=None, file_map=None): + def __init__( + self, + dataobj: ArrayLike, + affine: AffT, + header: FileBasedHeader | Mapping | None, + subheader: EcatSubHeader, + mlist: npt.NDArray[np.integer], + extra: Mapping | None = None, + file_map: FileMap | None = None, + ) -> None: """Initialize Image The image is a combination of @@ -798,40 +820,38 @@ def __init__(self, dataobj, affine, header, subheader, mlist, extra=None, file_m >>> data4d.shape == (10, 10, 3, 1) True """ + super().__init__( + dataobj=dataobj, + affine=affine, + header=header, + extra=extra, + file_map=file_map, + ) self._subheader = subheader self._mlist = mlist - self._dataobj = dataobj - if affine is not None: - # Check that affine is array-like 4,4. Maybe this is too strict at - # this abstract level, but so far I think all image formats we know - # do need 4,4. - affine = np.array(affine, dtype=np.float64, copy=True) - if not affine.shape == (4, 4): - raise ValueError('Affine should be shape 4,4') - self._affine = affine - if extra is None: - extra = {} - self.extra = extra - self._header = header - if file_map is None: - file_map = self.__class__.make_file_map() - self.file_map = file_map - self._data_cache = None - self._fdata_cache = None + + # Override SpatialImage default, which attempts to set the + # affine in the header. + def update_header(self) -> None: + """Does nothing""" @property - def affine(self): + def affine(self) -> AffT: if not self._subheader._check_affines(): warnings.warn( 'Affines different across frames, loading affine from FIRST frame', UserWarning ) return self._affine - def get_frame_affine(self, frame): + def get_frame_affine(self, frame: int) -> Affine: """returns 4X4 affine""" return self._subheader.get_frame_affine(frame=frame) - def get_frame(self, frame, orientation=None): + def get_frame( + self, + frame: int, + orientation: L['neurological', 'radiological'] | None = None, + ) -> np.ndarray: """ Get full volume for a time frame @@ -847,16 +867,16 @@ def get_data_dtype(self, frame): return dt @property - def shape(self): + def shape(self) -> tuple[int, int, int, int]: x, y, z = self._subheader.get_shape() nframes = self._subheader.get_nframes() return (x, y, z, nframes) - def get_mlist(self): + def get_mlist(self) -> npt.NDArray[np.integer]: """get access to the mlist""" return self._mlist - def get_subheaders(self): + def get_subheaders(self) -> EcatSubHeader: """get access to subheaders""" return self._subheader diff --git a/nibabel/filebasedimages.py b/nibabel/filebasedimages.py index 086e31f12..ede381e40 100644 --- a/nibabel/filebasedimages.py +++ b/nibabel/filebasedimages.py @@ -16,6 +16,7 @@ from urllib import request from ._compression import COMPRESSION_ERRORS +from ._typing import TypeVar from .fileholders import FileHolder, FileMap from .filename_parser import TypesFilenamesError, _stringify_path, splitext_addext, types_filenames from .openers import ImageOpener @@ -25,10 +26,10 @@ FileSniff = tuple[bytes, str] -ImgT = ty.TypeVar('ImgT', bound='FileBasedImage') -HdrT = ty.TypeVar('HdrT', bound='FileBasedHeader') +ImgT = TypeVar('ImgT', bound='FileBasedImage') +HdrT = TypeVar('HdrT', bound='FileBasedHeader') -StreamImgT = ty.TypeVar('StreamImgT', bound='SerializableImage') +StreamImgT = TypeVar('StreamImgT', bound='SerializableImage') class ImageFileError(Exception): diff --git a/nibabel/freesurfer/mghformat.py b/nibabel/freesurfer/mghformat.py index 0adcb88e2..28a30dc9a 100644 --- a/nibabel/freesurfer/mghformat.py +++ b/nibabel/freesurfer/mghformat.py @@ -22,7 +22,7 @@ from ..fileholders import FileHolder from ..filename_parser import _stringify_path from ..openers import ImageOpener -from ..spatialimages import HeaderDataError, SpatialHeader, SpatialImage +from ..spatialimages import Affine, HeaderDataError, SpatialHeader, SpatialImage from ..volumeutils import Recoder, array_from_file, array_to_file, endian_codes from ..wrapstruct import LabeledWrapStruct @@ -459,7 +459,7 @@ def diagnose_binaryblock(klass, binaryblock, endianness=None): return '\n'.join([report.message for report in reports if report.message]) -class MGHImage(SpatialImage, SerializableImage): +class MGHImage(SpatialImage[Affine], SerializableImage): """Class for MGH format image""" header_class = MGHHeader diff --git a/nibabel/loadsave.py b/nibabel/loadsave.py index e39aeceba..7ba02d826 100644 --- a/nibabel/loadsave.py +++ b/nibabel/loadsave.py @@ -27,10 +27,11 @@ if ty.TYPE_CHECKING: + from ._typing import ParamSpec from .filebasedimages import FileBasedImage from .filename_parser import FileSpec - P = ty.ParamSpec('P') + P = ParamSpec('P') class Signature(ty.TypedDict): signature: bytes diff --git a/nibabel/minc1.py b/nibabel/minc1.py index d0b9fd537..c85647216 100644 --- a/nibabel/minc1.py +++ b/nibabel/minc1.py @@ -16,7 +16,7 @@ from .externals.netcdf import netcdf_file from .fileslice import canonical_slicers -from .spatialimages import SpatialHeader, SpatialImage +from .spatialimages import Affine, SpatialHeader, SpatialImage _dt_dict = { ('b', 'unsigned'): np.uint8, @@ -299,7 +299,7 @@ def may_contain_header(klass, binaryblock): return binaryblock[:4] == b'CDF\x01' -class Minc1Image(SpatialImage): +class Minc1Image(SpatialImage[Affine]): """Class for MINC1 format images The MINC1 image class uses the default header type, rather than a specific diff --git a/nibabel/nifti1.py b/nibabel/nifti1.py index 5ea3041fc..ffd8a1bb4 100644 --- a/nibabel/nifti1.py +++ b/nibabel/nifti1.py @@ -14,7 +14,6 @@ from __future__ import annotations import json -import sys import typing as ty import warnings from io import BytesIO @@ -22,12 +21,8 @@ import numpy as np import numpy.linalg as npl -if sys.version_info < (3, 13): - from typing_extensions import Self, TypeVar # PY312 -else: - from typing import Self, TypeVar - from . import analyze # module import +from ._typing import Self, TypeVar from .arrayproxy import get_obj_dtype from .batteryrunners import Report from .casting import have_binary128 @@ -35,13 +30,19 @@ from .filebasedimages import ImageFileError, SerializableImage from .optpkg import optional_package from .quaternions import fillpositive, mat2quat, quat2mat -from .spatialimages import HeaderDataError +from .spatialimages import AffT, HeaderDataError from .spm99analyze import SpmAnalyzeHeader from .volumeutils import Recoder, endian_codes, make_dt_codes if ty.TYPE_CHECKING: + from collections.abc import Mapping + import pydicom as pdcm + from .arrayproxy import ArrayLike + from .filebasedimages import FileBasedHeader + from .fileholders import FileMap + have_dicom = True DicomDataset = pdcm.Dataset else: @@ -1971,11 +1972,12 @@ class Nifti1PairHeader(Nifti1Header): is_single = False -class Nifti1Pair(analyze.AnalyzeImage): +class Nifti1Pair(analyze.AnalyzeImage[AffT]): """Class for NIfTI1 format image, header pair""" header_class: type[Nifti1Header] = Nifti1PairHeader header: Nifti1Header + _header: Nifti1Header _meta_sniff_len = header_class.sizeof_hdr rw = True @@ -1983,7 +1985,15 @@ class Nifti1Pair(analyze.AnalyzeImage): # the data at serialization time _dtype_alias = None - def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, dtype=None): + def __init__( + self, + dataobj: ArrayLike, + affine: AffT, + header: FileBasedHeader | Mapping | None = None, + extra: Mapping | None = None, + file_map: FileMap | None = None, + dtype=None, + ) -> None: # Special carve-out for 64 bit integers # See GitHub issues # * https://github.com/nipy/nibabel/issues/1046 @@ -1994,7 +2004,7 @@ def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, dtyp danger_dts = (np.dtype('int64'), np.dtype('uint64')) if header is None and dtype is None and get_obj_dtype(dataobj) in danger_dts: alert_future_error( - f'Image data has type {dataobj.dtype}, which may cause ' + f'Image data has type {get_obj_dtype(dataobj)}, which may cause ' 'incompatibilities with other tools.', '5.0', warning_rec='This warning can be silenced by passing the dtype argument' @@ -2410,7 +2420,7 @@ def as_reoriented(self, ornt): return img -class Nifti1Image(Nifti1Pair, SerializableImage): +class Nifti1Image(Nifti1Pair[AffT], SerializableImage): """Class for single file NIfTI1 format image""" header_class = Nifti1Header diff --git a/nibabel/nifti2.py b/nibabel/nifti2.py index 9c898b47b..4fa3bc7fb 100644 --- a/nibabel/nifti2.py +++ b/nibabel/nifti2.py @@ -19,7 +19,7 @@ from .batteryrunners import Report from .filebasedimages import ImageFileError from .nifti1 import Nifti1Header, Nifti1Image, Nifti1Pair -from .spatialimages import HeaderDataError +from .spatialimages import AffT, HeaderDataError r""" Header struct from : https://www.nitrc.org/forum/message.php?msg_id=3738 @@ -240,17 +240,19 @@ class Nifti2PairHeader(Nifti2Header): is_single = False -class Nifti2Pair(Nifti1Pair): +class Nifti2Pair(Nifti1Pair[AffT]): """Class for NIfTI2 format image, header pair""" - header_class = Nifti2PairHeader + header_class: type[Nifti2Header] = Nifti2PairHeader + header: Nifti2Header _meta_sniff_len = header_class.sizeof_hdr -class Nifti2Image(Nifti1Image): +class Nifti2Image(Nifti1Image[AffT]): """Class for single file NIfTI2 format image""" - header_class = Nifti2Header + header_class: type[Nifti2Header] = Nifti2Header + header: Nifti2Header _meta_sniff_len = header_class.sizeof_hdr diff --git a/nibabel/openers.py b/nibabel/openers.py index 35b10c20a..fa9e42a8d 100644 --- a/nibabel/openers.py +++ b/nibabel/openers.py @@ -22,7 +22,8 @@ from types import TracebackType from _typeshed import WriteableBuffer - from typing_extensions import Self + + from ._typing import Self ModeRT = ty.Literal['r', 'rt'] ModeRB = ty.Literal['rb'] diff --git a/nibabel/parrec.py b/nibabel/parrec.py index 22520a603..b0f1ebda8 100644 --- a/nibabel/parrec.py +++ b/nibabel/parrec.py @@ -134,7 +134,7 @@ from .fileslice import fileslice, strided_scalar from .nifti1 import unit_codes from .openers import ImageOpener -from .spatialimages import SpatialHeader, SpatialImage +from .spatialimages import Affine, SpatialHeader, SpatialImage from .volumeutils import Recoder, array_from_file # PSL to RAS affine @@ -1248,7 +1248,7 @@ def get_volume_labels(self): return sort_info -class PARRECImage(SpatialImage): +class PARRECImage(SpatialImage[Affine]): """PAR/REC image""" header_class = PARRECHeader diff --git a/nibabel/pointset.py b/nibabel/pointset.py index 759a0b15e..1d20b82fe 100644 --- a/nibabel/pointset.py +++ b/nibabel/pointset.py @@ -31,9 +31,9 @@ from nibabel.spatialimages import SpatialImage if ty.TYPE_CHECKING: - from typing_extensions import Self + from ._typing import Self, TypeVar - _DType = ty.TypeVar('_DType', bound=np.dtype[ty.Any]) + _DType = TypeVar('_DType', bound=np.dtype[ty.Any]) class CoordinateArray(ty.Protocol): diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index a8e899359..9a7d8f5e9 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -137,7 +137,9 @@ from typing import Literal import numpy as np +import numpy.typing as npt +from ._typing import TypeVar from .casting import sctypes_aliases from .dataobj_images import DataobjImage from .filebasedimages import FileBasedHeader, FileBasedImage @@ -150,13 +152,17 @@ import io from collections.abc import Sequence - import numpy.typing as npt - from .arrayproxy import ArrayLike from .fileholders import FileMap -SpatialImgT = ty.TypeVar('SpatialImgT', bound='SpatialImage') -SpatialHdrT = ty.TypeVar('SpatialHdrT', bound='SpatialHeader') +# Track whether the image is initialized with an affine or not +# This will almost always be the case, but there are some exceptions +# and some functions that will fail if the affine is not present +Affine = npt.NDArray[np.floating] +AffT = TypeVar('AffT', covariant=True, bound=ty.Union[Affine, None], default=Affine) +SpatialImgT = TypeVar('SpatialImgT', bound='SpatialImage[Affine]') +SpatialHdrT = TypeVar('SpatialHdrT', bound='SpatialHeader') +AnySpatialImgT = TypeVar('AnySpatialImgT', bound='SpatialImage[Affine | None]') class HasDtype(ty.Protocol): @@ -194,7 +200,7 @@ def __init__( data_dtype: npt.DTypeLike = np.float32, shape: Sequence[int] = (0,), zooms: Sequence[float] | None = None, - ): + ) -> None: self.set_data_dtype(data_dtype) self._zooms = () self.set_data_shape(shape) @@ -461,7 +467,7 @@ def slice_affine(self, slicer: object) -> np.ndarray: return self.img.affine.dot(transform) -class SpatialImage(DataobjImage): +class SpatialImage(DataobjImage, ty.Generic[AffT]): """Template class for volumetric (3D/4D) images""" header_class: type[SpatialHeader] = SpatialHeader @@ -473,11 +479,11 @@ class SpatialImage(DataobjImage): def __init__( self, dataobj: ArrayLike, - affine: np.ndarray | None, + affine: AffT, header: FileBasedHeader | ty.Mapping | None = None, extra: ty.Mapping | None = None, file_map: FileMap | None = None, - ): + ) -> None: """Initialize image The image is a combination of (array-like, affine matrix, header), with @@ -510,7 +516,7 @@ def __init__( # do need 4,4. # Copy affine to isolate from environment. Specify float type to # avoid surprising integer rounding when setting values into affine - affine = np.array(affine, dtype=np.float64, copy=True) + affine = np.array(affine, dtype=np.float64, copy=True) # type: ignore[assignment] if not affine.shape == (4, 4): raise ValueError('Affine should be shape 4,4') self._affine = affine @@ -524,7 +530,7 @@ def __init__( self._data_cache = None @property - def affine(self): + def affine(self) -> AffT: return self._affine def update_header(self) -> None: @@ -586,7 +592,7 @@ def set_data_dtype(self, dtype: npt.DTypeLike) -> None: self._header.set_data_dtype(dtype) @classmethod - def from_image(klass: type[SpatialImgT], img: SpatialImage | FileBasedImage) -> SpatialImgT: + def from_image(klass: type[AnySpatialImgT], img: FileBasedImage) -> AnySpatialImgT: """Class method to create new instance of own class from `img` Parameters @@ -629,7 +635,7 @@ def slicer(self: SpatialImgT) -> SpatialFirstSlicer[SpatialImgT]: """ return self.ImageSlicer(self) - def __getitem__(self, idx: object) -> None: + def __getitem__(self, idx: object) -> ty.Never: """No slicing or dictionary interface for images Use the slicer attribute to perform cropping and subsampling at your diff --git a/nibabel/spm2analyze.py b/nibabel/spm2analyze.py index 9c4c544cf..43fd791ea 100644 --- a/nibabel/spm2analyze.py +++ b/nibabel/spm2analyze.py @@ -10,7 +10,8 @@ import numpy as np -from . import spm99analyze as spm99 # module import +from . import spm99analyze as spm99 +from .spatialimages import AffT image_dimension_dtd = spm99.image_dimension_dtd.copy() image_dimension_dtd[image_dimension_dtd.index(('funused2', 'f4'))] = ('scl_inter', 'f4') @@ -125,7 +126,7 @@ def may_contain_header(klass, binaryblock): ) -class Spm2AnalyzeImage(spm99.Spm99AnalyzeImage): +class Spm2AnalyzeImage(spm99.Spm99AnalyzeImage[AffT]): """Class for SPM2 variant of basic Analyze image""" header_class = Spm2AnalyzeHeader diff --git a/nibabel/spm99analyze.py b/nibabel/spm99analyze.py index cdedf223e..6fbd8d2d7 100644 --- a/nibabel/spm99analyze.py +++ b/nibabel/spm99analyze.py @@ -16,7 +16,7 @@ from . import analyze # module import from .batteryrunners import Report from .optpkg import optional_package -from .spatialimages import HeaderDataError, HeaderTypeError +from .spatialimages import AffT, HeaderDataError, HeaderTypeError have_scipy = optional_package('scipy')[1] @@ -224,7 +224,7 @@ def _chk_origin(hdr, fix=False): return hdr, rep -class Spm99AnalyzeImage(analyze.AnalyzeImage): +class Spm99AnalyzeImage(analyze.AnalyzeImage[AffT]): """Class for SPM99 variant of basic Analyze image""" header_class = Spm99AnalyzeHeader diff --git a/nibabel/volumeutils.py b/nibabel/volumeutils.py index cf23d905f..d2411f1c6 100644 --- a/nibabel/volumeutils.py +++ b/nibabel/volumeutils.py @@ -28,11 +28,13 @@ import numpy.typing as npt + from ._typing import TypeVar + Scalar = np.number | float - K = ty.TypeVar('K') - V = ty.TypeVar('V') - DT = ty.TypeVar('DT', bound=np.generic) + K = TypeVar('K') + V = TypeVar('V') + DT = TypeVar('DT', bound=np.generic) sys_is_le = sys.byteorder == 'little' native_code: ty.Literal['<', '>'] = '<' if sys_is_le else '>' diff --git a/pyproject.toml b/pyproject.toml index 73f01b66e..61868daef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ test = [ "pytest-httpserver >=1.0.7", "pytest-xdist >=3.5", "coverage[toml]>=7.2", + "pytest-mypy-testing>=0.1.3", ] # Remaining: Simpler to centralize in tox dev = ["tox"] diff --git a/tests/ruff.toml b/tests/ruff.toml new file mode 100644 index 000000000..d461c74d7 --- /dev/null +++ b/tests/ruff.toml @@ -0,0 +1 @@ +line-length = 200 diff --git a/tests/typing/test_spatialimage_api.py b/tests/typing/test_spatialimage_api.py new file mode 100644 index 000000000..d144c4e19 --- /dev/null +++ b/tests/typing/test_spatialimage_api.py @@ -0,0 +1,80 @@ +import typing as ty + +import numpy as np +import pytest + +from nibabel import AnalyzeImage, Spm99AnalyzeImage, Spm2AnalyzeImage, Nifti1Image, Nifti2Image, MGHImage +from nibabel.spatialimages import SpatialImage + +if ty.TYPE_CHECKING: + from typing import reveal_type +else: + + def reveal_type(x: ty.Any) -> None: + pass + + +@pytest.mark.mypy_testing +def test_affine_tracking() -> None: + img_with_affine = SpatialImage(np.empty((5, 5, 5)), np.eye(4)) + img_without_affine = SpatialImage(np.empty((5, 5, 5)), None) + + reveal_type(img_with_affine) # R: nibabel.spatialimages.SpatialImage[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]] + reveal_type(img_without_affine) # R: nibabel.spatialimages.SpatialImage[None] + + +@pytest.mark.mypy_testing +def test_SpatialImageAPI() -> None: + img = SpatialImage(np.empty((5, 5, 5)), np.eye(4)) + + # Affine + reveal_type(img.affine) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]] + reveal_type(SpatialImage(np.empty((5, 5, 5)), None).affine) # R: None + + # Data + reveal_type(img.dataobj) # R: nibabel.arrayproxy.ArrayLike + reveal_type(img.get_fdata()) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]] + reveal_type(img.get_fdata(dtype=np.float32)) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.floating[numpy._typing._nbit_base._32Bit]]] + reveal_type(img.get_fdata(dtype=np.float64)) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]] + reveal_type(img.get_fdata(dtype=np.dtype(np.float32))) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.floating[numpy._typing._nbit_base._32Bit]]] + reveal_type(img.get_fdata(dtype=np.dtype(np.float64))) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]] + reveal_type(img.get_fdata(dtype=np.dtype("f4"))) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.floating[numpy._typing._nbit_base._32Bit]]] + reveal_type(img.get_fdata(dtype=np.dtype("f8"))) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]] + reveal_type(img.get_fdata(dtype="f4")) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.floating[numpy._typing._nbit_base._32Bit]]] + reveal_type(img.get_fdata(dtype="f8")) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]] + + # Indirect header + reveal_type(img.shape) # R: builtins.tuple[builtins.int, ...] + reveal_type(img.ndim) # R: builtins.int + + # SpatialHeader fields + reveal_type(img.header.get_data_dtype()) # R: numpy.dtype[Any] + reveal_type(img.header.get_data_shape()) # R: builtins.tuple[builtins.int, ...] + reveal_type(img.header.get_zooms()) # R: builtins.tuple[builtins.float, ...] + + +@pytest.mark.mypy_testing +def test_image_and_header_types() -> None: + analyze_img = AnalyzeImage(np.empty((5, 5, 5)), np.eye(4)) + reveal_type(analyze_img) # R: nibabel.analyze.AnalyzeImage[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]] + reveal_type(analyze_img.header) # R: nibabel.analyze.AnalyzeHeader + + spm99_img = Spm99AnalyzeImage(np.empty((5, 5, 5)), np.eye(4)) + reveal_type(spm99_img) # R: nibabel.spm99analyze.Spm99AnalyzeImage[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]] + reveal_type(spm99_img.header) # R: nibabel.spm99analyze.Spm99AnalyzeHeader + + spm2_img = Spm2AnalyzeImage(np.empty((5, 5, 5)), np.eye(4)) + reveal_type(spm2_img) # R: nibabel.spm2analyze.Spm2AnalyzeImage[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]] + reveal_type(spm2_img.header) # R: nibabel.spm2analyze.Spm2AnalyzeHeader + + ni1_img = Nifti1Image(np.empty((5, 5, 5)), np.eye(4)) + reveal_type(ni1_img) # R: nibabel.nifti1.Nifti1Image[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]] + reveal_type(ni1_img.header) # R: nibabel.nifti1.Nifti1Header + + ni2_img = Nifti2Image(np.empty((5, 5, 5)), np.eye(4)) + reveal_type(ni2_img) # R: nibabel.nifti2.Nifti2Image[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]] + reveal_type(ni2_img.header) # R: nibabel.nifti2.Nifti2Header + + mgh_img = MGHImage(np.empty((5, 5, 5), dtype=np.float32), np.eye(4)) + reveal_type(mgh_img) # R: nibabel.freesurfer.mghformat.MGHImage + reveal_type(mgh_img.header) # R: nibabel.freesurfer.mghformat.MGHHeader diff --git a/tox.ini b/tox.ini index 05d977951..72e4e563e 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ envlist = doctest style typecheck + type-inference skip_missing_interpreters = true # Configuration that allows us to split tests across GitHub runners effectively @@ -105,6 +106,7 @@ commands = --cov nibabel --cov-report xml:cov.xml \ --junitxml test-results.xml \ --durations=20 --durations-min=1.0 \ + --dist worksteal \ --pyargs nibabel {posargs:-n auto} [testenv:install] @@ -179,6 +181,19 @@ skip_install = true commands = mypy nibabel +[testenv:type-inference] +description = Check type inference +labels = test +deps = + pytest-mypy-testing @ git+https://github.com/effigies/pytest-mypy-testing@rf/global-mypy-run +commands = + python -m pytest \ + --cov tests --cov nibabel --cov-report xml:cov.xml \ + --junitxml test-results.xml \ + --durations=20 --durations-min=1.0 \ + --dist loadgroup \ + tests/ {posargs:-n auto} + [testenv:build{,-strict}] labels = check