Skip to content

Commit 3fabc03

Browse files
authored
Merge pull request #1188 from effigies/type/internal_tools
TYP: Annotate deprecation, one-time property, and optional package tooling
2 parents aa0bfff + 4a676c5 commit 3fabc03

File tree

8 files changed

+112
-56
lines changed

8 files changed

+112
-56
lines changed

nibabel/deprecated.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
"""
33
from __future__ import annotations
44

5+
import typing as ty
56
import warnings
6-
from typing import Type
77

88
from .deprecator import Deprecator
99
from .pkg_info import cmp_pkg_version
1010

11+
if ty.TYPE_CHECKING: # pragma: no cover
12+
P = ty.ParamSpec('P')
13+
1114

1215
class ModuleProxy:
1316
"""Proxy for module that may not yet have been imported
@@ -30,14 +33,14 @@ class ModuleProxy:
3033
module.
3134
"""
3235

33-
def __init__(self, module_name):
36+
def __init__(self, module_name: str) -> None:
3437
self._module_name = module_name
3538

36-
def __getattr__(self, key):
39+
def __getattr__(self, key: str) -> ty.Any:
3740
mod = __import__(self._module_name, fromlist=[''])
3841
return getattr(mod, key)
3942

40-
def __repr__(self):
43+
def __repr__(self) -> str:
4144
return f'<module proxy for {self._module_name}>'
4245

4346

@@ -60,7 +63,7 @@ class FutureWarningMixin:
6063

6164
warn_message = 'This class will be removed in future versions'
6265

63-
def __init__(self, *args, **kwargs):
66+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
6467
warnings.warn(self.warn_message, FutureWarning, stacklevel=2)
6568
super().__init__(*args, **kwargs)
6669

@@ -85,12 +88,12 @@ def alert_future_error(
8588
msg: str,
8689
version: str,
8790
*,
88-
warning_class: Type[Warning] = FutureWarning,
89-
error_class: Type[Exception] = RuntimeError,
91+
warning_class: type[Warning] = FutureWarning,
92+
error_class: type[Exception] = RuntimeError,
9093
warning_rec: str = '',
9194
error_rec: str = '',
9295
stacklevel: int = 2,
93-
):
96+
) -> None:
9497
"""Warn or error with appropriate messages for changing functionality.
9598
9699
Parameters

nibabel/deprecator.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
"""Class for recording and reporting deprecations
22
"""
3+
from __future__ import annotations
34

45
import functools
56
import re
7+
import typing as ty
68
import warnings
79

10+
if ty.TYPE_CHECKING: # pragma: no cover
11+
T = ty.TypeVar('T')
12+
P = ty.ParamSpec('P')
13+
814
_LEADING_WHITE = re.compile(r'^(\s*)')
915

1016
TESTSETUP = """
@@ -38,15 +44,20 @@ class ExpiredDeprecationError(RuntimeError):
3844
pass
3945

4046

41-
def _ensure_cr(text):
47+
def _ensure_cr(text: str) -> str:
4248
"""Remove trailing whitespace and add carriage return
4349
4450
Ensures that `text` always ends with a carriage return
4551
"""
4652
return text.rstrip() + '\n'
4753

4854

49-
def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
55+
def _add_dep_doc(
56+
old_doc: str,
57+
dep_doc: str,
58+
setup: str = '',
59+
cleanup: str = '',
60+
) -> str:
5061
"""Add deprecation message `dep_doc` to docstring in `old_doc`
5162
5263
Parameters
@@ -55,6 +66,10 @@ def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
5566
Docstring from some object.
5667
dep_doc : str
5768
Deprecation warning to add to top of docstring, after initial line.
69+
setup : str, optional
70+
Doctest setup text
71+
cleanup : str, optional
72+
Doctest teardown text
5873
5974
Returns
6075
-------
@@ -76,7 +91,9 @@ def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
7691
if next_line >= len(old_lines):
7792
# nothing following first paragraph, just append message
7893
return old_doc + '\n' + dep_doc
79-
indent = _LEADING_WHITE.match(old_lines[next_line]).group()
94+
leading_white = _LEADING_WHITE.match(old_lines[next_line])
95+
assert leading_white is not None # Type narrowing, since this always matches
96+
indent = leading_white.group()
8097
setup_lines = [indent + L for L in setup.splitlines()]
8198
dep_lines = [indent + L for L in [''] + dep_doc.splitlines() + ['']]
8299
cleanup_lines = [indent + L for L in cleanup.splitlines()]
@@ -113,15 +130,15 @@ class Deprecator:
113130

114131
def __init__(
115132
self,
116-
version_comparator,
117-
warn_class=DeprecationWarning,
118-
error_class=ExpiredDeprecationError,
119-
):
133+
version_comparator: ty.Callable[[str], int],
134+
warn_class: type[Warning] = DeprecationWarning,
135+
error_class: type[Exception] = ExpiredDeprecationError,
136+
) -> None:
120137
self.version_comparator = version_comparator
121138
self.warn_class = warn_class
122139
self.error_class = error_class
123140

