1
1
import warnings
2
+ from collections import defaultdict
2
3
from typing import Any , Iterable , Optional , Sequence , Type , Union
3
4
4
5
import numpy as np
@@ -19,6 +20,9 @@ class TextEmbedding(TextEmbeddingBase):
19
20
PooledEmbedding ,
20
21
JinaEmbeddingV3 ,
21
22
]
23
+ CUSTOM_EMBEDDINGS_REGISTRY : dict [Type [TextEmbeddingBase ], list [dict [str , Any ]]] = defaultdict (
24
+ list
25
+ )
22
26
23
27
@classmethod
24
28
def list_supported_models (cls ) -> list [dict [str , Any ]]:
@@ -48,6 +52,9 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
48
52
result = []
49
53
for embedding in cls .EMBEDDINGS_REGISTRY :
50
54
result .extend (embedding .list_supported_models ())
55
+ for embedding , models in cls .CUSTOM_EMBEDDINGS_REGISTRY .items ():
56
+ for model in models :
57
+ result .append (model )
51
58
return result
52
59
53
60
@classmethod
@@ -74,13 +81,13 @@ def add_custom_model(
74
81
None
75
82
"""
76
83
if mean_pooling and not normalization :
77
- PooledEmbedding . add_custom_model (model_info )
84
+ cls . CUSTOM_EMBEDDINGS_REGISTRY [ PooledEmbedding ]. append (model_info )
78
85
elif mean_pooling and normalization :
79
- PooledNormalizedEmbedding . add_custom_model (model_info )
80
- elif "clip" in model_info [ "model" ]. lower () :
81
- CLIPOnnxEmbedding . add_custom_model (model_info )
86
+ cls . CUSTOM_EMBEDDINGS_REGISTRY [ PooledNormalizedEmbedding ]. append (model_info )
87
+ elif not mean_pooling and not normalization :
88
+ cls . CUSTOM_EMBEDDINGS_REGISTRY [ OnnxTextEmbedding ]. append (model_info )
82
89
else :
83
- OnnxTextEmbedding . add_custom_model (model_info )
90
+ cls . CUSTOM_EMBEDDINGS_REGISTRY [ PooledNormalizedEmbedding ]. append (model_info )
84
91
85
92
def __init__ (
86
93
self ,
0 commit comments