Skip to content

Commit c3ffd0b

Browse files
rgaudinbenoit74
authored andcommitted
Using beartype on all of scraperlib
This enables beartype on all of scraperlib, raising exceptions on all calls to our API that violates the requested types (and function returning incorrect types) Changes: - removed tests purposedly testing incorrect input types - fixed some return types or call params to match intent - logger console arg expecting io.StringIO - turned a couple NamedTuple into dataclass to ease with type declarations - Removed some unreachable code that was expecting invalid types - Introducing new protocols for IO-based type inputs, based on typeshed's (as protocols, those are ignored in test/coverage) - Image-related IO are declared as io.BytesIO instead of `IO[bytes]`. Somewhat equivalent but typing.IO is strongly discouraged everywhere. Cannot harmonize with rest of our code base as we pass this to Pillow. - Same goes for logger which eventually accepts TextIO - stream_file behavior changed a bit. Code assumed that if fpath is not there we want byte_stream. Given it's not properly tested, I changed it to accept both fpath and byte_stream simultaneously. I believe we should change the API to have a single input that supports the byte stream and the path and adapt behavior. We should do that through the code base though so that would be separate. I'll open a ticket if we agree on going this way
1 parent a0a225b commit c3ffd0b

20 files changed

+156
-145
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ ban-relative-imports = "all"
254254
# _libkiwix mimics libkiwix C++ code, names obey C++ conventions
255255
"src/zimscraperlib/zim/_libkiwix.py" = ["N802", "N803", "N806"]
256256
# beartype must be first
257-
"src/zimscraperlib/zim/__init__.py" = ["E402"]
257+
"src/zimscraperlib/__init__.py" = ["E402"]
258258

259259
[tool.pytest.ini_options]
260260
minversion = "7.3"
@@ -278,6 +278,7 @@ exclude_lines = [
278278
"no cov",
279279
"if __name__ == .__main__.:",
280280
"if TYPE_CHECKING:",
281+
"class .*Protocol.*",
281282
]
282283

283284
[tool.pyright]

src/zimscraperlib/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import logging as stdlogging
55
import os
66

7+
from beartype.claw import beartype_this_package
8+
9+
beartype_this_package()
10+
711
from zimscraperlib.constants import NAME
812
from zimscraperlib.logging import getLogger
913

src/zimscraperlib/download.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77
import subprocess
88
from concurrent.futures import Future, ThreadPoolExecutor
9-
from typing import IO, ClassVar
9+
from typing import ClassVar
1010

1111
import requests
1212
import requests.adapters
@@ -16,6 +16,7 @@
1616

1717
from zimscraperlib import logger
1818
from zimscraperlib.constants import DEFAULT_WEB_REQUESTS_TIMEOUT
19+
from zimscraperlib.typing import SupportsSeekableWrite, SupportsWrite
1920

2021

