Skip to content

Commit 7a64189

Browse files
committed
chore: exclude keys list_supported_models
1 parent d1969c5 commit 7a64189

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

fastembed/embedding.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,15 @@ def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = Non
202202
raise NotImplementedError
203203

204204
@classmethod
205-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
205+
def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Union[str, Union[int, float]]]]:
206206
"""
207207
Lists the supported models.
208208
"""
209209
models_file_path = Path(__file__).with_name("models.json")
210-
models = json.load(open(str(models_file_path)))
210+
with open(models_file_path, "r") as file:
211+
models = json.load(file)
212+
213+
models = [{k: v for k, v in model.items() if k not in exclude} for model in models]
211214

212215
return models
213216

@@ -264,7 +267,7 @@ def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[st
264267
Returns:
265268
Path: The path to the model directory.
266269
"""
267-
models = cls.list_supported_models()
270+
models = cls.list_supported_models(exclude=["gcs_sources"])
268271

269272
hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]]
270273

@@ -343,7 +346,7 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
343346

344347
model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
345348

346-
models = self.list_supported_models()
349+
models = self.list_supported_models(exclude=["hf_sources"])
347350

348351
gcs_sources = [item for model in models if model["model"] == model_name for item in model["gcs_sources"]]
349352

@@ -520,12 +523,16 @@ def embed(
520523
yield from normalize(embeddings[:, 0]).astype(np.float32)
521524

522525
@classmethod
523-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
526+
def list_supported_models(
527+
cls, exclude: List[str] = ["gcs_sources", "hf_sources"]
528+
) -> List[Dict[str, Union[str, Union[int, float]]]]:
524529
"""
525530
Lists the supported models.
526531
"""
527532
# jina models are not supported by this class
528-
return [model for model in super().list_supported_models() if not model["model"].startswith("jinaai")]
533+
return [
534+
model for model in super().list_supported_models(exclude=exclude) if not model["model"].startswith("jinaai")
535+
]
529536

530537

531538
class DefaultEmbedding(FlagEmbedding):
@@ -638,12 +645,16 @@ def embed(
638645
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
639646

640647
@classmethod
641-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
648+
def list_supported_models(
649+
cls, exclude: List[str] = ["gcs_sources", "hf_sources"]
650+
) -> List[Dict[str, Union[str, Union[int, float]]]]:
642651
"""
643652
Lists the supported models.
644653
"""
645654
# only jina models are supported by this class
646-
return [model for model in Embedding.list_supported_models() if model["model"].startswith("jinaai")]
655+
return [
656+
model for model in Embedding.list_supported_models(exclude=exclude) if model["model"].startswith("jinaai")
657+
]
647658

648659
@staticmethod
649660
def mean_pooling(model_output, attention_mask):

0 commit comments

Comments
 (0)