@@ -202,12 +202,20 @@ def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = Non
202
202
raise NotImplementedError
203
203
204
204
@classmethod
205
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
206
- """
207
- Lists the supported models.
205
+ def list_supported_models (cls , exclude : List [str ] = []) -> List [Dict [str , Any ]]:
206
+ """Lists the supported models.
207
+
208
+ Args:
209
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
210
+
211
+ Returns:
212
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
208
213
"""
209
214
models_file_path = Path (__file__ ).with_name ("models.json" )
210
- models = json .load (open (str (models_file_path )))
215
+ with open (models_file_path , "r" ) as file :
216
+ models = json .load (file )
217
+
218
+ models = [{k : v for k , v in model .items () if k not in exclude } for model in models ]
211
219
212
220
return models
213
221
@@ -264,7 +272,7 @@ def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[st
264
272
Returns:
265
273
Path: The path to the model directory.
266
274
"""
267
- models = cls .list_supported_models ()
275
+ models = cls .list_supported_models (exclude = [ "gcs_sources" ] )
268
276
269
277
hf_sources = [item for model in models if model ["model" ] == model_name for item in model ["hf_sources" ]]
270
278
@@ -343,7 +351,7 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
343
351
344
352
model_tar_gz = Path (cache_dir ) / f"{ fast_model_name } .tar.gz"
345
353
346
- models = self .list_supported_models ()
354
+ models = self .list_supported_models (exclude = [ "hf_sources" ] )
347
355
348
356
gcs_sources = [item for model in models if model ["model" ] == model_name for item in model ["gcs_sources" ]]
349
357
@@ -520,12 +528,19 @@ def embed(
520
528
yield from normalize (embeddings [:, 0 ]).astype (np .float32 )
521
529
522
530
@classmethod
523
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
524
- """
525
- Lists the supported models.
531
+ def list_supported_models (cls , exclude : List [str ] = ["gcs_sources" , "hf_sources" ]) -> List [Dict [str , Any ]]:
532
+ """Lists the supported models.
533
+
534
+ Args:
535
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["gcs_sources", "hf_sources"].
536
+
537
+ Returns:
538
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
526
539
"""
527
540
# jina models are not supported by this class
528
- return [model for model in super ().list_supported_models () if not model ["model" ].startswith ("jinaai" )]
541
+ return [
542
+ model for model in super ().list_supported_models (exclude = exclude ) if not model ["model" ].startswith ("jinaai" )
543
+ ]
529
544
530
545
531
546
class DefaultEmbedding (FlagEmbedding ):
@@ -638,12 +653,19 @@ def embed(
638
653
yield from normalize (self .mean_pooling (embeddings , attn_mask )).astype (np .float32 )
639
654
640
655
@classmethod
641
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
642
- """
643
- Lists the supported models.
656
+ def list_supported_models (cls , exclude : List [str ] = ["gcs_sources" , "hf_sources" ]) -> List [Dict [str , Any ]]:
657
+ """Lists the supported models.
658
+
659
+ Args:
660
+ exclude (List[str], optional): Keys to exclude from the result. Defaults to ["gcs_sources", "hf_sources"].
661
+
662
+ Returns:
663
+ List[Dict[str, Any]]: A list of dictionaries containing the model information.
644
664
"""
645
665
# only jina models are supported by this class
646
- return [model for model in Embedding .list_supported_models () if model ["model" ].startswith ("jinaai" )]
666
+ return [
667
+ model for model in Embedding .list_supported_models (exclude = exclude ) if model ["model" ].startswith ("jinaai" )
668
+ ]
647
669
648
670
@staticmethod
649
671
def mean_pooling (model_output , attention_mask ):
0 commit comments