Skip to content

Commit 8c43ffe

Browse files
committed
TYP: Annotate openers
Opener proxy methods now match io.BufferedIOBase prototypes. Remove some version checks for indexed-gzip < 0.8, which supported Python 3.6 while our minimum is now 3.8. A runtime-checkable protocol for .read()/.write() was the easiest way to accommodate weird file-likes that aren't IOBases. When indexed-gzip is typed, we may need to adjust the output of _gzip_open.
1 parent cbd7690 commit 8c43ffe

File tree

2 files changed

+116
-67
lines changed

2 files changed

+116
-67
lines changed

nibabel/openers.py

+115-66
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,48 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Context manager openers for various fileobject types"""
10+
from __future__ import annotations
11+
1012
import gzip
11-
import warnings
13+
import io
14+
import typing as ty
1215
from bz2 import BZ2File
1316
from os.path import splitext
1417

15-
from packaging.version import Version
16-
1718
from nibabel.optpkg import optional_package
1819

19-
# is indexed_gzip present and modern?
20-
try:
21-
import indexed_gzip as igzip # type: ignore
20+
if ty.TYPE_CHECKING: # pragma: no cover
21+
from types import TracebackType
2222

23-
version = igzip.__version__
23+
import pyzstd
24+
from _typeshed import WriteableBuffer
2425

25-
HAVE_INDEXED_GZIP = True
26+
ModeRT = ty.Literal['r', 'rt']
27+
ModeRB = ty.Literal['rb']
28+
ModeWT = ty.Literal['w', 'wt']
29+
ModeWB = ty.Literal['wb']
30+
ModeR = ty.Union[ModeRT, ModeRB]
31+
ModeW = ty.Union[ModeWT, ModeWB]
32+
Mode = ty.Union[ModeR, ModeW]
2633

27-
# < 0.7 - no good
28-
if Version(version) < Version('0.7.0'):
29-
warnings.warn(f'indexed_gzip is present, but too old (>= 0.7.0 required): {version})')
30-
HAVE_INDEXED_GZIP = False
31-
# >= 0.8 SafeIndexedGzipFile renamed to IndexedGzipFile
32-
elif Version(version) < Version('0.8.0'):
33-
IndexedGzipFile = igzip.SafeIndexedGzipFile
34-
else:
35-
IndexedGzipFile = igzip.IndexedGzipFile
36-
del igzip, version
34+
OpenerDef = tuple[ty.Callable[..., io.IOBase], tuple[str, ...]]
35+
else:
36+
pyzstd = optional_package('pyzstd')[0]
37+
38+
39+
@ty.runtime_checkable
40+
class Fileish(ty.Protocol):
41+
def read(self, size: int = -1, /) -> bytes:
42+
... # pragma: no cover
43+
44+
def write(self, b: bytes, /) -> int | None:
45+
... # pragma: no cover
46+
47+
48+
try:
49+
from indexed_gzip import IndexedGzipFile # type: ignore
3750

51+
HAVE_INDEXED_GZIP = True
3852
except ImportError:
3953
# nibabel.openers.IndexedGzipFile is imported by nibabel.volumeutils
4054
# to detect compressed file types, so we give a fallback value here.
@@ -49,35 +63,63 @@ class DeterministicGzipFile(gzip.GzipFile):
4963
to a modification time (``mtime``) of 0 seconds.
5064
"""
5165

52-
def __init__(self, filename=None, mode=None, compresslevel=9, fileobj=None, mtime=0):
53-
# These two guards are copied from
66+
def __init__(
67+
self,
68+
filename: str | None = None,
69+
mode: Mode | None = None,
70+
compresslevel: int = 9,
71+
fileobj: io.FileIO | None = None,
72+
mtime: int = 0,
73+
):
74+
if mode is None:
75+
mode = 'rb'
76+
modestr: str = mode
77+
78+
# These two guards are adapted from
5479
# https://github.com/python/cpython/blob/6ab65c6/Lib/gzip.py#L171-L174
55-
if mode and 'b' not in mode:
56-
mode += 'b'
80+
if 'b' not in modestr:
81+
modestr = f'{mode}b'
5782
if fileobj is None:
58-
fileobj = self.myfileobj = open(filename, mode or 'rb')
83+
if filename is None:
84+
raise TypeError('Must define either fileobj or filename')
85+
# Cast because GzipFile.myfileobj has type io.FileIO while open returns ty.IO
86+
fileobj = self.myfileobj = ty.cast(io.FileIO, open(filename, modestr))
5987
return super().__init__(
60-
filename='', mode=mode, compresslevel=compresslevel, fileobj=fileobj, mtime=mtime
88+
filename='',
89+
mode=modestr,
90+
compresslevel=compresslevel,
91+
fileobj=fileobj,
92+
mtime=mtime,
6193
)
6294