124-
def is_bad_version(self, version_str):
141+
def is_bad_version(self, version_str: str) -> bool:
125142
"""Return True if `version_str` is too high
126143
127144
Tests `version_str` with ``self.version_comparator``
@@ -139,7 +156,14 @@ def is_bad_version(self, version_str):
139156
"""
140157
return self.version_comparator(version_str) == -1
141158

142-
def __call__(self, message, since='', until='', warn_class=None, error_class=None):
159+
def __call__(
160+
self,
161+
message: str,
162+
since: str = '',
163+
until: str = '',
164+
warn_class: type[Warning] | None = None,
165+
error_class: type[Exception] | None = None,
166+
) -> ty.Callable[[ty.Callable[P, T]], ty.Callable[P, T]]:
143167
"""Return decorator function function for deprecation warning / error
144168
145169
Parameters
@@ -164,8 +188,8 @@ def __call__(self, message, since='', until='', warn_class=None, error_class=Non
164188
deprecator : func
165189
Function returning a decorator.
166190
"""
167-
warn_class = warn_class or self.warn_class
168-
error_class = error_class or self.error_class
191+
exception = error_class if error_class is not None else self.error_class
192+
warning = warn_class if warn_class is not None else self.warn_class
169193
messages = [message]
170194
if (since, until) != ('', ''):
171195
messages.append('')
@@ -174,19 +198,21 @@ def __call__(self, message, since='', until='', warn_class=None, error_class=Non
174198
if until:
175199
messages.append(
176200
f"* {'Raises' if self.is_bad_version(until) else 'Will raise'} "
177-
f'{error_class} as of version: {until}'
201+
f'{exception} as of version: {until}'
178202
)
179203
message = '\n'.join(messages)
180204

181-
def deprecator(func):
205+
def deprecator(func: ty.Callable[P, T]) -> ty.Callable[P, T]:
182206
@functools.wraps(func)
183-
def deprecated_func(*args, **kwargs):
207+
def deprecated_func(*args: P.args, **kwargs: P.kwargs) -> T:
184208
if until and self.is_bad_version(until):
185-
raise error_class(message)
186-
warnings.warn(message, warn_class, stacklevel=2)
209+
raise exception(message)
210+
warnings.warn(message, warning, stacklevel=2)
187211
return func(*args, **kwargs)
188212

189213
keep_doc = deprecated_func.__doc__
214+
if keep_doc is None:
215+
keep_doc = ''
190216
setup = TESTSETUP
191217
cleanup = TESTCLEANUP
192218
# After expiration, remove all but the first paragraph.

nibabel/onetime.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
2020
[2] Python data model, https://docs.python.org/reference/datamodel.html
2121
"""
22+
from __future__ import annotations
23+
24+
import typing as ty
25+
26+
InstanceT = ty.TypeVar('InstanceT')
27+
T = ty.TypeVar('T')
2228

2329
from nibabel.deprecated import deprecate_with_version
2430

@@ -96,26 +102,24 @@ class ResetMixin:
96102
10.0
97103
"""
98104

99-
def reset(self):
105+
def reset(self) -> None:
100106
"""Reset all OneTimeProperty attributes that may have fired already."""
101-
instdict = self.__dict__
102-
classdict = self.__class__.__dict__
103107
# To reset them, we simply remove them from the instance dict. At that
104108
# point, it's as if they had never been computed. On the next access,
105109
# the accessor function from the parent class will be called, simply
106110
# because that's how the python descriptor protocol works.
107-
for mname, mval in classdict.items():
108-
if mname in instdict and isinstance(mval, OneTimeProperty):
111+
for mname, mval in self.__class__.__dict__.items():
112+
if mname in self.__dict__ and isinstance(mval, OneTimeProperty):
109113
delattr(self, mname)
110114

111115

112-
class OneTimeProperty:
116+
class OneTimeProperty(ty.Generic[T]):
113117
"""A descriptor to make special properties that become normal attributes.
114118
115119
This is meant to be used mostly by the auto_attr decorator in this module.
116120
"""
117121

118-
def __init__(self, func):
122+
def __init__(self, func: ty.Callable[[InstanceT], T]) -> None:
119123
"""Create a OneTimeProperty instance.
120124
121125
Parameters
@@ -128,24 +132,35 @@ def __init__(self, func):
128132
"""
129133
self.getter = func
130134
self.name = func.__name__
135+
self.__doc__ = func.__doc__
136+
137+
@ty.overload
138+
def __get__(
139+
self, obj: None, objtype: type[InstanceT] | None = None
140+
) -> ty.Callable[[InstanceT], T]:
141+
... # pragma: no cover
142+
143+
@ty.overload
144+
def __get__(self, obj: InstanceT, objtype: type[InstanceT] | None = None) -> T:
145+
... # pragma: no cover
131146

132-
def __get__(self, obj, type=None):
147+
def __get__(
148+
self, obj: InstanceT | None, objtype: type[InstanceT] | None = None
149+
) -> T | ty.Callable[[InstanceT], T]:
133150
"""This will be called on attribute access on the class or instance."""
134151
if obj is None:
135152
# Being called on the class, return the original function. This
136153
# way, introspection works on the class.
137-
# return func
138154
return self.getter
139155

140-
# Errors in the following line are errors in setting a
141-
# OneTimeProperty
156+
# Errors in the following line are errors in setting a OneTimeProperty
142157
val = self.getter(obj)
143158

144-
setattr(obj, self.name, val)
159+
obj.__dict__[self.name] = val
145160
return val
146161

147162

148-
def auto_attr(func):
163+
def auto_attr(func: ty.Callable[[InstanceT], T]) -> OneTimeProperty[T]:
149164
"""Decorator to create OneTimeProperty attributes.
150165
151166
Parameters

nibabel/optpkg.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
"""Routines to support optional packages"""
2+
from __future__ import annotations
3+
4+
import typing as ty
5+
from types import ModuleType
6+
27
from packaging.version import Version
38

49
from .tripwire import TripWire
510

611

7-
def _check_pkg_version(pkg, min_version):
8-
# Default version checking function
9-
if isinstance(min_version, str):
10-
min_version = Version(min_version)
11-
try:
12-
return min_version <= Version(pkg.__version__)
13-
except AttributeError:
12+
def _check_pkg_version(min_version: str | Version) -> ty.Callable[[ModuleType], bool]:
13+
min_ver = Version(min_version) if isinstance(min_version, str) else min_version
14+
15+
def check(pkg: ModuleType) -> bool:
16+
pkg_ver = getattr(pkg, '__version__', None)
17+
if isinstance(pkg_ver, str):
18+
return min_ver <= Version(pkg_ver)
1419
return False
1520

21+
return check
22+
1623

17-
def optional_package(name, trip_msg=None, min_version=None):
24+
def optional_package(
25+
name: str,
26+
trip_msg: str | None = None,
27+
min_version: str | Version | ty.Callable[[ModuleType], bool] | None = None,
28+
) -> tuple[ModuleType | TripWire, bool, ty.Callable[[], None]]:
1829
"""Return package-like thing and module setup for package `name`
1930
2031
Parameters
@@ -81,7 +92,7 @@ def optional_package(name, trip_msg=None, min_version=None):
8192
elif min_version is None:
8293
check_version = lambda pkg: True
8394
else:
84-
check_version = lambda pkg: _check_pkg_version(pkg, min_version)
95+
check_version = _check_pkg_version(min_version)
8596
# fromlist=[''] results in submodule being returned, rather than the top
8697
# level module. See help(__import__)
8798
fromlist = [''] if '.' in name else []
@@ -107,11 +118,11 @@ def optional_package(name, trip_msg=None, min_version=None):
107118
trip_msg = (
108119
f'We need package {name} for these functions, but ``import {name}`` raised {exc}'
109120
)
110-
pkg = TripWire(trip_msg)
121+
trip = TripWire(trip_msg)
111122

112-
def setup_module():
123+
def setup_module() -> None:
113124
import unittest
114125

115126
raise unittest.SkipTest(f'No {name} for these tests')
116127

117-
return pkg, False, setup_module
128+
return trip, False, setup_module

nibabel/pkg_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
COMMIT_HASH = '$Format:%h$'
1515

1616

17-
def _cmp(a, b) -> int:
17+
def _cmp(a: Version, b: Version) -> int:
1818
"""Implementation of ``cmp`` for Python 3"""
1919
return (a > b) - (a < b)
2020

@@ -113,7 +113,7 @@ def pkg_commit_hash(pkg_path: str | None = None) -> tuple[str, str]:
113113
return '(none found)', '<not found>'
114114

115115

116-
def get_pkg_info(pkg_path: str) -> dict:
116+
def get_pkg_info(pkg_path: str) -> dict[str, str]:
117117
"""Return dict describing the context of this package
118118
119119
Parameters

nibabel/processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from .optpkg import optional_package
2222

23-
spnd, _, _ = optional_package('scipy.ndimage')
23+
spnd = optional_package('scipy.ndimage')[0]
2424

2525
from .affines import AffineError, append_diag, from_matvec, rescale_affine, to_matvec
2626
from .imageclasses import spatial_axes_first

nibabel/testing/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ..optpkg import optional_package
88

9-
_, have_scipy, _ = optional_package('scipy.io')
9+
have_scipy = optional_package('scipy.io')[1]
1010

1111
from numpy.testing import assert_array_equal
1212

0 commit comments

Comments
 (0)