Skip to content

Commit c8890e8

Browse files
committed
chore: exclude keys list_supported_models
1 parent d1969c5 commit c8890e8

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

fastembed/embedding.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,15 @@ def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = Non
202202
raise NotImplementedError
203203

204204
@classmethod
205-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
205+
def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Union[str, Union[int, float]]]]:
206206
"""
207207
Lists the supported models.
208208
"""
209209
models_file_path = Path(__file__).with_name("models.json")
210-
models = json.load(open(str(models_file_path)))
210+
with open(models_file_path, "r") as file:
211+
models = json.load(file)
212+
213+
models = [{k: v for k, v in model.items() if k not in exclude} for model in models]
211214

212215
return models
213216

@@ -525,7 +528,11 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
525528
Lists the supported models.
526529
"""
527530
# jina models are not supported by this class
528-
return [model for model in super().list_supported_models() if not model["model"].startswith("jinaai")]
531+
return [
532+
model
533+
for model in super().list_supported_models(exclude=["gcs_sources", "hf_sources"])
534+
if not model["model"].startswith("jinaai")
535+
]
529536

530537

531538
class DefaultEmbedding(FlagEmbedding):
@@ -643,7 +650,11 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
643650
Lists the supported models.
644651
"""
645652
# only jina models are supported by this class
646-
return [model for model in Embedding.list_supported_models() if model["model"].startswith("jinaai")]
653+
return [
654+
model
655+
for model in Embedding.list_supported_models(exclude=["gcs_sources", "hf_sources"])
656+
if model["model"].startswith("jinaai")
657+
]
647658

648659
@staticmethod
649660
def mean_pooling(model_output, attention_mask):

0 commit comments

Comments
 (0)