@@ -202,12 +202,15 @@ def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = Non
202
202
raise NotImplementedError
203
203
204
204
@classmethod
205
- def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
205
+ def list_supported_models (cls , exclude : List [ str ] = [] ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
206
206
"""
207
207
Lists the supported models.
208
208
"""
209
209
models_file_path = Path (__file__ ).with_name ("models.json" )
210
- models = json .load (open (str (models_file_path )))
210
+ with open (models_file_path , "r" ) as file :
211
+ models = json .load (file )
212
+
213
+ models = [{k : v for k , v in model .items () if k not in exclude } for model in models ]
211
214
212
215
return models
213
216
@@ -525,7 +528,11 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
525
528
Lists the supported models.
526
529
"""
527
530
# jina models are not supported by this class
528
- return [model for model in super ().list_supported_models () if not model ["model" ].startswith ("jinaai" )]
531
+ return [
532
+ model
533
+ for model in super ().list_supported_models (exclude = ["gcs_sources" , "hf_sources" ])
534
+ if not model ["model" ].startswith ("jinaai" )
535
+ ]
529
536
530
537
531
538
class DefaultEmbedding (FlagEmbedding ):
@@ -643,7 +650,11 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
643
650
Lists the supported models.
644
651
"""
645
652
# only jina models are supported by this class
646
- return [model for model in Embedding .list_supported_models () if model ["model" ].startswith ("jinaai" )]
653
+ return [
654
+ model
655
+ for model in Embedding .list_supported_models (exclude = ["gcs_sources" , "hf_sources" ])
656
+ if model ["model" ].startswith ("jinaai" )
657
+ ]
647
658
648
659
@staticmethod
649
660
def mean_pooling (model_output , attention_mask ):
0 commit comments