1
- #!/usr/bin/env python3
2
- # vim: ai ts=4 sts=4 et sw=4 nu
3
-
4
1
from __future__ import annotations
5
2
6
3
import pathlib
7
4
import subprocess
8
5
from concurrent .futures import Future , ThreadPoolExecutor
9
- from typing import IO , ClassVar
6
+ from typing import IO , Any , ClassVar
10
7
11
8
import requests
12
9
import requests .adapters
13
10
import requests .structures
14
11
import urllib3 .util
15
- import yt_dlp as youtube_dl
12
+ import yt_dlp as youtube_dl # pyright: ignore[reportMissingTypeStubs]
16
13
17
14
from zimscraperlib import logger
18
15
@@ -29,24 +26,24 @@ def __init__(self, threads: int | None = 1) -> None:
29
26
def __enter__ (self ):
30
27
return self
31
28
32
- def __exit__ (self , * args ):
29
+ def __exit__ (self , * _ : Any ):
33
30
self .shutdown ()
34
31
35
32
def shutdown (self ) -> None :
36
33
"""shuts down the executor, awaiting completion"""
37
34
self .executor .shutdown (wait = True )
38
35
39
- def _run_youtube_dl (self , url : str , options : dict ) -> None :
36
+ def _run_youtube_dl (self , url : str , options : dict [ str , Any ] ) -> None :
40
37
with youtube_dl .YoutubeDL (options ) as ydl :
41
- ydl .download ([url ])
38
+ ydl .download ([url ]) # pyright: ignore[reportUnknownMemberType]
42
39
43
40
def download (
44
41
self ,
45
42
url : str ,
46
- options : dict | None ,
43
+ options : dict [ str , Any ] | None ,
47
44
* ,
48
45
wait : bool | None = True ,
49
- ) -> bool | Future :
46
+ ) -> bool | Future [ Any ] :
50
47
"""Downloads video using initialized executor.
51
48
52
49
url: URL or Video ID
@@ -65,7 +62,7 @@ def download(
65
62
raise future .exception () # pyright: ignore
66
63
67
64
68
- class YoutubeConfig (dict ):
65
+ class YoutubeConfig (dict [ str , str | bool | int | None ] ):
69
66
options : ClassVar [dict [str , str | bool | int | None ]] = {}
70
67
defaults : ClassVar [dict [str , str | bool | int | None ]] = {
71
68
"writethumbnail" : True ,
@@ -81,7 +78,7 @@ class YoutubeConfig(dict):
81
78
"outtmpl" : "video.%(ext)s" ,
82
79
}
83
80
84
- def __init__ (self , ** kwargs ):
81
+ def __init__ (self , ** kwargs : str | bool | int | None ):
85
82
super ().__init__ (self , ** type (self ).defaults )
86
83
self .update (self .options )
87
84
self .update (kwargs )
@@ -91,7 +88,7 @@ def get_options(
91
88
cls ,
92
89
target_dir : pathlib .Path | None = None ,
93
90
filepath : pathlib .Path | None = None ,
94
- ** options ,
91
+ ** options : str | bool | int | None ,
95
92
):
96
93
if "outtmpl" not in options :
97
94
outtmpl = cls .options .get ("outtmpl" , cls .defaults ["outtmpl" ])
@@ -142,9 +139,10 @@ def save_large_file(url: str, fpath: pathlib.Path) -> None:
142
139
)
143
140
144
141
145
- def _get_retry_adapter (
142
+ def get_retry_adapter (
146
143
max_retries : int | None = 5 ,
147
144
) -> requests .adapters .BaseAdapter :
145
+ """A requests adapter to automatically retry on known HTTP status that can be"""
148
146
retries = urllib3 .util .retry .Retry (
149
147
total = max_retries , # total number of retries
150
148
connect = max_retries , # connection errors
@@ -168,7 +166,7 @@ def _get_retry_adapter(
168
166
def get_session (max_retries : int | None = 5 ) -> requests .Session :
169
167
"""Session to hold cookies and connection pool together"""
170
168
session = requests .Session ()
171
- session .mount ("http" , _get_retry_adapter (max_retries )) # tied to http and https
169
+ session .mount ("http" , get_retry_adapter (max_retries )) # tied to http and https
172
170
return session
173
171
174
172
@@ -198,7 +196,11 @@ def stream_file(
198
196
Returns the total number of bytes downloaded and the response headers"""
199
197
200
198
# if no output option is supplied
201
- if fpath is None and byte_stream is None :
199
+ if fpath is not None :
200
+ fp = open (fpath , "wb" )
201
+ elif byte_stream is not None :
202
+ fp = byte_stream
203
+ else :
202
204
raise ValueError ("Either file path or a bytesIO object is needed" )
203
205
204
206
if not session :
@@ -212,12 +214,6 @@ def stream_file(
212
214
resp .raise_for_status ()
213
215
214
216
total_downloaded = 0
215
- if fpath is not None :
216
- fp = open (fpath , "wb" )
217
- elif (
218
- byte_stream is not None
219
- ): # pragma: no branch (we use a precise condition to help type checker)
220
- fp = byte_stream
221
217
222
218
for data in resp .iter_content (block_size ):
223
219
total_downloaded += len (data )
0 commit comments