6395

64-
def _gzip_open(filename, mode='rb', compresslevel=9, mtime=0, keep_open=False):
96+
def _gzip_open(
97+
filename: str,
98+
mode: Mode = 'rb',
99+
compresslevel: int = 9,
100+
mtime: int = 0,
101+
keep_open: bool = False,
102+
) -> gzip.GzipFile:
103+
104+
if not HAVE_INDEXED_GZIP or mode != 'rb':
105+
gzip_file = DeterministicGzipFile(filename, mode, compresslevel, mtime=mtime)
65106

66107
# use indexed_gzip if possible for faster read access. If keep_open ==
67108
# True, we tell IndexedGzipFile to keep the file handle open. Otherwise
68109
# the IndexedGzipFile will close/open the file on each read.
69-
if HAVE_INDEXED_GZIP and mode == 'rb':
70-
gzip_file = IndexedGzipFile(filename, drop_handles=not keep_open)
71-
72-
# Fall-back to built-in GzipFile
73110
else:
74-
gzip_file = DeterministicGzipFile(filename, mode, compresslevel, mtime=mtime)
111+
gzip_file = IndexedGzipFile(filename, drop_handles=not keep_open)
75112

76113
return gzip_file
77114

78115

79-
def _zstd_open(filename, mode='r', *, level_or_option=None, zstd_dict=None):
80-
pyzstd = optional_package('pyzstd')[0]
116+
def _zstd_open(
117+
filename: str,
118+
mode: Mode = 'r',
119+
*,
120+
level_or_option: int | dict | None = None,
121+
zstd_dict: pyzstd.ZstdDict | None = None,
122+
) -> pyzstd.ZstdFile:
81123
return pyzstd.ZstdFile(filename, mode, level_or_option=level_or_option, zstd_dict=zstd_dict)
82124

83125

