2
2
import shutil
3
3
import tarfile
4
4
from pathlib import Path
5
- from typing import List , Optional , Dict , Any
5
+ from typing import Any , Dict , List , Optional
6
6
7
7
import requests
8
8
from huggingface_hub import snapshot_download
9
9
from huggingface_hub .utils import RepositoryNotFoundError
10
- from tqdm import tqdm
11
10
from loguru import logger
11
+ from tqdm import tqdm
12
12
13
13
14
14
class ModelManagement :
@@ -42,7 +42,9 @@ def _get_model_description(cls, model_name: str) -> Dict[str, Any]:
42
42
raise ValueError (f"Model { model_name } is not supported in { cls .__name__ } ." )
43
43
44
44
@classmethod
45
- def download_file_from_gcs (cls , url : str , output_path : str , show_progress : bool = True ) -> str :
45
+ def download_file_from_gcs (
46
+ cls , url : str , output_path : str , show_progress : bool = True
47
+ ) -> str :
46
48
"""
47
49
Downloads a file from Google Cloud Storage.
48
50
@@ -71,12 +73,17 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
71
73
72
74
# Warn if the total size is zero
73
75
if total_size_in_bytes == 0 :
74
- print (f"Warning: Content-length header is missing or zero in the response from { url } ." )
76
+ print (
77
+ f"Warning: Content-length header is missing or zero in the response from { url } ."
78
+ )
75
79
76
80
show_progress = total_size_in_bytes and show_progress
77
81
78
82
with tqdm (
79
- total = total_size_in_bytes , unit = "iB" , unit_scale = True , disable = not show_progress
83
+ total = total_size_in_bytes ,
84
+ unit = "iB" ,
85
+ unit_scale = True ,
86
+ disable = not show_progress ,
80
87
) as progress_bar :
81
88
with open (output_path , "wb" ) as file :
82
89
for chunk in response .iter_content (chunk_size = 1024 ):
@@ -156,7 +163,9 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
156
163
return cache_dir
157
164
158
165
@classmethod
159
- def retrieve_model_gcs (cls , model_name : str , source_url : str , cache_dir : str ) -> Path :
166
+ def retrieve_model_gcs (
167
+ cls , model_name : str , source_url : str , cache_dir : str
168
+ ) -> Path :
160
169
fast_model_name = f"fast-{ model_name .split ('/' )[- 1 ]} "
161
170
162
171
cache_tmp_dir = Path (cache_dir ) / "tmp"
@@ -182,8 +191,12 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) ->
182
191
output_path = str (model_tar_gz ),
183
192
)
184
193
185
- cls .decompress_to_cache (targz_path = str (model_tar_gz ), cache_dir = str (cache_tmp_dir ))
186
- assert model_tmp_dir .exists (), f"Could not find { model_tmp_dir } in { cache_tmp_dir } "
194
+ cls .decompress_to_cache (
195
+ targz_path = str (model_tar_gz ), cache_dir = str (cache_tmp_dir )
196
+ )
197
+ assert (
198
+ model_tmp_dir .exists ()
199
+ ), f"Could not find { model_tmp_dir } in { cache_tmp_dir } "
187
200
188
201
model_tar_gz .unlink ()
189
202
# Rename from tmp to final name is atomic
0 commit comments