|
6 | 6 | import pathlib |
7 | 7 | import subprocess |
8 | 8 | from concurrent.futures import Future, ThreadPoolExecutor |
9 | | -from typing import IO, ClassVar |
| 9 | +from typing import ClassVar |
10 | 10 |
|
11 | 11 | import requests |
12 | 12 | import requests.adapters |
|
16 | 16 |
|
17 | 17 | from zimscraperlib import logger |
18 | 18 | from zimscraperlib.constants import DEFAULT_WEB_REQUESTS_TIMEOUT |
| 19 | +from zimscraperlib.typing import SupportsSeekableWrite, SupportsWrite |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class YoutubeDownloader: |
@@ -59,11 +60,10 @@ def download( |
59 | 60 | future = self.executor.submit(self._run_youtube_dl, url, options or {}) |
60 | 61 | if not wait: |
61 | 62 | return future |
62 | | - if not future.exception(): |
63 | | - # return the result |
64 | | - return future.result() # pyright: ignore |
65 | | - # raise the exception |
66 | | - raise future.exception() # pyright: ignore |
| 63 | + exc = future.exception() |
| 64 | + if isinstance(exc, BaseException): |
| 65 | + raise exc |
| 66 | + return True |
67 | 67 |
|
68 | 68 |
|
69 | 69 | class YoutubeConfig(dict): |
@@ -176,7 +176,7 @@ def get_session(max_retries: int | None = 5) -> requests.Session: |
176 | 176 | def stream_file( |
177 | 177 | url: str, |
178 | 178 | fpath: pathlib.Path | None = None, |
179 | | - byte_stream: IO[bytes] | None = None, |
| 179 | + byte_stream: SupportsWrite[bytes] | SupportsSeekableWrite[bytes] | None = None, |
180 | 180 | block_size: int | None = 1024, |
181 | 181 | proxies: dict[str, str] | None = None, |
182 | 182 | max_retries: int | None = 5, |
@@ -216,24 +216,25 @@ def stream_file( |
216 | 216 |
|
217 | 217 | total_downloaded = 0 |
218 | 218 | if fpath is not None: |
219 | | - fp = open(fpath, "wb") |
220 | | - elif ( |
221 | | - byte_stream is not None |
222 | | - ): # pragma: no branch (we use a precise condition to help type checker) |
223 | | - fp = byte_stream |
| 219 | + fpath_handler = open(fpath, "wb") |
| 220 | + else: |
| 221 | + fpath_handler = None |
224 | 222 |
|
225 | 223 | for data in resp.iter_content(block_size): |
226 | 224 | total_downloaded += len(data) |
227 | | - fp.write(data) |
| 225 | + if fpath_handler: |
| 226 | + fpath_handler.write(data) |
| 227 | + if byte_stream: |
| 228 | + byte_stream.write(data) |
228 | 229 |
|
229 | 230 | # stop downloading/reading if we're just testing first block |
230 | 231 | if only_first_block: |
231 | 232 | break |
232 | 233 |
|
233 | 234 | logger.debug(f"Downloaded {total_downloaded} bytes from {url}") |
234 | 235 |
|
235 | | - if fpath: |
236 | | - fp.close() |
237 | | - else: |
238 | | - fp.seek(0) |
| 236 | + if fpath_handler: |
| 237 | + fpath_handler.close() |
| 238 | + elif isinstance(byte_stream, SupportsSeekableWrite) and byte_stream.seekable(): |
| 239 | + byte_stream.seek(0) |
239 | 240 | return total_downloaded, resp.headers |
0 commit comments