|
1 | 1 | from pathlib import Path |
| 2 | + |
| 3 | +from typing import Any, Optional, Sequence, Iterable, Union, Type |
| 4 | + |
2 | 5 | import numpy as np |
3 | | -from typing import Any, Dict, Optional, Sequence, Iterable, Union, Set |
| 6 | +from numpy.typing import NDArray |
| 7 | +from py_rust_stemmers import SnowballStemmer |
| 8 | +from tokenizers import Tokenizer |
4 | 9 |
|
5 | 10 | from fastembed.common.model_description import SparseModelDescription, ModelSource |
| 11 | +from fastembed.common.onnx_model import OnnxOutputContext |
| 12 | +from fastembed.common import OnnxProvider |
| 13 | +from fastembed.common.utils import define_cache_dir |
6 | 14 | from fastembed.sparse.sparse_embedding_base import ( |
7 | 15 | SparseEmbedding, |
8 | 16 | SparseTextEmbeddingBase, |
9 | 17 | ) |
10 | | - |
11 | | -from numpy.typing import NDArray |
12 | | - |
13 | | -from fastembed.common.onnx_model import OnnxOutputContext |
14 | 18 | from fastembed.sparse.utils.minicoil_encoder import Encoder |
15 | 19 | from fastembed.sparse.utils.sparse_vectors_converter import SparseVectorConverter, WordEmbedding |
16 | 20 | from fastembed.sparse.utils.vocab_resolver import VocabResolver, VocabTokenizerTokenizer |
17 | 21 | from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker |
18 | | -from py_rust_stemmers import SnowballStemmer |
19 | | -from fastembed.common import OnnxProvider |
20 | | -from fastembed.common.utils import define_cache_dir |
21 | | -from tokenizers import Tokenizer |
22 | 22 |
|
23 | 23 |
|
24 | 24 | MINICOIL_MODEL_FILE = "minicoil.triplet.model.npy" |
|
29 | 29 | supported_minicoil_models: list[SparseModelDescription] = [ |
30 | 30 | SparseModelDescription( |
31 | 31 | model="Qdrant/minicoil-v1", |
32 | | - vocab_size=30522, |
| 32 | + vocab_size=19125, |
33 | 33 | description="Sparse embedding model, that resolves semantic meaning of the words, " |
34 | 34 | "while keeping exact keyword match behavior. " |
35 | 35 | "Based on jinaai/jina-embeddings-v2-small-en-tokens", |
@@ -57,7 +57,7 @@ class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): |
57 | 57 | while keeping exact keyword match behavior. |
58 | 58 |
|
59 | 59 | Each vocabulary token is converted into 4d component of a sparse vector, which is then weighted by the token frequency in the corpus. |
60 | | - If the token is not found in the corpus, it is trearted exactly like in BM25. |
| 60 | + If the token is not found in the corpus, it is treated exactly like in BM25. |
61 | 61 | ` |
62 | 62 | The model is based on `jinaai/jina-embeddings-v2-small-en-tokens` |
63 | 63 | """ |
@@ -116,10 +116,10 @@ def __init__( |
116 | 116 |
|
117 | 117 | # Initialize class attributes |
118 | 118 | self.tokenizer: Optional[Tokenizer] = None |
119 | | - self.invert_vocab: Dict[int, str] = {} |
120 | | - self.special_tokens: Set[str] = set() |
121 | | - self.special_tokens_ids: Set[int] = set() |
122 | | - self.stopwords: Set[str] = set() |
| 119 | + self.invert_vocab: dict[int, str] = {} |
| 120 | + self.special_tokens: set[str] = set() |
| 121 | + self.special_tokens_ids: set[int] = set() |
| 122 | + self.stopwords: set[str] = set() |
123 | 123 | self.vocab_resolver: Optional[VocabResolver] = None |
124 | 124 | self.encoder: Optional[Encoder] = None |
125 | 125 | self.output_dim: Optional[int] = None |
@@ -297,7 +297,7 @@ def _post_process_onnx_output( |
297 | 297 | # Size of counts: (unique_words) |
298 | 298 | words_ids = ids_mapping[:, 0].tolist() |
299 | 299 |
|
300 | | - sentence_result: Dict[str, WordEmbedding] = {} |
| 300 | + sentence_result: dict[str, WordEmbedding] = {} |
301 | 301 |
|
302 | 302 | words = [self.vocab_resolver.lookup_word(word_id) for word_id in words_ids] |
303 | 303 |
|
@@ -325,36 +325,25 @@ def _post_process_onnx_output( |
325 | 325 | word=oov_word, forms=[oov_word], count=int(count), word_id=-1, embedding=[1] |
326 | 326 | ) |
327 | 327 |
|
328 | | - if is_query: |
329 | | - yield self.sparse_vector_converter.embedding_to_vector_query( |
| 328 | + if not is_query: |
| 329 | + yield self.sparse_vector_converter.embedding_to_vector( |
330 | 330 | sentence_result, vocab_size=vocab_size, embedding_size=embedding_size |
331 | 331 | ) |
332 | 332 | else: |
333 | | - yield self.sparse_vector_converter.embedding_to_vector( |
| 333 | + yield self.sparse_vector_converter.embedding_to_vector_query( |
334 | 334 | sentence_result, vocab_size=vocab_size, embedding_size=embedding_size |
335 | 335 | ) |
336 | 336 |
|
| 337 | + @classmethod |
| 338 | + def _get_worker_class(cls) -> Type["MiniCoilTextEmbeddingWorker"]: |
| 339 | + return MiniCoilTextEmbeddingWorker |
| 340 | + |
337 | 341 |
|
338 | 342 | class MiniCoilTextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]): |
339 | 343 | def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> MiniCOIL: |
340 | 344 | return MiniCOIL( |
341 | 345 | model_name=model_name, |
342 | 346 | cache_dir=cache_dir, |
| 347 | + threads=1, |
343 | 348 | **kwargs, |
344 | 349 | ) |
345 | | - |
346 | | - |
347 | | -def test_minicoil() -> None: |
348 | | - model = MiniCOIL(model_name="Qdrant/minicoil-v1") |
349 | | - |
350 | | - embedding = next(iter(model.embed("Hello World"))) |
351 | | - |
352 | | - print(embedding) |
353 | | - |
354 | | - embedding = next(iter(model.query_embed("Hello World"))) |
355 | | - |
356 | | - print(embedding) |
357 | | - |
358 | | - |
359 | | -if __name__ == "__main__": |
360 | | - test_minicoil() |
0 commit comments