Skip to content

Commit ce72a74

Browse files
committed
chore: exclude keys list_supported_models
1 parent d1969c5 commit ce72a74

File tree

1 file changed

+36
-14
lines changed

1 file changed

+36
-14
lines changed

fastembed/embedding.py

+36-14
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,20 @@ def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = Non
202202
raise NotImplementedError
203203

204204
@classmethod
205-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
206-
"""
207-
Lists the supported models.
205+
def list_supported_models(cls, exclude: List[str] = []) -> List[Dict[str, Any]]:
206+
"""Lists the supported models.
207+
208+
Args:
209+
exclude (List[str], optional): Keys to exclude from the result. Defaults to [].
210+
211+
Returns:
212+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
208213
"""
209214
models_file_path = Path(__file__).with_name("models.json")
210-
models = json.load(open(str(models_file_path)))
215+
with open(models_file_path, "r") as file:
216+
models = json.load(file)
217+
218+
models = [{k: v for k, v in model.items() if k not in exclude} for model in models]
211219

212220
return models
213221

@@ -264,7 +272,7 @@ def download_files_from_huggingface(cls, model_name: str, cache_dir: Optional[st
264272
Returns:
265273
Path: The path to the model directory.
266274
"""
267-
models = cls.list_supported_models()
275+
models = cls.list_supported_models(exclude=["gcs_sources"])
268276

269277
hf_sources = [item for model in models if model["model"] == model_name for item in model["hf_sources"]]
270278

@@ -343,7 +351,7 @@ def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
343351

344352
model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
345353

346-
models = self.list_supported_models()
354+
models = self.list_supported_models(exclude=["hf_sources"])
347355

348356
gcs_sources = [item for model in models if model["model"] == model_name for item in model["gcs_sources"]]
349357

@@ -520,12 +528,19 @@ def embed(
520528
yield from normalize(embeddings[:, 0]).astype(np.float32)
521529

522530
@classmethod
523-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
524-
"""
525-
Lists the supported models.
531+
def list_supported_models(cls, exclude: List[str] = ["gcs_sources", "hf_sources"]) -> List[Dict[str, Any]]:
532+
"""Lists the supported models.
533+
534+
Args:
535+
exclude (List[str], optional): Keys to exclude from the result. Defaults to ["gcs_sources", "hf_sources"].
536+
537+
Returns:
538+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
526539
"""
527540
# jina models are not supported by this class
528-
return [model for model in super().list_supported_models() if not model["model"].startswith("jinaai")]
541+
return [
542+
model for model in super().list_supported_models(exclude=exclude) if not model["model"].startswith("jinaai")
543+
]
529544

530545

531546
class DefaultEmbedding(FlagEmbedding):
@@ -638,12 +653,19 @@ def embed(
638653
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
639654

640655
@classmethod
641-
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
642-
"""
643-
Lists the supported models.
656+
def list_supported_models(cls, exclude: List[str] = ["gcs_sources", "hf_sources"]) -> List[Dict[str, Any]]:
657+
"""Lists the supported models.
658+
659+
Args:
660+
exclude (List[str], optional): Keys to exclude from the result. Defaults to ["gcs_sources", "hf_sources"].
661+
662+
Returns:
663+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
644664
"""
645665
# only jina models are supported by this class
646-
return [model for model in Embedding.list_supported_models() if model["model"].startswith("jinaai")]
666+
return [
667+
model for model in Embedding.list_supported_models(exclude=exclude) if model["model"].startswith("jinaai")
668+
]
647669

648670
@staticmethod
649671
def mean_pooling(model_output, attention_mask):

0 commit comments

Comments
 (0)