Skip to content

Commit 27eeb39

Browse files
authored
new: add custom models (#479)
* fix: fix onnx text embedding list supported models, do not add already registered models, add tests * fix: autouse fixture in custom model tests * Refactor custom models (#482) * refactor: refactor custom models * fix: fix types * remove commented out code
1 parent 4e527b1 commit 27eeb39

File tree

7 files changed

+342
-12
lines changed

7 files changed

+342
-12
lines changed

fastembed/common/model_description.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass, field
2+
from enum import Enum
23
from typing import Optional, Any
34

45

@@ -28,7 +29,7 @@ class BaseModelDescription:
2829
@dataclass(frozen=True)
2930
class DenseModelDescription(BaseModelDescription):
3031
dim: Optional[int] = None
31-
tasks: Optional[dict[str, Any]] = None
32+
tasks: Optional[dict[str, Any]] = field(default_factory=dict)
3233

3334
def __post_init__(self) -> None:
3435
assert self.dim is not None, "dim is required for dense model description"
@@ -38,3 +39,9 @@ def __post_init__(self) -> None:
3839
class SparseModelDescription(BaseModelDescription):
3940
requires_idf: Optional[bool] = None
4041
vocab_size: Optional[int] = None
42+
43+
44+
class PoolingType(str, Enum):
45+
CLS = "CLS"
46+
MEAN = "MEAN"
47+
DISABLED = "DISABLED"

fastembed/common/model_management.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,31 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
3333
"""
3434
raise NotImplementedError()
3535

36+
@classmethod
37+
def add_custom_model(
38+
cls,
39+
*args: Any,
40+
**kwargs: Any,
41+
) -> None:
42+
"""Add a custom model to the existing embedding classes based on the passed model descriptions
43+
44+
Model description dict should contain the fields same as in one of the model descriptions presented
45+
in fastembed.common.model_description
46+
47+
E.g. for BaseModelDescription:
48+
model: str
49+
sources: ModelSource
50+
model_file: str
51+
description: str
52+
license: str
53+
size_in_GB: float
54+
additional_files: list[str]
55+
56+
Returns:
57+
None
58+
"""
59+
raise NotImplementedError()
60+
3661
@classmethod
3762
def _list_supported_models(cls) -> list[T]:
3863
raise NotImplementedError()

fastembed/common/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Iterable, Optional, TypeVar
99

1010
import numpy as np
11+
from numpy.typing import NDArray
1112

1213
from fastembed.common.types import NumpyArray
1314

@@ -22,6 +23,15 @@ def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e
2223
return normalized_array
2324

2425

26+
def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray:
27+
input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64)
28+
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1]))
29+
sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1)
30+
sum_mask = np.sum(input_mask_expanded, axis=1)
31+
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
32+
return pooled_embeddings
33+
34+
2535
def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
2636
"""
2737
>>> list(iter_batch([1,2,3,4,5], 3))
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Optional, Sequence, Any, Iterable
2+
3+
from dataclasses import dataclass
4+
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
8+
from fastembed.common import OnnxProvider
9+
from fastembed.common.model_description import (
10+
PoolingType,
11+
DenseModelDescription,
12+
)
13+
from fastembed.common.onnx_model import OnnxOutputContext
14+
from fastembed.common.types import NumpyArray
15+
from fastembed.common.utils import normalize, mean_pooling
16+
from fastembed.text.onnx_embedding import OnnxTextEmbedding
17+
18+
19+
@dataclass(frozen=True)
20+
class PostprocessingConfig:
21+
pooling: PoolingType
22+
normalization: bool
23+
24+
25+
class CustomTextEmbedding(OnnxTextEmbedding):
26+
SUPPORTED_MODELS: list[DenseModelDescription] = []
27+
POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {}
28+
29+
def __init__(
30+
self,
31+
model_name: str,
32+
cache_dir: Optional[str] = None,
33+
threads: Optional[int] = None,
34+
providers: Optional[Sequence[OnnxProvider]] = None,
35+
cuda: bool = False,
36+
device_ids: Optional[list[int]] = None,
37+
lazy_load: bool = False,
38+
device_id: Optional[int] = None,
39+
specific_model_path: Optional[str] = None,
40+
**kwargs: Any,
41+
):
42+
super().__init__(
43+
model_name=model_name,
44+
cache_dir=cache_dir,
45+
threads=threads,
46+
providers=providers,
47+
cuda=cuda,
48+
device_ids=device_ids,
49+
lazy_load=lazy_load,
50+
device_id=device_id,
51+
specific_model_path=specific_model_path,
52+
**kwargs,
53+
)
54+
self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling
55+
self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization
56+
57+
@classmethod
58+
def _list_supported_models(cls) -> list[DenseModelDescription]:
59+
return cls.SUPPORTED_MODELS
60+
61+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
62+
return self._normalize(self._pool(output.model_output, output.attention_mask))
63+
64+
def _pool(
65+
self, embeddings: NumpyArray, attention_mask: Optional[NDArray[np.int64]] = None
66+
) -> NumpyArray:
67+
if self._pooling == PoolingType.CLS:
68+
return embeddings[:, 0]
69+
70+
if self._pooling == PoolingType.MEAN:
71+
if attention_mask is None:
72+
raise ValueError("attention_mask must be provided for mean pooling")
73+
return mean_pooling(embeddings, attention_mask)
74+
75+
if self._pooling == PoolingType.DISABLED:
76+
return embeddings
77+
78+
def _normalize(self, embeddings: NumpyArray) -> NumpyArray:
79+
return normalize(embeddings) if self._normalization else embeddings
80+
81+
@classmethod
82+
def add_model(
83+
cls,
84+
model_description: DenseModelDescription,
85+
pooling: PoolingType,
86+
normalization: bool,
87+
) -> None:
88+
cls.SUPPORTED_MODELS.append(model_description)
89+
cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig(
90+
pooling=pooling, normalization=normalization
91+
)

fastembed/text/pooled_embedding.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import Any, Iterable, Type
22

33
import numpy as np
4+
from numpy.typing import NDArray
45

56
from fastembed.common.types import NumpyArray
67
from fastembed.common.onnx_model import OnnxOutputContext
8+
from fastembed.common.utils import mean_pooling
79
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
810
from fastembed.common.model_description import DenseModelDescription, ModelSource
911

@@ -93,16 +95,10 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
9395
return PooledEmbeddingWorker
9496

9597
@classmethod
96-
def mean_pooling(cls, model_output: NumpyArray, attention_mask: NumpyArray) -> NumpyArray:
97-
token_embeddings = model_output.astype(np.float32)
98-
attention_mask = attention_mask.astype(np.float32)
99-
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
100-
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1]))
101-
input_mask_expanded = input_mask_expanded.astype(np.float32)
102-
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
103-
sum_mask = np.sum(input_mask_expanded, axis=1)
104-
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
105-
return pooled_embeddings
98+
def mean_pooling(
99+
cls, model_output: NumpyArray, attention_mask: NDArray[np.int64]
100+
) -> NumpyArray:
101+
return mean_pooling(model_output, attention_mask)
106102

