1
1
import os
2
+ import time
2
3
import shutil
3
4
import tarfile
4
5
from pathlib import Path
@@ -42,9 +43,7 @@ def _get_model_description(cls, model_name: str) -> Dict[str, Any]:
42
43
raise ValueError (f"Model { model_name } is not supported in { cls .__name__ } ." )
43
44
44
45
@classmethod
45
- def download_file_from_gcs (
46
- cls , url : str , output_path : str , show_progress : bool = True
47
- ) -> str :
46
+ def download_file_from_gcs (cls , url : str , output_path : str , show_progress : bool = True ) -> str :
48
47
"""
49
48
Downloads a file from Google Cloud Storage.
50
49
@@ -73,9 +72,7 @@ def download_file_from_gcs(
73
72
74
73
# Warn if the total size is zero
75
74
if total_size_in_bytes == 0 :
76
- print (
77
- f"Warning: Content-length header is missing or zero in the response from { url } ."
78
- )
75
+ print (f"Warning: Content-length header is missing or zero in the response from { url } ." )
79
76
80
77
show_progress = total_size_in_bytes and show_progress
81
78
@@ -163,9 +160,7 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
163
160
return cache_dir
164
161
165
162
@classmethod
166
- def retrieve_model_gcs (
167
- cls , model_name : str , source_url : str , cache_dir : str
168
- ) -> Path :
163
+ def retrieve_model_gcs (cls , model_name : str , source_url : str , cache_dir : str ) -> Path :
169
164
fast_model_name = f"fast-{ model_name .split ('/' )[- 1 ]} "
170
165
171
166
cache_tmp_dir = Path (cache_dir ) / "tmp"
@@ -191,12 +186,8 @@ def retrieve_model_gcs(
191
186
output_path = str (model_tar_gz ),
192
187
)
193
188
194
- cls .decompress_to_cache (
195
- targz_path = str (model_tar_gz ), cache_dir = str (cache_tmp_dir )
196
- )
197
- assert (
198
- model_tmp_dir .exists ()
199
- ), f"Could not find { model_tmp_dir } in { cache_tmp_dir } "
189
+ cls .decompress_to_cache (targz_path = str (model_tar_gz ), cache_dir = str (cache_tmp_dir ))
190
+ assert model_tmp_dir .exists (), f"Could not find { model_tmp_dir } in { cache_tmp_dir } "
200
191
201
192
model_tar_gz .unlink ()
202
193
# Rename from tmp to final name is atomic
@@ -205,7 +196,7 @@ def retrieve_model_gcs(
205
196
return model_dir
206
197
207
198
@classmethod
208
- def download_model (cls , model : Dict [str , Any ], cache_dir : Path , ** kwargs ) -> Path :
199
+ def download_model (cls , model : Dict [str , Any ], cache_dir : Path , retries = 3 , ** kwargs ) -> Path :
209
200
"""
210
201
Downloads a model from HuggingFace Hub or Google Cloud Storage.
211
202
@@ -225,6 +216,7 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path, **kwargs) -> Pat
225
216
}
226
217
```
227
218
cache_dir (str): The path to the cache directory.
219
+ retries: (int): The number of times to retry (including the first attempt)
228
220
229
221
Returns:
230
222
Path: The path to the downloaded model directory.
@@ -233,26 +225,38 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path, **kwargs) -> Pat
233
225
hf_source = model .get ("sources" , {}).get ("hf" )
234
226
url_source = model .get ("sources" , {}).get ("url" )
235
227
236
- if hf_source :
237
- extra_patterns = [model ["model_file" ]]
238
- extra_patterns .extend (model .get ("additional_files" , []))
239
-
240
- try :
241
- return Path (
242
- cls .download_files_from_huggingface (
243
- hf_source ,
244
- cache_dir = str (cache_dir ),
245
- extra_patterns = extra_patterns ,
246
- local_files_only = kwargs .get ("local_files_only" , False ),
228
+ sleep = 3.0
229
+ while retries > 0 :
230
+ retries -= 1
231
+
232
+ if hf_source :
233
+ extra_patterns = [model ["model_file" ]]
234
+ extra_patterns .extend (model .get ("additional_files" , []))
235
+
236
+ try :
237
+ return Path (
238
+ cls .download_files_from_huggingface (
239
+ hf_source ,
240
+ cache_dir = str (cache_dir ),
241
+ extra_patterns = extra_patterns ,
242
+ local_files_only = kwargs .get ("local_files_only" , False ),
243
+ )
247
244
)
248
- )
249
- except (EnvironmentError , RepositoryNotFoundError , ValueError ) as e :
250
- logger .error (
251
- f"Could not download model from HuggingFace: { e } "
252
- "Falling back to other sources."
253
- )
254
-
255
- if url_source :
256
- return cls .retrieve_model_gcs (model ["model" ], url_source , str (cache_dir ))
245
+ except (EnvironmentError , RepositoryNotFoundError , ValueError ) as e :
246
+ logger .error (
247
+ f"Could not download model from HuggingFace: { e } "
248
+ "Falling back to other sources."
249
+ )
250
+ if url_source :
251
+ try :
252
+ return cls .retrieve_model_gcs (model ["model" ], url_source , str (cache_dir ))
253
+ except Exception :
254
+ logger .error (f"Could not download model from url: { url_source } " )
255
+
256
+ logger .error (
257
+ f"Could not download model from either source, sleeping for { sleep } seconds, { retries } retries left."
258
+ )
259
+ time .sleep (sleep )
260
+ sleep *= 3
257
261
258
262
raise ValueError (f"Could not download model { model ['model' ]} from any source." )
0 commit comments