Skip to content

Commit 3a8985b

Browse files
authored
new: add retry logic for model downloading (#293)
* new: add retry logic for model downloading * fix: add sleep
1 parent f0ff09c commit 3a8985b

File tree

1 file changed

+40
-36
lines changed

1 file changed

+40
-36
lines changed

fastembed/common/model_management.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import time
23
import shutil
34
import tarfile
45
from pathlib import Path
@@ -42,9 +43,7 @@ def _get_model_description(cls, model_name: str) -> Dict[str, Any]:
4243
raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.")
4344

4445
@classmethod
45-
def download_file_from_gcs(
46-
cls, url: str, output_path: str, show_progress: bool = True
47-
) -> str:
46+
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
4847
"""
4948
Downloads a file from Google Cloud Storage.
5049
@@ -73,9 +72,7 @@ def download_file_from_gcs(
7372

7473
# Warn if the total size is zero
7574
if total_size_in_bytes == 0:
76-
print(
77-
f"Warning: Content-length header is missing or zero in the response from {url}."
78-
)
75+
print(f"Warning: Content-length header is missing or zero in the response from {url}.")
7976

8077
show_progress = total_size_in_bytes and show_progress
8178

@@ -163,9 +160,7 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
163160
return cache_dir
164161

165162
@classmethod
166-
def retrieve_model_gcs(
167-
cls, model_name: str, source_url: str, cache_dir: str
168-
) -> Path:
163+
def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> Path:
169164
fast_model_name = f"fast-{model_name.split('/')[-1]}"
170165

171166
cache_tmp_dir = Path(cache_dir) / "tmp"
@@ -191,12 +186,8 @@ def retrieve_model_gcs(
191186
output_path=str(model_tar_gz),
192187
)
193188

194-
cls.decompress_to_cache(
195-
targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir)
196-
)
197-
assert (
198-
model_tmp_dir.exists()
199-
), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"
189+
cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
190+
assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"
200191

201192
model_tar_gz.unlink()
202193
# Rename from tmp to final name is atomic
@@ -205,7 +196,7 @@ def retrieve_model_gcs(
205196
return model_dir
206197

207198
@classmethod
208-
def download_model(cls, model: Dict[str, Any], cache_dir: Path, **kwargs) -> Path:
199+
def download_model(cls, model: Dict[str, Any], cache_dir: Path, retries=3, **kwargs) -> Path:
209200
"""
210201
Downloads a model from HuggingFace Hub or Google Cloud Storage.
211202
@@ -225,6 +216,7 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path, **kwargs) -> Pat
225216
}
226217
```
227218
cache_dir (str): The path to the cache directory.
219+
retries: (int): The number of times to retry (including the first attempt)
228220
229221
Returns:
230222
Path: The path to the downloaded model directory.
@@ -233,26 +225,38 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path, **kwargs) -> Pat
233225
hf_source = model.get("sources", {}).get("hf")
234226
url_source = model.get("sources", {}).get("url")
235227

236-
if hf_source:
237-
extra_patterns = [model["model_file"]]
238-
extra_patterns.extend(model.get("additional_files", []))
239-
240-
try:
241-
return Path(
242-
cls.download_files_from_huggingface(
243-
hf_source,
244-
cache_dir=str(cache_dir),
245-
extra_patterns=extra_patterns,
246-
local_files_only=kwargs.get("local_files_only", False),
228+
sleep = 3.0
229+
while retries > 0:
230+
retries -= 1
231+
232+
if hf_source:
233+
extra_patterns = [model["model_file"]]
234+
extra_patterns.extend(model.get("additional_files", []))
235+
236+
try:
237+
return Path(
238+
cls.download_files_from_huggingface(
239+
hf_source,
240+
cache_dir=str(cache_dir),
241+
extra_patterns=extra_patterns,
242+
local_files_only=kwargs.get("local_files_only", False),
243+
)
247244
)
248-
)
249-
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
250-
logger.error(
251-
f"Could not download model from HuggingFace: {e}"
252-
"Falling back to other sources."
253-
)
254-
255-
if url_source:
256-
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))
245+
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
246+
logger.error(
247+
f"Could not download model from HuggingFace: {e} "
248+
"Falling back to other sources."
249+
)
250+
if url_source:
251+
try:
252+
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))
253+
except Exception:
254+
logger.error(f"Could not download model from url: {url_source}")
255+
256+
logger.error(
257+
f"Could not download model from either source, sleeping for {sleep} seconds, {retries} retries left."
258+
)
259+
time.sleep(sleep)
260+
sleep *= 3
257261

258262
raise ValueError(f"Could not download model {model['model']} from any source.")

0 commit comments

Comments
 (0)