@@ -202,12 +202,15 @@ def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = Non
202
202
raise NotImplementedError
203
203
204
204
@classmethod
205
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
205
+ def list_supported_models (cls , exclude : List [ str ] = [] ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
206
206
"""
207
207
Lists the supported models.
208
208
"""
209
209
models_file_path = Path (__file__ ).with_name ("models.json" )
210
- models = json .load (open (str (models_file_path )))
210
+ with open (models_file_path , "r" ) as file :
211
+ models = json .load (file )
212
+
213
+ models = [{k : v for k , v in model .items () if k not in exclude } for model in models ]
211
214
212
215
return models
213
216
@@ -264,7 +267,7 @@ def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[st
264
267
Returns:
265
268
Path: The path to the model directory.
266
269
"""
267
- models = cls .list_supported_models ()
270
+ models = cls .list_supported_models (exclude = [ "gcs_sources" ] )
268
271
269
272
hf_sources = [item for model in models if model ["model" ] == model_name for item in model ["hf_sources" ]]
270
273
@@ -343,7 +346,7 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
343
346
344
347
model_tar_gz = Path (cache_dir ) / f"{ fast_model_name } .tar.gz"
345
348
346
- models = self .list_supported_models ()
349
+ models = self .list_supported_models (exclude = [ "hf_sources" ] )
347
350
348
351
gcs_sources = [item for model in models if model ["model" ] == model_name for item in model ["gcs_sources" ]]
349
352
@@ -520,12 +523,16 @@ def embed(
520
523
yield from normalize (embeddings [:, 0 ]).astype (np .float32 )
521
524
522
525
@classmethod
523
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
526
+ def list_supported_models (
527
+ cls , exclude : List [str ] = ["gcs_sources" , "hf_sources" ]
528
+ ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
524
529
"""
525
530
Lists the supported models.
526
531
"""
527
532
# jina models are not supported by this class
528
- return [model for model in super ().list_supported_models () if not model ["model" ].startswith ("jinaai" )]
533
+ return [
534
+ model for model in super ().list_supported_models (exclude = exclude ) if not model ["model" ].startswith ("jinaai" )
535
+ ]
529
536
530
537
531
538
class DefaultEmbedding (FlagEmbedding ):
@@ -638,12 +645,16 @@ def embed(
638
645
yield from normalize (self .mean_pooling (embeddings , attn_mask )).astype (np .float32 )
639
646
640
647
@classmethod
641
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
648
+ def list_supported_models (
649
+ cls , exclude : List [str ] = ["gcs_sources" , "hf_sources" ]
650
+ ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
642
651
"""
643
652
Lists the supported models.
644
653
"""
645
654
# only jina models are supported by this class
646
- return [model for model in Embedding .list_supported_models () if model ["model" ].startswith ("jinaai" )]
655
+ return [
656
+ model for model in Embedding .list_supported_models (exclude = exclude ) if model ["model" ].startswith ("jinaai" )
657
+ ]
647
658
648
659
@staticmethod
649
660
def mean_pooling (model_output , attention_mask ):
0 commit comments