Skip to content

Commit ccee6fe

Browse files
committed
draft: move custom models logic to text embedding
1 parent b186301 commit ccee6fe

5 files changed

+12
-27
lines changed

fastembed/text/clip_embedding.py

-6
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323

2424
class CLIPOnnxEmbedding(OnnxTextEmbedding):
25-
supported_models = supported_clip_models
26-
2725
@classmethod
2826
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
2927
return CLIPEmbeddingWorker
@@ -37,10 +35,6 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
3735
"""
3836
return cls.supported_models
3937

40-
@classmethod
41-
def add_custom_model(cls, model_info: dict[str, Any]):
42-
cls.supported_models.append(model_info)
43-
4438
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
4539
return output.model_output
4640

fastembed/text/multitask_embedding.py

-4
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,6 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
5555
def list_supported_models(cls) -> list[dict[str, Any]]:
5656
return cls.supported_models
5757

58-
@classmethod
59-
def add_custom_model(cls, model_info: dict[str, Any]):
60-
cls.supported_models.append(model_info)
61-
6258
def _preprocess_onnx_input(
6359
self, onnx_input: dict[str, np.ndarray], **kwargs
6460
) -> dict[str, np.ndarray]:

fastembed/text/pooled_embedding.py

-6
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@
7979

8080

8181
class PooledEmbedding(OnnxTextEmbedding):
82-
supported_models = supported_pooled_models
83-
8482
@classmethod
8583
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
8684
return PooledEmbeddingWorker
@@ -105,10 +103,6 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
105103
"""
106104
return cls.supported_models
107105

108-
@classmethod
109-
def add_custom_model(cls, model_info: dict[str, Any]):
110-
cls.supported_models.append(model_info)
111-
112106
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
113107
if output.attention_mask is None:
114108
raise ValueError("attention_mask must be provided for document post-processing")

fastembed/text/pooled_normalized_embedding.py

-6
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@
8888

8989

9090
class PooledNormalizedEmbedding(PooledEmbedding):
91-
supported_models = supported_pooled_normalized_models
92-
9391
@classmethod
9492
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
9593
return PooledNormalizedEmbeddingWorker
@@ -103,10 +101,6 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
103101
"""
104102
return cls.supported_models
105103

106-
@classmethod
107-
def add_custom_model(cls, model_info: dict[str, Any]):
108-
cls.supported_models.append(model_info)
109-
110104
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
111105
if output.attention_mask is None:
112106
raise ValueError("attention_mask must be provided for document post-processing")

fastembed/text/text_embedding.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from collections import defaultdict
23
from typing import Any, Iterable, Optional, Sequence, Type, Union
34

45
import numpy as np
@@ -19,6 +20,9 @@ class TextEmbedding(TextEmbeddingBase):
1920
PooledEmbedding,
2021
JinaEmbeddingV3,
2122
]
23+
CUSTOM_EMBEDDINGS_REGISTRY: dict[Type[TextEmbeddingBase], list[dict[str, Any]]] = defaultdict(
24+
list
25+
)
2226

2327
@classmethod
2428
def list_supported_models(cls) -> list[dict[str, Any]]:
@@ -48,6 +52,9 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
4852
result = []
4953
for embedding in cls.EMBEDDINGS_REGISTRY:
5054
result.extend(embedding.list_supported_models())
55+
for embedding, models in cls.CUSTOM_EMBEDDINGS_REGISTRY.items():
56+
for model in models:
57+
result.append(model)
5158
return result
5259

5360
@classmethod
@@ -74,13 +81,13 @@ def add_custom_model(
7481
None
7582
"""
7683
if mean_pooling and not normalization:
77-
PooledEmbedding.add_custom_model(model_info)
84+
cls.CUSTOM_EMBEDDINGS_REGISTRY[PooledEmbedding].append(model_info)
7885
elif mean_pooling and normalization:
79-
PooledNormalizedEmbedding.add_custom_model(model_info)
80-
elif "clip" in model_info["model"].lower():
81-
CLIPOnnxEmbedding.add_custom_model(model_info)
86+
cls.CUSTOM_EMBEDDINGS_REGISTRY[PooledNormalizedEmbedding].append(model_info)
87+
elif not mean_pooling and not normalization:
88+
cls.CUSTOM_EMBEDDINGS_REGISTRY[OnnxTextEmbedding].append(model_info)
8289
else:
83-
OnnxTextEmbedding.add_custom_model(model_info)
90+
cls.CUSTOM_EMBEDDINGS_REGISTRY[PooledNormalizedEmbedding].append(model_info)
8491

8592
def __init__(
8693
self,

0 commit comments

Comments
 (0)