Skip to content

Commit e7d6fb9

Browse files
joeingenerall
andauthored
some pr issues (#514)
* some pr issues * revert query embed refactor * test: add query embed tests * nit * Update tests/test_sparse_embeddings.py --------- Co-authored-by: Andrey Vasnetsov <andrey@vasnetsov.com>
1 parent 5279e3f commit e7d6fb9

7 files changed

Lines changed: 115 additions & 98 deletions

File tree

fastembed/late_interaction/colbert.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,7 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterab
244244
self.load_onnx_model()
245245

246246
for text in query:
247-
yield from self._post_process_onnx_output(
248-
self.onnx_embed([text], is_doc=False), is_doc=False
249-
)
247+
yield from self._post_process_onnx_output(self.onnx_embed([text]), is_doc=False)
250248

251249
@classmethod
252250
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:

fastembed/late_interaction/token_embeddings.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from typing import Union, Iterable, Optional, List, Dict, Any, Type
1+
from dataclasses import asdict
2+
from typing import Union, Iterable, Optional, Any, Type
23

4+
from fastembed.common.model_description import DenseModelDescription, ModelSource
35
from fastembed.common.onnx_model import OnnxOutputContext
46
from fastembed.common.types import NumpyArray
57
from fastembed.late_interaction.late_interaction_embedding_base import (
@@ -10,29 +12,38 @@
1012
import numpy as np
1113

1214
supported_token_embeddings_models = [
13-
{
14-
"model": "jinaai/jina-embeddings-v2-small-en-tokens",
15-
"dim": 512,
16-
"description": "Text embeddings, Unimodal (text), English, 8192 input tokens truncation,"
15+
DenseModelDescription(
16+
model="jinaai/jina-embeddings-v2-small-en-tokens",
17+
dim=512,
18+
description="Text embeddings, Unimodal (text), English, 8192 input tokens truncation,"
1719
" Prefixes for queries/documents: not necessary, 2023 year.",
18-
"license": "apache-2.0",
19-
"size_in_GB": 0.12,
20-
"sources": {"hf": "xenova/jina-embeddings-v2-small-en"},
21-
"model_file": "onnx/model.onnx",
22-
},
20+
license="apache-2.0",
21+
size_in_GB=0.12,
22+
sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"),
23+
model_file="onnx/model.onnx",
24+
),
2325
]
2426

2527

2628
class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase):
2729
@classmethod
28-
def list_supported_models(cls) -> List[Dict[str, Any]]:
30+
def _list_supported_models(cls) -> list[DenseModelDescription]:
2931
"""Lists the supported models.
3032
3133
Returns:
32-
List[Dict[str, Any]]: A list of dictionaries containing the model information.
34+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
3335
"""
3436
return supported_token_embeddings_models
3537

38+
@classmethod
39+
def list_supported_models(cls) -> list[dict[str, Any]]:
40+
"""Lists the supported models.
41+
42+
Returns:
43+
list[dict[str, Any]]: A list of dictionaries containing the model information.
44+
"""
45+
return [asdict(model) for model in cls._list_supported_models()]
46+
3647
@classmethod
3748
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
3849
return TokensEmbeddingWorker
@@ -47,7 +58,6 @@ def _post_process_onnx_output(
4758
masks = output.attention_mask
4859

4960
# For each document we only select those embeddings that are not masked out
50-
5161
for i in range(embeddings.shape[0]):
5262
yield embeddings[i, masks[i] == 1]
5363

@@ -58,11 +68,9 @@ def embed(
5868
parallel: Optional[int] = None,
5969
**kwargs: Any,
6070
) -> Iterable[NumpyArray]:
61-
yield from OnnxTextEmbedding.embed(
62-
self, documents, batch_size=batch_size, parallel=parallel, **kwargs
63-
)
71+
yield from super().embed(documents, batch_size=batch_size, parallel=parallel, **kwargs)
6472

65-
def tokenize_docs(self, documents: List[str]) -> List[NumpyArray]:
73+
def tokenize_docs(self, documents: list[str]) -> list[NumpyArray]:
6674
if self.tokenizer is None:
6775
raise ValueError("Tokenizer not initialized")
6876
encoded = self.tokenizer.encode_batch(documents)
@@ -83,6 +91,7 @@ def init_embedding(
8391

8492
if __name__ == "__main__":
8593
# Example usage
94+
print(TokenEmbeddingsModel.list_supported_models())
8695
model = TokenEmbeddingsModel(model_name="jinaai/jina-embeddings-v2-small-en-tokens")
8796
docs = ["Hello, world!", "hello", "hello hello"]
8897

fastembed/sparse/minicoil.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
from pathlib import Path
2+
3+
from typing import Any, Optional, Sequence, Iterable, Union, Type
4+
25
import numpy as np
3-
from typing import Any, Dict, Optional, Sequence, Iterable, Union, Set
6+
from numpy.typing import NDArray
7+
from py_rust_stemmers import SnowballStemmer
8+
from tokenizers import Tokenizer
49

510
from fastembed.common.model_description import SparseModelDescription, ModelSource
11+
from fastembed.common.onnx_model import OnnxOutputContext
12+
from fastembed.common import OnnxProvider
13+
from fastembed.common.utils import define_cache_dir
614
from fastembed.sparse.sparse_embedding_base import (
715
SparseEmbedding,
816
SparseTextEmbeddingBase,
917
)
10-
11-
from numpy.typing import NDArray
12-
13-
from fastembed.common.onnx_model import OnnxOutputContext
1418
from fastembed.sparse.utils.minicoil_encoder import Encoder
1519
from fastembed.sparse.utils.sparse_vectors_converter import SparseVectorConverter, WordEmbedding
1620
from fastembed.sparse.utils.vocab_resolver import VocabResolver, VocabTokenizerTokenizer
1721
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
18-
from py_rust_stemmers import SnowballStemmer
19-
from fastembed.common import OnnxProvider
20-
from fastembed.common.utils import define_cache_dir
21-
from tokenizers import Tokenizer
2222

2323

2424
MINICOIL_MODEL_FILE = "minicoil.triplet.model.npy"
@@ -29,7 +29,7 @@
2929
supported_minicoil_models: list[SparseModelDescription] = [
3030
SparseModelDescription(
3131
model="Qdrant/minicoil-v1",
32-
vocab_size=30522,
32+
vocab_size=19125,
3333
description="Sparse embedding model, that resolves semantic meaning of the words, "
3434
"while keeping exact keyword match behavior. "
3535
"Based on jinaai/jina-embeddings-v2-small-en-tokens",
@@ -57,7 +57,7 @@ class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]):
5757
while keeping exact keyword match behavior.
5858
5959
Each vocabulary token is converted into 4d component of a sparse vector, which is then weighted by the token frequency in the corpus.
60-
If the token is not found in the corpus, it is trearted exactly like in BM25.
60+
If the token is not found in the corpus, it is treated exactly like in BM25.
6161
`
6262
The model is based on `jinaai/jina-embeddings-v2-small-en-tokens`
6363
"""
@@ -116,10 +116,10 @@ def __init__(
116116

117117
# Initialize class attributes
118118
self.tokenizer: Optional[Tokenizer] = None
119-
self.invert_vocab: Dict[int, str] = {}
120-
self.special_tokens: Set[str] = set()
121-
self.special_tokens_ids: Set[int] = set()
122-
self.stopwords: Set[str] = set()
119+
self.invert_vocab: dict[int, str] = {}
120+
self.special_tokens: set[str] = set()
121+
self.special_tokens_ids: set[int] = set()
122+
self.stopwords: set[str] = set()
123123
self.vocab_resolver: Optional[VocabResolver] = None
124124
self.encoder: Optional[Encoder] = None
125125
self.output_dim: Optional[int] = None
@@ -297,7 +297,7 @@ def _post_process_onnx_output(
297297
# Size of counts: (unique_words)
298298
words_ids = ids_mapping[:, 0].tolist()
299299

300-
sentence_result: Dict[str, WordEmbedding] = {}
300+
sentence_result: dict[str, WordEmbedding] = {}
301301

302302
words = [self.vocab_resolver.lookup_word(word_id) for word_id in words_ids]
303303

@@ -325,36 +325,25 @@ def _post_process_onnx_output(
325325
word=oov_word, forms=[oov_word], count=int(count), word_id=-1, embedding=[1]
326326
)
327327

328-
if is_query:
329-
yield self.sparse_vector_converter.embedding_to_vector_query(
328+
if not is_query:
329+
yield self.sparse_vector_converter.embedding_to_vector(
330330
sentence_result, vocab_size=vocab_size, embedding_size=embedding_size
331331
)
332332
else:
333-
yield self.sparse_vector_converter.embedding_to_vector(
333+
yield self.sparse_vector_converter.embedding_to_vector_query(
334334
sentence_result, vocab_size=vocab_size, embedding_size=embedding_size
335335
)
336336

337+
@classmethod
338+
def _get_worker_class(cls) -> Type["MiniCoilTextEmbeddingWorker"]:
339+
return MiniCoilTextEmbeddingWorker
340+
337341

338342
class MiniCoilTextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
339343
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> MiniCOIL:
340344
return MiniCOIL(
341345
model_name=model_name,
342346
cache_dir=cache_dir,
347+
threads=1,
343348
**kwargs,
344349
)
345-
346-
347-
def test_minicoil() -> None:
348-
model = MiniCOIL(model_name="Qdrant/minicoil-v1")
349-
350-
embedding = next(iter(model.embed("Hello World")))
351-
352-
print(embedding)
353-
354-
embedding = next(iter(model.query_embed("Hello World")))
355-
356-
print(embedding)
357-
358-
359-
if __name__ == "__main__":
360-
test_minicoil()

fastembed/sparse/utils/minicoil_encoder.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import numpy as np
88
from fastembed.common.types import NumpyArray
99

10-
from typing import Tuple
11-
1210

1311
class Encoder:
1412
"""
@@ -35,7 +33,7 @@ class Encoder:
3533
│ │
3634
└─────────────────────┘
3735
38-
Final liner transformation is accompanied by a non-linear activation function: Tanh.
36+
Final linear transformation is accompanied by a non-linear activation function: Tanh.
3937
4038
Tanh is used to ensure that the output is in the range [-1, 1].
4139
It would be easier to visually interpret the output of the model, assuming that each dimension
@@ -70,7 +68,7 @@ def convert_vocab_ids(vocab_ids: NumpyArray) -> NumpyArray:
7068
@classmethod
7169
def avg_by_vocab_ids(
7270
cls, vocab_ids: NumpyArray, embeddings: NumpyArray
73-
) -> Tuple[NumpyArray, NumpyArray]:
71+
) -> tuple[NumpyArray, NumpyArray]:
7472
"""
7573
Takes:
7674
vocab_ids: (batch_size, seq_len) int array
@@ -112,7 +110,7 @@ def avg_by_vocab_ids(
112110

113111
def forward(
114112
self, vocab_ids: NumpyArray, embeddings: NumpyArray
115-
) -> Tuple[NumpyArray, NumpyArray]:
113+
) -> tuple[NumpyArray, NumpyArray]:
116114
"""
117115
Args:
118116
vocab_ids: (batch_size, seq_len) int array

fastembed/sparse/utils/sparse_vectors_converter.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
avg_len: float = 150.0,
3333
):
3434
punctuation = set(get_all_punctuation())
35-
special_tokens = set(["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]"])
35+
special_tokens = {"[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]"}
3636

3737
self.stemmer = stemmer
3838
self.unwanted_tokens = punctuation | special_tokens | stopwords
@@ -163,20 +163,21 @@ def embedding_to_vector(
163163
unknown_words_shift = (
164164
(vocab_size * embedding_size) // GAP + 2
165165
) * GAP # miniCOIL vocab + at least (GAP // embedding_size) + 1 new words gap
166-
167166
sentence_embedding_cleaned = self.clean_words(sentence_embedding)
168167

169-
# Calcualte sentence length after cleaning
168+
# Calculate sentence length after cleaning
170169
sentence_len = 0
171170
for embedding in sentence_embedding_cleaned.values():
172171
sentence_len += embedding.count
173172

174173
for embedding in sentence_embedding_cleaned.values():
175174
word_id = embedding.word_id
176-
num_occurences = embedding.count
177-
tf = self.bm25_tf(num_occurences, sentence_len)
178-
179-
if word_id >= 0: # miniCOIL starts with ID 1
175+
num_occurrences = embedding.count
176+
tf = self.bm25_tf(num_occurrences, sentence_len)
177+
if (
178+
word_id > 0
179+
): # miniCOIL starts with ID 1, we generally won't have word_id == 0 (UNK), as we don't add
180+
# these words to sentence_embedding
180181
embedding_values = embedding.embedding
181182
normalized_embedding = self.normalize_vector(embedding_values)
182183

0 commit comments

Comments
 (0)