Skip to content

Commit 54f6cd9

Browse files
Improve progress bar new (#440)
* improve: Improve progress bar * fix: Fix error downloading when internet connection down * new: Added file hash computation to track new versions * refactor: Removed redundant hash check fix: Fix ci * new: Verify using hf_api * new: Improve progress bar * refactor new progress bar (#446) * refactor * chore: Remove redundant enable progress bar --------- Co-authored-by: hh-space-invader <[email protected]> * refactor comments --------- Co-authored-by: George <[email protected]>
1 parent ae37da3 commit 54f6cd9

File tree

1 file changed

+125
-9
lines changed

1 file changed

+125
-9
lines changed

fastembed/common/model_management.py

+125-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22
import time
3+
import json
34
import shutil
45
import tarfile
56
from pathlib import Path
6-
from typing import Any, Optional
7+
from typing import Any
78

89
import requests
9-
from huggingface_hub import snapshot_download
10+
from huggingface_hub import snapshot_download, model_info, list_repo_tree
11+
from huggingface_hub.hf_api import RepoFile
1012
from huggingface_hub.utils import (
1113
RepositoryNotFoundError,
1214
disable_progress_bars,
@@ -17,6 +19,8 @@
1719

1820

1921
class ModelManagement:
22+
METADATA_FILE = "files_metadata.json"
23+
2024
@classmethod
2125
def list_supported_models(cls) -> list[dict[str, Any]]:
2226
"""Lists the supported models.
@@ -98,7 +102,7 @@ def download_files_from_huggingface(
98102
cls,
99103
hf_source_repo: str,
100104
cache_dir: str,
101-
extra_patterns: Optional[list[str]] = None,
105+
extra_patterns: list[str],
102106
local_files_only: bool = False,
103107
**kwargs,
104108
) -> str:
@@ -107,36 +111,148 @@ def download_files_from_huggingface(
107111
Args:
108112
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
109113
cache_dir (Optional[str]): The path to the cache directory.
110-
extra_patterns (Optional[list[str]]): extra patterns to allow in the snapshot download, typically
114+
extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
111115
includes the required model files.
112116
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
113117
Returns:
114118
Path: The path to the model directory.
115119
"""
120+
121+
def _verify_files_from_metadata(
122+
model_dir: Path, stored_metadata: dict[str, Any], repo_files: list[RepoFile]
123+
) -> bool:
124+
try:
125+
for rel_path, meta in stored_metadata.items():
126+
file_path = model_dir / rel_path
127+
128+
if not file_path.exists():
129+
return False
130+
131+
if repo_files: # online verification
132+
file_info = next((f for f in repo_files if f.path == file_path.name), None)
133+
if (
134+
not file_info
135+
or file_info.size != meta["size"]
136+
or file_info.blob_id != meta["blob_id"]
137+
):
138+
return False
139+
140+
else: # offline verification
141+
if file_path.stat().st_size != meta["size"]:
142+
return False
143+
return True
144+
except (OSError, KeyError) as e:
145+
logger.error(f"Error verifying files: {str(e)}")
146+
return False
147+
148+
def _collect_file_metadata(
149+
model_dir: Path, repo_files: list[RepoFile]
150+
) -> dict[str, dict[str, int]]:
151+
meta = {}
152+
file_info_map = {f.path: f for f in repo_files}
153+
for file_path in model_dir.rglob("*"):
154+
if file_path.is_file() and file_path.name != cls.METADATA_FILE:
155+
repo_file = file_info_map.get(file_path.name)
156+
if repo_file:
157+
meta[str(file_path.relative_to(model_dir))] = {
158+
"size": repo_file.size,
159+
"blob_id": repo_file.blob_id,
160+
}
161+
return meta
162+
163+
def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, int]]) -> None:
164+
try:
165+
if not model_dir.exists():
166+
model_dir.mkdir(parents=True, exist_ok=True)
167+
(model_dir / cls.METADATA_FILE).write_text(json.dumps(meta))
168+
except (OSError, ValueError) as e:
169+
logger.warning(f"Error saving metadata: {str(e)}")
170+
116171
allow_patterns = [
117172
"config.json",
118173
"tokenizer.json",
119174
"tokenizer_config.json",
120175
"special_tokens_map.json",
121176
"preprocessor_config.json",
122177
]
123-
if extra_patterns is not None:
124-
allow_patterns.extend(extra_patterns)
178+
179+
allow_patterns.extend(extra_patterns)
125180

126181
snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}"
127-
is_cached = snapshot_dir.exists()
182+
metadata_file = snapshot_dir / cls.METADATA_FILE
183+
184+
if local_files_only:
185+
disable_progress_bars()
186+
if metadata_file.exists():
187+
metadata = json.loads(metadata_file.read_text())
188+
verified = _verify_files_from_metadata(snapshot_dir, metadata, repo_files=[])
189+
if not verified:
190+
logger.warning(
191+
"Local file sizes do not match the metadata."
192+
) # do not raise, still make an attempt to load the model
193+
else:
194+
logger.warning(
195+
"Metadata file not found. Proceeding without checking local files."
196+
) # if users have downloaded models from hf manually, or they're updating from previous versions of
197+
# fastembed
198+
result = snapshot_download(
199+
repo_id=hf_source_repo,
200+
allow_patterns=allow_patterns,
201+
cache_dir=cache_dir,
202+
local_files_only=local_files_only,
203+
**kwargs,
204+
)
205+
return result
206+
207+
repo_revision = model_info(hf_source_repo).sha
208+
repo_tree = list(list_repo_tree(hf_source_repo, revision=repo_revision, repo_type="model"))
209+
210+
allowed_extensions = {".json", ".onnx", ".txt"}
211+
repo_files = (
212+
[
213+
f
214+
for f in repo_tree
215+
if isinstance(f, RepoFile) and Path(f.path).suffix in allowed_extensions
216+
]
217+
if repo_tree
218+
else []
219+
)
220+
221+
verified_metadata = False
222+
223+
if snapshot_dir.exists() and metadata_file.exists():
224+
metadata = json.loads(metadata_file.read_text())
225+
verified_metadata = _verify_files_from_metadata(snapshot_dir, metadata, repo_files)
128226

129-
if is_cached:
227+
if verified_metadata:
130228
disable_progress_bars()
131229

132-
return snapshot_download(
230+
result = snapshot_download(
133231
repo_id=hf_source_repo,
134232
allow_patterns=allow_patterns,
135233
cache_dir=cache_dir,
136234
local_files_only=local_files_only,
137235
**kwargs,
138236
)
139237

238+
if (
239+
not verified_metadata
240+
): # metadata is not up-to-date, update it and check whether the files have been
241+
# downloaded correctly
242+
metadata = _collect_file_metadata(snapshot_dir, repo_files)
243+
244+
download_successful = _verify_files_from_metadata(
245+
snapshot_dir, metadata, repo_files=[]
246+
) # offline verification
247+
if not download_successful:
248+
raise ValueError(
249+
"Files have been corrupted during downloading process. "
250+
"Please check your internet connection and try again."
251+
)
252+
_save_file_metadata(snapshot_dir, metadata)
253+
254+
return result
255+
140256
@classmethod
141257
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
142258
"""

0 commit comments

Comments
 (0)