2122
class YoutubeDownloader:
@@ -59,11 +60,10 @@ def download(
5960
future = self.executor.submit(self._run_youtube_dl, url, options or {})
6061
if not wait:
6162
return future
62-
if not future.exception():
63-
# return the result
64-
return future.result() # pyright: ignore
65-
# raise the exception
66-
raise future.exception() # pyright: ignore
63+
exc = future.exception()
64+
if isinstance(exc, BaseException):
65+
raise exc
66+
return True
6767

6868

6969
class YoutubeConfig(dict):
@@ -176,7 +176,7 @@ def get_session(max_retries: int | None = 5) -> requests.Session:
176176
def stream_file(
177177
url: str,
178178
fpath: pathlib.Path | None = None,
179-
byte_stream: IO[bytes] | None = None,
179+
byte_stream: SupportsWrite[bytes] | SupportsSeekableWrite[bytes] | None = None,
180180
block_size: int | None = 1024,
181181
proxies: dict[str, str] | None = None,
182182
max_retries: int | None = 5,
@@ -216,24 +216,25 @@ def stream_file(
216216

217217
total_downloaded = 0
218218
if fpath is not None:
219-
fp = open(fpath, "wb")
220-
elif (
221-
byte_stream is not None
222-
): # pragma: no branch (we use a precise condition to help type checker)
223-
fp = byte_stream
219+
fpath_handler = open(fpath, "wb")
220+
else:
221+
fpath_handler = None
224222

225223
for data in resp.iter_content(block_size):
226224
total_downloaded += len(data)
227-
fp.write(data)
225+
if fpath_handler:
226+
fpath_handler.write(data)
227+
if byte_stream:
228+
byte_stream.write(data)
228229

229230
# stop downloading/reading if we're just testing first block
230231
if only_first_block:
231232
break
232233

233234
logger.debug(f"Downloaded {total_downloaded} bytes from {url}")
234235

235-
if fpath:
236-
fp.close()
237-
else:
238-
fp.seek(0)
236+
if fpath_handler:
237+
fpath_handler.close()
238+
elif isinstance(byte_stream, SupportsSeekableWrite) and byte_stream.seekable():
239+
byte_stream.seek(0)
239240
return total_downloaded, resp.headers

src/zimscraperlib/i18n.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ def iso_types(self) -> list[str]:
5959
return self["iso_types"]
6060

6161
@property
62-
def query(self) -> list[str]:
62+
def query(self) -> str:
6363
"""Query issued for these language details"""
6464
return self["query"]
6565

6666
@property
67-
def querytype(self) -> list[str]:
67+
def querytype(self) -> str:
6868
"""Type of query issued to retrieve language details"""
6969
return self["querytype"]
7070

src/zimscraperlib/image/conversion.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import io
77
import pathlib
8-
from typing import IO
98

109
import cairosvg.svg
1110
from PIL.Image import open as pilopen
@@ -17,9 +16,9 @@
1716

1817

1918
def convert_image(
20-
src: pathlib.Path | IO[bytes],
21-
dst: pathlib.Path | IO[bytes],
22-
**params: str,
19+
src: pathlib.Path | io.BytesIO,
20+
dst: pathlib.Path | io.BytesIO,
21+
**params: str | None,
2322
) -> None:
2423
"""convert an image file from one format to another
2524
params: Image.save() parameters. Depends on dest format.
@@ -31,7 +30,9 @@ def convert_image(
3130
to RGB. ex: RGB, ARGB, CMYK (and other PIL colorspaces)"""
3231

3332
colorspace = params.get("colorspace") # requested colorspace
34-
fmt = params.pop("fmt").upper() if "fmt" in params else None # requested format
33+
fmt = (
34+
str(params.pop("fmt")).upper() if params.get("fmt") else None
35+
) # requested format
3536
if not fmt:
3637
fmt = format_for(dst)
3738
if not fmt:
@@ -44,7 +45,7 @@ def convert_image(
4445

4546
def convert_svg2png(
4647
src: str | pathlib.Path | io.BytesIO,
47-
dst: pathlib.Path | IO[bytes],
48+
dst: pathlib.Path | io.BytesIO,
4849
width: int | None = None,
4950
height: int | None = None,
5051
):

src/zimscraperlib/image/optimization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,14 @@ def optimize_webp(
210210
else:
211211
try:
212212
save_image(webp_image, dst, fmt="WEBP", **params)
213-
except Exception as exc:
213+
except Exception as exc: # pragma: no cover
214214
if (
215215
isinstance(src, pathlib.Path)
216216
and isinstance(dst, pathlib.Path)
217217
and src.resolve() != dst.resolve()
218218
and dst.exists()
219219
):
220-
dst.unlink() # pragma: no cover
220+
dst.unlink()
221221
raise exc
222222
return dst
223223

src/zimscraperlib/image/probing.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import io
88
import pathlib
99
import re
10-
from typing import IO
1110

1211
import colorthief
1312
import PIL.Image
@@ -55,7 +54,7 @@ def is_hex_color(text: str) -> bool:
5554

5655

5756
def format_for(
58-
src: pathlib.Path | IO[bytes],
57+
src: pathlib.Path | io.BytesIO,
5958
*,
6059
from_suffix: bool = True,
6160
) -> str | None:
@@ -95,7 +94,7 @@ def format_for(
9594

9695

9796
def is_valid_image(
98-
image: pathlib.Path | IO[bytes] | bytes,
97+
image: pathlib.Path | bytes | io.BytesIO,
9998
imformat: str,
10099
size: tuple[int, int] | None = None,
101100
) -> bool:

src/zimscraperlib/image/utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
# vim: ai ts=4 sts=4 et sw=4 nu
33
from __future__ import annotations
44

5+
import io
56
import pathlib
6-
from typing import IO
7+
from typing import IO, Any
78

89
from PIL.Image import Image
910
from PIL.ImageFile import ImageFile
1011

1112

1213
def save_image(
1314
src: Image | ImageFile,
14-
dst: pathlib.Path | IO[bytes],
15+
dst: pathlib.Path | IO[bytes] | io.BytesIO,
1516
fmt: str,
16-
**params: str,
17+
**params: Any,
1718
) -> None:
1819
"""PIL.Image.save() wrapper setting default parameters"""
1920
args = {"JPEG": {"quality": 100}, "PNG": {}}.get(fmt, {})

src/zimscraperlib/logging.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import io
67
import logging
78
import pathlib
89
import sys
@@ -22,7 +23,7 @@
2223
def getLogger( # noqa: N802 (intentionally matches the stdlib getLogger name)
2324
name: str,
2425
level: int = logging.INFO,
25-
console: TextIO | None = sys.stdout,
26+
console: TextIO | io.StringIO | None = sys.stdout,
2627
log_format: str | None = DEFAULT_FORMAT,
2728
file: pathlib.Path | None = None,
2829
file_level: int | None = None,

src/zimscraperlib/rewriting/url_rewriting.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141
from __future__ import annotations
4242

4343
import re
44+
from dataclasses import dataclass
4445
from pathlib import PurePosixPath
45-
from typing import ClassVar, NamedTuple
46+
from typing import ClassVar
4647
from urllib.parse import quote, unquote, urljoin, urlsplit
4748

4849
import idna
@@ -51,7 +52,8 @@
5152
from zimscraperlib.rewriting.rules import FUZZY_RULES
5253

5354

54-
class AdditionalRule(NamedTuple):
55+
@dataclass
56+
class AdditionalRule:
5557
match: re.Pattern[str]
5658
replace: str
5759

@@ -147,7 +149,8 @@ def check_validity(cls, value: str) -> None:
147149
raise ValueError(f"Unexpected password in value: {value} {parts.password}")
148150

149151

150-
class RewriteResult(NamedTuple):
152+
@dataclass
153+
class RewriteResult:
151154
absolute_url: str
152155
rewriten_url: str
153156
zim_path: ZimPath | None
@@ -382,9 +385,6 @@ def normalize(cls, url: HttpUrl) -> ZimPath:
382385
passed to python-libzim for UTF-8 encoding.
383386
"""
384387

385-
if not isinstance(url, HttpUrl):
386-
raise ValueError("Bad argument type passed, HttpUrl expected")
387-
388388
url_parts = urlsplit(url.value)
389389

390390
if not url_parts.hostname:

src/zimscraperlib/types.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,10 @@ def get_mime_for_name(
4747
MIME only guessed from file extension and not actual content.
4848
4949
Filename with no extension are mapped to `no_ext_to`"""
50-
try:
51-
filename = pathlib.Path(filename)
52-
if not filename.suffix:
53-
return no_ext_to
54-
return mimetypes.guess_type(f"{filename.stem}{filename.suffix}")[0] or fallback
55-
except Exception:
56-
return fallback
50+
filename = pathlib.Path(filename)
51+
if not filename.suffix:
52+
return no_ext_to
53+
return mimetypes.guess_type(f"{filename.stem}{filename.suffix}")[0] or fallback
5754

5855

5956
def init_types():

src/zimscraperlib/typing.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from typing import Any, NamedTuple
4+
from dataclasses import dataclass
5+
from typing import Any, Protocol, TypeVar, runtime_checkable
56

7+
_T_co = TypeVar("_T_co", covariant=True)
8+
_T_contra = TypeVar("_T_contra", contravariant=True)
69

7-
class Callback(NamedTuple):
10+
11+
@dataclass
12+
class Callback:
813
func: Callable
914
args: tuple[Any, ...] | None = None
1015
kwargs: dict[str, Any] | None = None
@@ -24,3 +29,34 @@ def call_with(self, *args, **kwargs):
2429

2530
def call(self):
2631
self.call_with(*self.get_args(), **self.get_kwargs())
32+
33+
34+
@runtime_checkable
35+
class SupportsWrite(Protocol[_T_contra]):
36+
"""Protocol exposing an expected write method"""
37+
38+
def write(self, s: _T_contra, /) -> object: ...
39+
40+
41+
@runtime_checkable
42+
class SupportsRead(Protocol[_T_co]):
43+
def read(self, length: int = ..., /) -> _T_co: ...
44+
45+
46+
@runtime_checkable
47+
class SupportsSeeking(Protocol):
48+
def seekable(self) -> bool: ...
49+
50+
def seek(self, target: int, whence: int = 0) -> int: ...
51+
52+
def tell(self) -> int: ...
53+
54+
def truncate(self, pos: int) -> int: ...
55+
56+
57+
@runtime_checkable
58+
class SupportsSeekableRead(SupportsRead[_T_co], SupportsSeeking, Protocol): ...
59+
60+
61+
@runtime_checkable
62+
class SupportsSeekableWrite(SupportsWrite[_T_contra], SupportsSeeking, Protocol): ...

0 commit comments

Comments
 (0)