11import  os 
22import  time 
3+ import  json 
34import  shutil 
45import  tarfile 
56from  pathlib  import  Path 
6- from  typing  import  Any ,  Optional 
7+ from  typing  import  Any 
78
89import  requests 
9- from  huggingface_hub  import  snapshot_download 
10+ from  huggingface_hub  import  snapshot_download , model_info , list_repo_tree 
11+ from  huggingface_hub .hf_api  import  RepoFile 
1012from  huggingface_hub .utils  import  (
1113    RepositoryNotFoundError ,
1214    disable_progress_bars ,
1719
1820
1921class  ModelManagement :
22+     METADATA_FILE  =  "files_metadata.json" 
23+ 
2024    @classmethod  
2125    def  list_supported_models (cls ) ->  list [dict [str , Any ]]:
2226        """Lists the supported models. 
@@ -98,7 +102,7 @@ def download_files_from_huggingface(
98102        cls ,
99103        hf_source_repo : str ,
100104        cache_dir : str ,
101-         extra_patterns : Optional [ list [str ]]  =   None ,
105+         extra_patterns : list [str ],
102106        local_files_only : bool  =  False ,
103107        ** kwargs ,
104108    ) ->  str :
@@ -107,36 +111,148 @@ def download_files_from_huggingface(
107111        Args: 
108112            hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx". 
109113            cache_dir (Optional[str]): The path to the cache directory. 
110-             extra_patterns (Optional[ list[str] ]): extra patterns to allow in the snapshot download, typically 
114+             extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically 
111115                includes the required model files. 
112116            local_files_only (bool, optional): Whether to only use local files. Defaults to False. 
113117        Returns: 
114118            Path: The path to the model directory. 
115119        """ 
120+ 
121+         def  _verify_files_from_metadata (
122+             model_dir : Path , stored_metadata : dict [str , Any ], repo_files : list [RepoFile ]
123+         ) ->  bool :
124+             try :
125+                 for  rel_path , meta  in  stored_metadata .items ():
126+                     file_path  =  model_dir  /  rel_path 
127+ 
128+                     if  not  file_path .exists ():
129+                         return  False 
130+ 
131+                     if  repo_files :  # online verification 
132+                         file_info  =  next ((f  for  f  in  repo_files  if  f .path  ==  file_path .name ), None )
133+                         if  (
134+                             not  file_info 
135+                             or  file_info .size  !=  meta ["size" ]
136+                             or  file_info .blob_id  !=  meta ["blob_id" ]
137+                         ):
138+                             return  False 
139+ 
140+                     else :  # offline verification 
141+                         if  file_path .stat ().st_size  !=  meta ["size" ]:
142+                             return  False 
143+                 return  True 
144+             except  (OSError , KeyError ) as  e :
145+                 logger .error (f"Error verifying files: { str (e )}  )
146+                 return  False 
147+ 
148+         def  _collect_file_metadata (
149+             model_dir : Path , repo_files : list [RepoFile ]
150+         ) ->  dict [str , dict [str , int ]]:
151+             meta  =  {}
152+             file_info_map  =  {f .path : f  for  f  in  repo_files }
153+             for  file_path  in  model_dir .rglob ("*" ):
154+                 if  file_path .is_file () and  file_path .name  !=  cls .METADATA_FILE :
155+                     repo_file  =  file_info_map .get (file_path .name )
156+                     if  repo_file :
157+                         meta [str (file_path .relative_to (model_dir ))] =  {
158+                             "size" : repo_file .size ,
159+                             "blob_id" : repo_file .blob_id ,
160+                         }
161+             return  meta 
162+ 
163+         def  _save_file_metadata (model_dir : Path , meta : dict [str , dict [str , int ]]) ->  None :
164+             try :
165+                 if  not  model_dir .exists ():
166+                     model_dir .mkdir (parents = True , exist_ok = True )
167+                 (model_dir  /  cls .METADATA_FILE ).write_text (json .dumps (meta ))
168+             except  (OSError , ValueError ) as  e :
169+                 logger .warning (f"Error saving metadata: { str (e )}  )
170+ 
116171        allow_patterns  =  [
117172            "config.json" ,
118173            "tokenizer.json" ,
119174            "tokenizer_config.json" ,
120175            "special_tokens_map.json" ,
121176            "preprocessor_config.json" ,
122177        ]
123-          if   extra_patterns   is   not   None : 
124-              allow_patterns .extend (extra_patterns )
178+ 
179+         allow_patterns .extend (extra_patterns )
125180
126181        snapshot_dir  =  Path (cache_dir ) /  f"models--{ hf_source_repo .replace ('/' , '--' )}  
127-         is_cached  =  snapshot_dir .exists ()
182+         metadata_file  =  snapshot_dir  /  cls .METADATA_FILE 
183+ 
184+         if  local_files_only :
185+             disable_progress_bars ()
186+             if  metadata_file .exists ():
187+                 metadata  =  json .loads (metadata_file .read_text ())
188+                 verified  =  _verify_files_from_metadata (snapshot_dir , metadata , repo_files = [])
189+                 if  not  verified :
190+                     logger .warning (
191+                         "Local file sizes do not match the metadata." 
192+                     )  # do not raise, still make an attempt to load the model 
193+             else :
194+                 logger .warning (
195+                     "Metadata file not found. Proceeding without checking local files." 
196+                 )  # if users have downloaded models from hf manually, or they're updating from previous versions of 
197+                 # fastembed 
198+             result  =  snapshot_download (
199+                 repo_id = hf_source_repo ,
200+                 allow_patterns = allow_patterns ,
201+                 cache_dir = cache_dir ,
202+                 local_files_only = local_files_only ,
203+                 ** kwargs ,
204+             )
205+             return  result 
206+ 
207+         repo_revision  =  model_info (hf_source_repo ).sha 
208+         repo_tree  =  list (list_repo_tree (hf_source_repo , revision = repo_revision , repo_type = "model" ))
209+ 
210+         allowed_extensions  =  {".json" , ".onnx" , ".txt" }
211+         repo_files  =  (
212+             [
213+                 f 
214+                 for  f  in  repo_tree 
215+                 if  isinstance (f , RepoFile ) and  Path (f .path ).suffix  in  allowed_extensions 
216+             ]
217+             if  repo_tree 
218+             else  []
219+         )
220+ 
221+         verified_metadata  =  False 
222+ 
223+         if  snapshot_dir .exists () and  metadata_file .exists ():
224+             metadata  =  json .loads (metadata_file .read_text ())
225+             verified_metadata  =  _verify_files_from_metadata (snapshot_dir , metadata , repo_files )
128226
129-         if  is_cached :
227+         if  verified_metadata :
130228            disable_progress_bars ()
131229
132-         return  snapshot_download (
230+         result   =  snapshot_download (
133231            repo_id = hf_source_repo ,
134232            allow_patterns = allow_patterns ,
135233            cache_dir = cache_dir ,
136234            local_files_only = local_files_only ,
137235            ** kwargs ,
138236        )
139237
238+         if  (
239+             not  verified_metadata 
240+         ):  # metadata is not up-to-date, update it and check whether the files have been 
241+             # downloaded correctly 
242+             metadata  =  _collect_file_metadata (snapshot_dir , repo_files )
243+ 
244+             download_successful  =  _verify_files_from_metadata (
245+                 snapshot_dir , metadata , repo_files = []
246+             )  # offline verification 
247+             if  not  download_successful :
248+                 raise  ValueError (
249+                     "Files have been corrupted during downloading process. " 
250+                     "Please check your internet connection and try again." 
251+                 )
252+             _save_file_metadata (snapshot_dir , metadata )
253+ 
254+         return  result 
255+ 
140256    @classmethod  
141257    def  decompress_to_cache (cls , targz_path : str , cache_dir : str ):
142258        """ 
0 commit comments