Skip to content

Commit b718cc6

Browse files
authored
new: add token count method (#583)
* new: add token count method * fix: fix mypy * fix: load model in token_count * fix: remove debug code
1 parent 2ba8990 commit b718cc6

25 files changed

+340
-2
lines changed

fastembed/late_interaction/colbert.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fastembed.common.types import NumpyArray
99
from fastembed.common import OnnxProvider
1010
from fastembed.common.onnx_model import OnnxOutputContext
11-
from fastembed.common.utils import define_cache_dir
11+
from fastembed.common.utils import define_cache_dir, iter_batch
1212
from fastembed.late_interaction.late_interaction_embedding_base import (
1313
LateInteractionTextEmbeddingBase,
1414
)
@@ -96,6 +96,38 @@ def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
9696
encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
9797
return encoded
9898

99+
def token_count(
100+
self,
101+
texts: Union[str, Iterable[str]],
102+
batch_size: int = 1024,
103+
is_doc: bool = True,
104+
include_extension: bool = False,
105+
**kwargs: Any,
106+
) -> int:
107+
if not hasattr(self, "model") or self.model is None:
108+
self.load_onnx_model() # loads the tokenizer as well
109+
token_num = 0
110+
texts = [texts] if isinstance(texts, str) else texts
111+
tokenizer = self.tokenizer if is_doc else self.query_tokenizer
112+
assert tokenizer is not None
113+
for batch in iter_batch(texts, batch_size):
114+
for tokens in tokenizer.encode_batch(batch):
115+
if is_doc:
116+
token_num += sum(tokens.attention_mask)
117+
else:
118+
attend_count = sum(tokens.attention_mask)
119+
if include_extension:
120+
token_num += max(attend_count, self.MIN_QUERY_LENGTH)
121+
122+
else:
123+
token_num += attend_count
124+
if include_extension:
125+
token_num += len(
126+
batch
127+
) # add 1 for each cls.DOC_MARKER_TOKEN_ID or cls.QUERY_MARKER_TOKEN_ID
128+
129+
return token_num
130+
99131
@classmethod
100132
def _list_supported_models(cls) -> list[DenseModelDescription]:
101133
"""Lists the supported models.

fastembed/late_interaction/late_interaction_embedding_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,12 @@ def get_embedding_size(cls, model_name: str) -> int:
6969
def embedding_size(self) -> int:
7070
"""Returns embedding size for the current model"""
7171
raise NotImplementedError("Subclasses must implement this method")
72+
73+
def token_count(
74+
self,
75+
texts: Union[str, Iterable[str]],
76+
batch_size: int = 1024,
77+
**kwargs: Any,
78+
) -> int:
79+
"""Returns the number of tokens in the texts."""
80+
raise NotImplementedError("Subclasses must implement this method")

fastembed/late_interaction/late_interaction_text_embedding.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,30 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterab
151151

152152
# This is model-specific, so that different models can have specialized implementations
153153
yield from self.model.query_embed(query, **kwargs)
154+
155+
def token_count(
156+
self,
157+
texts: Union[str, Iterable[str]],
158+
batch_size: int = 1024,
159+
is_doc: bool = True,
160+
include_extension: bool = False,
161+
**kwargs: Any,
162+
) -> int:
163+
"""Returns the number of tokens in the texts.
164+
165+
Args:
166+
texts (str | Iterable[str]): The list of texts to embed.
167+
batch_size (int): Batch size for encoding
168+
is_doc (bool): Whether the texts are documents (disable embedding a query with include_mask=True).
169+
include_extension (bool): Turn on to count DOC / QUERY marker tokens, and [MASK] token in query mode.
170+
171+
Returns:
172+
int: Sum of number of tokens in the texts.
173+
"""
174+
return self.model.token_count(
175+
texts,
176+
batch_size=batch_size,
177+
is_doc=is_doc,
178+
include_extension=include_extension,
179+
**kwargs,
180+
)

fastembed/late_interaction_multimodal/colpali.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastembed.common import OnnxProvider, ImageInput
77
from fastembed.common.onnx_model import OnnxOutputContext
88
from fastembed.common.types import NumpyArray
9-
from fastembed.common.utils import define_cache_dir
9+
from fastembed.common.utils import define_cache_dir, iter_batch
1010
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
1111
LateInteractionMultimodalEmbeddingBase,
1212
)
@@ -172,6 +172,23 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
172172
encoded = self.tokenizer.encode_batch(texts_query) # type: ignore[union-attr]
173173
return encoded
174174

175+
def token_count(
176+
self,
177+
texts: Union[str, Iterable[str]],
178+
batch_size: int = 1024,
179+
include_extension: bool = False,
180+
**kwargs: Any,
181+
) -> int:
182+
if not hasattr(self, "model") or self.model is None:
183+
self.load_onnx_model() # loads the tokenizer as well
184+
token_num = 0
185+
texts = [texts] if isinstance(texts, str) else texts
186+
assert self.tokenizer is not None
187+
tokenize_func = self.tokenize if include_extension else self.tokenizer.encode_batch
188+
for batch in iter_batch(texts, batch_size):
189+
token_num += sum([sum(encoding.attention_mask) for encoding in tokenize_func(batch)])
190+
return token_num
191+
175192
def _preprocess_onnx_text_input(
176193
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
177194
) -> dict[str, NumpyArray]:

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,24 @@ def embed_image(
162162
List of embeddings, one per image
163163
"""
164164
yield from self.model.embed_image(images, batch_size, parallel, **kwargs)
165+
166+
def token_count(
167+
self,
168+
texts: Union[str, Iterable[str]],
169+
batch_size: int = 1024,
170+
include_extension: bool = False,
171+
**kwargs: Any,
172+
) -> int:
173+
"""Returns the number of tokens in the texts.
174+
175+
Args:
176+
texts (str | Iterable[str]): The list of texts to embed.
177+
batch_size (int): Batch size for encoding
178+
include_extension (bool): Whether to include tokens added by preprocessing
179+
180+
Returns:
181+
int: Sum of number of tokens in the texts.
182+
"""
183+
return self.model.token_count(
184+
texts, batch_size=batch_size, include_extension=include_extension, **kwargs
185+
)

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,11 @@ def get_embedding_size(cls, model_name: str) -> int:
7676
def embedding_size(self) -> int:
7777
"""Returns embedding size for the current model"""
7878
raise NotImplementedError("Subclasses must implement this method")
79+
80+
def token_count(
81+
self,
82+
texts: Union[str, Iterable[str]],
83+
**kwargs: Any,
84+
) -> int:
85+
"""Returns the number of tokens in the texts."""
86+
raise NotImplementedError("Subclasses must implement this method")

fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,20 @@ def _post_process_onnx_output(
207207
) -> Iterable[float]:
208208
return (float(elem) for elem in output.model_output)
209209

210+
def token_count(
211+
self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any
212+
) -> int:
213+
"""Returns the number of tokens in the pairs.
214+
215+
Args:
216+
pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized
217+
batch_size: Batch size for tokenizing
218+
219+
Returns:
220+
token count: overall number of tokens in the pairs
221+
"""
222+
return self._token_count(pairs, batch_size=batch_size, **kwargs)
223+
210224

211225
class TextCrossEncoderWorker(TextRerankerWorker):
212226
def init_embedding(

fastembed/rerank/cross_encoder/onnx_text_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,20 @@ def _preprocess_onnx_input(
165165
"""
166166
return onnx_input
167167

168+
def _token_count(
169+
self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **_: Any
170+
) -> int:
171+
if not hasattr(self, "model") or self.model is None:
172+
self.load_onnx_model() # loads the tokenizer as well
173+
174+
token_num = 0
175+
assert self.tokenizer is not None
176+
for batch in iter_batch(pairs, batch_size):
177+
for tokens in self.tokenizer.encode_batch(batch):
178+
token_num += sum(tokens.attention_mask)
179+
180+
return token_num
181+
168182

169183
class TextRerankerWorker(EmbeddingWorker[float]):
170184
def __init__(

fastembed/rerank/cross_encoder/text_cross_encoder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,17 @@ def add_custom_model(
161161
additional_files=additional_files or [],
162162
)
163163
)
164+
165+
def token_count(
166+
self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any
167+
) -> int:
168+
"""Returns the number of tokens in the pairs.
169+
170+
Args:
171+
pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized
172+
batch_size: Batch size for tokenizing
173+
174+
Returns:
175+
token count: overall number of tokens in the pairs
176+
"""
177+
return self.model.token_count(pairs, batch_size=batch_size, **kwargs)

fastembed/rerank/cross_encoder/text_cross_encoder_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,7 @@ def rerank_pairs(
5757
Iterable[float]: Scores for each individual pair
5858
"""
5959
raise NotImplementedError("This method should be overridden by subclasses")
60+
61+
def token_count(self, pairs: Iterable[tuple[str, str]], **kwargs: Any) -> int:
62+
"""Returns the number of tokens in the pairs."""
63+
raise NotImplementedError("This method should be overridden by subclasses")

0 commit comments

Comments
 (0)