107103
@classmethod
108104
def _list_supported_models(cls) -> list[DenseModelDescription]:

fastembed/text/text_embedding.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
from fastembed.common.types import NumpyArray, OnnxProvider
66
from fastembed.text.clip_embedding import CLIPOnnxEmbedding
7+
from fastembed.text.custom_text_embedding import CustomTextEmbedding
78
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
89
from fastembed.text.pooled_embedding import PooledEmbedding
910
from fastembed.text.multitask_embedding import JinaEmbeddingV3
1011
from fastembed.text.onnx_embedding import OnnxTextEmbedding
1112
from fastembed.text.text_embedding_base import TextEmbeddingBase
12-
from fastembed.common.model_description import DenseModelDescription
13+
from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType
1314

1415

1516
class TextEmbedding(TextEmbeddingBase):
@@ -19,6 +20,7 @@ class TextEmbedding(TextEmbeddingBase):
1920
PooledNormalizedEmbedding,
2021
PooledEmbedding,
2122
JinaEmbeddingV3,
23+
CustomTextEmbedding,
2224
]
2325

2426
@classmethod
@@ -37,6 +39,43 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
3739
result.extend(embedding._list_supported_models())
3840
return result
3941

42+
@classmethod
43+
def add_custom_model(
44+
cls,
45+
model: str,
46+
pooling: PoolingType,
47+
normalization: bool,
48+
sources: ModelSource,
49+
dim: int,
50+
model_file: str = "onnx/model.onnx",
51+
description: str = "",
52+
license: str = "",
53+
size_in_gb: float = 0.0,
54+
additional_files: Optional[list[str]] = None,
55+
) -> None:
56+
registered_models = cls._list_supported_models()
57+
for registered_model in registered_models:
58+
if model == registered_model.model:
59+
raise ValueError(
60+
f"Model {model} is already registered in TextEmbedding, if you still want to add this model, "
61+
f"please use another model name"
62+
)
63+
64+
CustomTextEmbedding.add_model(
65+
DenseModelDescription(
66+
model=model,
67+
sources=sources,
68+
dim=dim,
69+
model_file=model_file,
70+
description=description,
71+
license=license,
72+
size_in_GB=size_in_gb,
73+
additional_files=additional_files or [],
74+
),
75+
pooling=pooling,
76+
normalization=normalization,
77+
)
78+
4079
def __init__(
4180
self,
4281
model_name: str = "BAAI/bge-small-en-v1.5",

0 commit comments

Comments
 (0)