|
6 | 6 | import pathlib
|
7 | 7 | import subprocess
|
8 | 8 | from concurrent.futures import Future, ThreadPoolExecutor
|
9 |
| -from typing import IO, ClassVar |
| 9 | +from typing import ClassVar |
10 | 10 |
|
11 | 11 | import requests
|
12 | 12 | import requests.adapters
|
|
16 | 16 |
|
17 | 17 | from zimscraperlib import logger
|
18 | 18 | from zimscraperlib.constants import DEFAULT_WEB_REQUESTS_TIMEOUT
|
| 19 | +from zimscraperlib.typing import SupportsSeekableWrite, SupportsWrite |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class YoutubeDownloader:
|
@@ -59,11 +60,10 @@ def download(
|
59 | 60 | future = self.executor.submit(self._run_youtube_dl, url, options or {})
|
60 | 61 | if not wait:
|
61 | 62 | return future
|
62 |
| - if not future.exception(): |
63 |
| - # return the result |
64 |
| - return future.result() # pyright: ignore |
65 |
| - # raise the exception |
66 |
| - raise future.exception() # pyright: ignore |
| 63 | + exc = future.exception() |
| 64 | + if isinstance(exc, BaseException): |
| 65 | + raise exc |
| 66 | + return True |
67 | 67 |
|
68 | 68 |
|
69 | 69 | class YoutubeConfig(dict):
|
@@ -176,7 +176,7 @@ def get_session(max_retries: int | None = 5) -> requests.Session:
|
176 | 176 | def stream_file(
|
177 | 177 | url: str,
|
178 | 178 | fpath: pathlib.Path | None = None,
|
179 |
| - byte_stream: IO[bytes] | None = None, |
| 179 | + byte_stream: SupportsWrite[bytes] | SupportsSeekableWrite[bytes] | None = None, |
180 | 180 | block_size: int | None = 1024,
|
181 | 181 | proxies: dict[str, str] | None = None,
|
182 | 182 | max_retries: int | None = 5,
|
@@ -216,24 +216,25 @@ def stream_file(
|
216 | 216 |
|
217 | 217 | total_downloaded = 0
|
218 | 218 | if fpath is not None:
|
219 |
| - fp = open(fpath, "wb") |
220 |
| - elif ( |
221 |
| - byte_stream is not None |
222 |
| - ): # pragma: no branch (we use a precise condition to help type checker) |
223 |
| - fp = byte_stream |
| 219 | + fpath_handler = open(fpath, "wb") |
| 220 | + else: |
| 221 | + fpath_handler = None |
224 | 222 |
|
225 | 223 | for data in resp.iter_content(block_size):
|
226 | 224 | total_downloaded += len(data)
|
227 |
| - fp.write(data) |
| 225 | + if fpath_handler: |
| 226 | + fpath_handler.write(data) |
| 227 | + if byte_stream: |
| 228 | + byte_stream.write(data) |
228 | 229 |
|
229 | 230 | # stop downloading/reading if we're just testing first block
|
230 | 231 | if only_first_block:
|
231 | 232 | break
|
232 | 233 |
|
233 | 234 | logger.debug(f"Downloaded {total_downloaded} bytes from {url}")
|
234 | 235 |
|
235 |
| - if fpath: |
236 |
| - fp.close() |
237 |
| - else: |
238 |
| - fp.seek(0) |
| 236 | + if fpath_handler: |
| 237 | + fpath_handler.close() |
| 238 | + elif isinstance(byte_stream, SupportsSeekableWrite) and byte_stream.seekable(): |
| 239 | + byte_stream.seek(0) |
239 | 240 | return total_downloaded, resp.headers
|
0 commit comments