@@ -104,7 +146,7 @@ class Opener:
104146
gz_def = (_gzip_open, ('mode', 'compresslevel', 'mtime', 'keep_open'))
105147
bz2_def = (BZ2File, ('mode', 'buffering', 'compresslevel'))
106148
zstd_def = (_zstd_open, ('mode', 'level_or_option', 'zstd_dict'))
107-
compress_ext_map = {
149+
compress_ext_map: dict[str | None, OpenerDef] = {
108150
'.gz': gz_def,
109151
'.bz2': bz2_def,
110152
'.zst': zstd_def,
@@ -121,19 +163,19 @@ class Opener:
121163
'w': default_zst_compresslevel,
122164
}
123165
#: whether to ignore case looking for compression extensions
124-
compress_ext_icase = True
166+
compress_ext_icase: bool = True
167+
168+
fobj: io.IOBase
125169

126-
def __init__(self, fileish, *args, **kwargs):
127-
if self._is_fileobj(fileish):
170+
def __init__(self, fileish: str | io.IOBase, *args, **kwargs):
171+
if isinstance(fileish, (io.IOBase, Fileish)):
128172
self.fobj = fileish
129173
self.me_opened = False
130-
self._name = None
174+
self._name = getattr(fileish, 'name', None)
131175
return
132176
opener, arg_names = self._get_opener_argnames(fileish)
133177
# Get full arguments to check for mode and compresslevel
134-
full_kwargs = kwargs.copy()
135-
n_args = len(args)
136-
full_kwargs.update(dict(zip(arg_names[:n_args], args)))
178+
full_kwargs = {**kwargs, **dict(zip(arg_names, args))}
137179
# Set default mode
138180
if 'mode' not in full_kwargs:
139181
mode = 'rb'
@@ -155,7 +197,7 @@ def __init__(self, fileish, *args, **kwargs):
155197
self._name = fileish
156198
self.me_opened = True
157199

158-
def _get_opener_argnames(self, fileish):
200+
def _get_opener_argnames(self, fileish: str) -> OpenerDef:
159201
_, ext = splitext(fileish)
160202
if self.compress_ext_icase:
161203
ext = ext.lower()
@@ -168,16 +210,12 @@ def _get_opener_argnames(self, fileish):
168210
return self.compress_ext_map[ext]
169211
return self.compress_ext_map[None]
170212

171-
def _is_fileobj(self, obj):
172-
"""Is `obj` a file-like object?"""
173-
return hasattr(obj, 'read') and hasattr(obj, 'write')
174-
175213
@property
176-
def closed(self):
214+
def closed(self) -> bool:
177215
return self.fobj.closed
178216

179217
@property
180-
def name(self):
218+
def name(self) -> str | None:
181219
"""Return ``self.fobj.name`` or self._name if not present
182220
183221
self._name will be None if object was created with a fileobj, otherwise
@@ -186,42 +224,53 @@ def name(self):
186224
return self._name
187225

188226
@property
189-
def mode(self):
190-
return self.fobj.mode
227+
def mode(self) -> str:
228+
# Check and raise our own error for type narrowing purposes
229+
if hasattr(self.fobj, 'mode'):
230+
return self.fobj.mode
231+
raise AttributeError(f'{self.fobj.__class__.__name__} has no attribute "mode"')
191232

192-
def fileno(self):
233+
def fileno(self) -> int:
193234
return self.fobj.fileno()
194235

195-
def read(self, *args, **kwargs):
196-
return self.fobj.read(*args, **kwargs)
236+
def read(self, size: int = -1, /) -> bytes:
237+
return self.fobj.read(size)
197238

198-
def readinto(self, *args, **kwargs):
199-
return self.fobj.readinto(*args, **kwargs)
239+
def readinto(self, buffer: WriteableBuffer, /) -> int | None:
240+
# Check and raise our own error for type narrowing purposes
241+
if hasattr(self.fobj, 'readinto'):
242+
return self.fobj.readinto(buffer)
243+
raise AttributeError(f'{self.fobj.__class__.__name__} has no attribute "readinto"')
200244

201-
def write(self, *args, **kwargs):
202-
return self.fobj.write(*args, **kwargs)
245+
def write(self, b: bytes, /) -> int | None:
246+
return self.fobj.write(b)
203247

204-
def seek(self, *args, **kwargs):
205-
return self.fobj.seek(*args, **kwargs)
248+
def seek(self, pos: int, whence: int = 0, /) -> int:
249+
return self.fobj.seek(pos, whence)
206250

207-
def tell(self, *args, **kwargs):
208-
return self.fobj.tell(*args, **kwargs)
251+
def tell(self, /) -> int:
252+
return self.fobj.tell()
209253

210-
def close(self, *args, **kwargs):
211-
return self.fobj.close(*args, **kwargs)
254+
def close(self, /) -> None:
255+
return self.fobj.close()
212256

213-
def __iter__(self):
257+
def __iter__(self) -> ty.Iterator[bytes]:
214258
return iter(self.fobj)
215259

216-
def close_if_mine(self):
260+
def close_if_mine(self) -> None:
217261
"""Close ``self.fobj`` iff we opened it in the constructor"""
218262
if self.me_opened:
219263
self.close()
220264

221-
def __enter__(self):
265+
def __enter__(self) -> Opener:
222266
return self
223267

224-
def __exit__(self, exc_type, exc_val, exc_tb):
268+
def __exit__(
269+
self,
270+
exc_type: type[BaseException] | None,
271+
exc_val: BaseException | None,
272+
exc_tb: TracebackType | None,
273+
) -> None:
225274
self.close_if_mine()
226275

227276

nibabel/tests/test_openers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, message):
3838
def write(self):
3939
pass
4040

41-
def read(self):
41+
def read(self, size=-1, /):
4242
return self.message
4343

4444

0 commit comments

Comments
 (0)