1
1
import os
2
2
import time
3
+ import json
3
4
import shutil
4
5
import tarfile
5
6
from pathlib import Path
6
- from typing import Any , Optional
7
+ from typing import Any
7
8
8
9
import requests
9
- from huggingface_hub import snapshot_download
10
+ from huggingface_hub import snapshot_download , model_info , list_repo_tree
11
+ from huggingface_hub .hf_api import RepoFile
10
12
from huggingface_hub .utils import (
11
13
RepositoryNotFoundError ,
12
14
disable_progress_bars ,
17
19
18
20
19
21
class ModelManagement :
22
+ METADATA_FILE = "files_metadata.json"
23
+
20
24
@classmethod
21
25
def list_supported_models (cls ) -> list [dict [str , Any ]]:
22
26
"""Lists the supported models.
@@ -98,7 +102,7 @@ def download_files_from_huggingface(
98
102
cls ,
99
103
hf_source_repo : str ,
100
104
cache_dir : str ,
101
- extra_patterns : Optional [ list [str ]] = None ,
105
+ extra_patterns : list [str ],
102
106
local_files_only : bool = False ,
103
107
** kwargs ,
104
108
) -> str :
@@ -107,36 +111,148 @@ def download_files_from_huggingface(
107
111
Args:
108
112
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
109
113
cache_dir (Optional[str]): The path to the cache directory.
110
- extra_patterns (Optional[ list[str] ]): extra patterns to allow in the snapshot download, typically
114
+ extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
111
115
includes the required model files.
112
116
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
113
117
Returns:
114
118
Path: The path to the model directory.
115
119
"""
120
+
121
+ def _verify_files_from_metadata (
122
+ model_dir : Path , stored_metadata : dict [str , Any ], repo_files : list [RepoFile ]
123
+ ) -> bool :
124
+ try :
125
+ for rel_path , meta in stored_metadata .items ():
126
+ file_path = model_dir / rel_path
127
+
128
+ if not file_path .exists ():
129
+ return False
130
+
131
+ if repo_files : # online verification
132
+ file_info = next ((f for f in repo_files if f .path == file_path .name ), None )
133
+ if (
134
+ not file_info
135
+ or file_info .size != meta ["size" ]
136
+ or file_info .blob_id != meta ["blob_id" ]
137
+ ):
138
+ return False
139
+
140
+ else : # offline verification
141
+ if file_path .stat ().st_size != meta ["size" ]:
142
+ return False
143
+ return True
144
+ except (OSError , KeyError ) as e :
145
+ logger .error (f"Error verifying files: { str (e )} " )
146
+ return False
147
+
148
+ def _collect_file_metadata (
149
+ model_dir : Path , repo_files : list [RepoFile ]
150
+ ) -> dict [str , dict [str , int ]]:
151
+ meta = {}
152
+ file_info_map = {f .path : f for f in repo_files }
153
+ for file_path in model_dir .rglob ("*" ):
154
+ if file_path .is_file () and file_path .name != cls .METADATA_FILE :
155
+ repo_file = file_info_map .get (file_path .name )
156
+ if repo_file :
157
+ meta [str (file_path .relative_to (model_dir ))] = {
158
+ "size" : repo_file .size ,
159
+ "blob_id" : repo_file .blob_id ,
160
+ }
161
+ return meta
162
+
163
+ def _save_file_metadata (model_dir : Path , meta : dict [str , dict [str , int ]]) -> None :
164
+ try :
165
+ if not model_dir .exists ():
166
+ model_dir .mkdir (parents = True , exist_ok = True )
167
+ (model_dir / cls .METADATA_FILE ).write_text (json .dumps (meta ))
168
+ except (OSError , ValueError ) as e :
169
+ logger .warning (f"Error saving metadata: { str (e )} " )
170
+
116
171
allow_patterns = [
117
172
"config.json" ,
118
173
"tokenizer.json" ,
119
174
"tokenizer_config.json" ,
120
175
"special_tokens_map.json" ,
121
176
"preprocessor_config.json" ,
122
177
]
123
- if extra_patterns is not None :
124
- allow_patterns .extend (extra_patterns )
178
+
179
+ allow_patterns .extend (extra_patterns )
125
180
126
181
snapshot_dir = Path (cache_dir ) / f"models--{ hf_source_repo .replace ('/' , '--' )} "
127
- is_cached = snapshot_dir .exists ()
182
+ metadata_file = snapshot_dir / cls .METADATA_FILE
183
+
184
+ if local_files_only :
185
+ disable_progress_bars ()
186
+ if metadata_file .exists ():
187
+ metadata = json .loads (metadata_file .read_text ())
188
+ verified = _verify_files_from_metadata (snapshot_dir , metadata , repo_files = [])
189
+ if not verified :
190
+ logger .warning (
191
+ "Local file sizes do not match the metadata."
192
+ ) # do not raise, still make an attempt to load the model
193
+ else :
194
+ logger .warning (
195
+ "Metadata file not found. Proceeding without checking local files."
196
+ ) # if users have downloaded models from hf manually, or they're updating from previous versions of
197
+ # fastembed
198
+ result = snapshot_download (
199
+ repo_id = hf_source_repo ,
200
+ allow_patterns = allow_patterns ,
201
+ cache_dir = cache_dir ,
202
+ local_files_only = local_files_only ,
203
+ ** kwargs ,
204
+ )
205
+ return result
206
+
207
+ repo_revision = model_info (hf_source_repo ).sha
208
+ repo_tree = list (list_repo_tree (hf_source_repo , revision = repo_revision , repo_type = "model" ))
209
+
210
+ allowed_extensions = {".json" , ".onnx" , ".txt" }
211
+ repo_files = (
212
+ [
213
+ f
214
+ for f in repo_tree
215
+ if isinstance (f , RepoFile ) and Path (f .path ).suffix in allowed_extensions
216
+ ]
217
+ if repo_tree
218
+ else []
219
+ )
220
+
221
+ verified_metadata = False
222
+
223
+ if snapshot_dir .exists () and metadata_file .exists ():
224
+ metadata = json .loads (metadata_file .read_text ())
225
+ verified_metadata = _verify_files_from_metadata (snapshot_dir , metadata , repo_files )
128
226
129
- if is_cached :
227
+ if verified_metadata :
130
228
disable_progress_bars ()
131
229
132
- return snapshot_download (
230
+ result = snapshot_download (
133
231
repo_id = hf_source_repo ,
134
232
allow_patterns = allow_patterns ,
135
233
cache_dir = cache_dir ,
136
234
local_files_only = local_files_only ,
137
235
** kwargs ,
138
236
)
139
237
238
+ if (
239
+ not verified_metadata
240
+ ): # metadata is not up-to-date, update it and check whether the files have been
241
+ # downloaded correctly
242
+ metadata = _collect_file_metadata (snapshot_dir , repo_files )
243
+
244
+ download_successful = _verify_files_from_metadata (
245
+ snapshot_dir , metadata , repo_files = []
246
+ ) # offline verification
247
+ if not download_successful :
248
+ raise ValueError (
249
+ "Files have been corrupted during downloading process. "
250
+ "Please check your internet connection and try again."
251
+ )
252
+ _save_file_metadata (snapshot_dir , metadata )
253
+
254
+ return result
255
+
140
256
@classmethod
141
257
def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
142
258
"""
0 commit comments