|
8 | 8 | from fastembed.common.types import NumpyArray |
9 | 9 | from fastembed.common import OnnxProvider |
10 | 10 | from fastembed.common.onnx_model import OnnxOutputContext |
11 | | -from fastembed.common.utils import define_cache_dir |
| 11 | +from fastembed.common.utils import define_cache_dir, iter_batch |
12 | 12 | from fastembed.late_interaction.late_interaction_embedding_base import ( |
13 | 13 | LateInteractionTextEmbeddingBase, |
14 | 14 | ) |
@@ -96,6 +96,38 @@ def _tokenize_documents(self, documents: list[str]) -> list[Encoding]: |
96 | 96 | encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr] |
97 | 97 | return encoded |
98 | 98 |
|
| 99 | + def token_count( |
| 100 | + self, |
| 101 | + texts: Union[str, Iterable[str]], |
| 102 | + batch_size: int = 1024, |
| 103 | + is_doc: bool = True, |
| 104 | + include_extension: bool = False, |
| 105 | + **kwargs: Any, |
| 106 | + ) -> int: |
| 107 | + if not hasattr(self, "model") or self.model is None: |
| 108 | + self.load_onnx_model() # loads the tokenizer as well |
| 109 | + token_num = 0 |
| 110 | + texts = [texts] if isinstance(texts, str) else texts |
| 111 | + tokenizer = self.tokenizer if is_doc else self.query_tokenizer |
| 112 | + assert tokenizer is not None |
| 113 | + for batch in iter_batch(texts, batch_size): |
| 114 | + for tokens in tokenizer.encode_batch(batch): |
| 115 | + if is_doc: |
| 116 | + token_num += sum(tokens.attention_mask) |
| 117 | + else: |
| 118 | + attend_count = sum(tokens.attention_mask) |
| 119 | + if include_extension: |
| 120 | + token_num += max(attend_count, self.MIN_QUERY_LENGTH) |
| 121 | + |
| 122 | + else: |
| 123 | + token_num += attend_count |
| 124 | + if include_extension: |
| 125 | + token_num += len( |
| 126 | + batch |
| 127 | + ) # add 1 for each cls.DOC_MARKER_TOKEN_ID or cls.QUERY_MARKER_TOKEN_ID |
| 128 | + |
| 129 | + return token_num |
| 130 | + |
99 | 131 | @classmethod |
100 | 132 | def _list_supported_models(cls) -> list[DenseModelDescription]: |
101 | 133 | """Lists the supported models. |
|
0 commit comments