Skip to content

Commit ce0e544

Browse files
chore: Generalized query marker and document marker
1 parent 01b15e0 commit ce0e544

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

fastembed/common/preprocessor_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,18 @@ def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]:
6868
token_str = token.get("content", "")
6969
special_token_to_id[token_str] = tokenizer.token_to_id(token_str)
7070

71+
if tokenizer_config["tokenizer_class"] == "BertTokenizer":
72+
query_marker = {"[Q]": 1}
73+
document_marker = {"[D]": 2}
74+
elif tokenizer_config["tokenizer_class"] == "XLMRobertaTokenizer":
75+
query_marker = {"[QueryMarker]": 250002}
76+
document_marker = {"[DocumentMarker]": 250003}
77+
else:
78+
query_marker = {}
79+
document_marker = {}
80+
81+
special_token_to_id.update(query_marker)
82+
special_token_to_id.update(document_marker)
7183
return tokenizer, special_token_to_id
7284

7385

fastembed/late_interaction/colbert.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@
4848

4949

5050
class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]):
51-
QUERY_MARKER_TOKEN_ID = 1
52-
DOCUMENT_MARKER_TOKEN_ID = 2
5351
MIN_QUERY_LENGTH = 32
5452
MASK_TOKENS = ["[MASK]", "<mask>"]
53+
QUERY_MARKER_TOKENS = ["[Q]", "[QueryMarker]"]
54+
DOCUMENT_MARKER_TOKENS = ["[D]", "[DocumentMarker]"]
5555

5656
def _post_process_onnx_output(
5757
self, output: OnnxOutputContext, is_doc: bool = True
@@ -74,9 +74,9 @@ def _preprocess_onnx_input(
7474
self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True
7575
) -> Dict[str, np.ndarray]:
7676
if is_doc:
77-
onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID
77+
onnx_input["input_ids"][:, 1] = self.document_marker_token_id
7878
else:
79-
onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID
79+
onnx_input["input_ids"][:, 1] = self.query_marker_token_id
8080
return onnx_input
8181

8282
def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
@@ -87,16 +87,17 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
8787
)
8888

8989
def _tokenize_query(self, query: str) -> List[Encoding]:
90-
# ". " is added to a query to be replaced with a special query token
91-
query = [f". {query}"]
90+
# "@ " is added to a query to be replaced with a special query token
91+
# please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models
92+
query = [f"@ {query}"]
9293
encoded = self.tokenizer.encode_batch(query)
9394
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
9495
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
9596
prev_padding = None
9697
if self.tokenizer.padding:
9798
prev_padding = self.tokenizer.padding
9899
self.tokenizer.enable_padding(
99-
pad_token=self.MASK_TOKENS[0],
100+
pad_token=self.mask_token,
100101
pad_id=self.mask_token_id,
101102
length=self.MIN_QUERY_LENGTH,
102103
)
@@ -108,8 +109,9 @@ def _tokenize_query(self, query: str) -> List[Encoding]:
108109
return encoded
109110

110111
def _tokenize_documents(self, documents: List[str]) -> List[Encoding]:
111-
# ". " is added to a document to be replaced with a special document token
112-
documents = [". " + doc for doc in documents]
112+
# "@ " is added to a document to be replaced with a special document token
113+
# please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models
114+
documents = ["@ " + doc for doc in documents]
113115
encoded = self.tokenizer.encode_batch(documents)
114116
return encoded
115117

@@ -157,12 +159,28 @@ def __init__(
157159
threads=threads,
158160
providers=providers,
159161
)
160-
self.mask_token_id = next(
162+
self.mask_token_id, self.mask_token = next(
161163
(
162-
self.special_token_to_id[token]
164+
(self.special_token_to_id[token], token)
163165
for token in self.MASK_TOKENS
164166
if token in self.special_token_to_id
165167
),
168+
(None, None),
169+
)
170+
self.query_marker_token_id = next(
171+
(
172+
self.special_token_to_id[token]
173+
for token in self.QUERY_MARKER_TOKENS
174+
if token in self.special_token_to_id
175+
),
176+
None,
177+
)
178+
self.document_marker_token_id = next(
179+
(
180+
self.special_token_to_id[token]
181+
for token in self.DOCUMENT_MARKER_TOKENS
182+
if token in self.special_token_to_id
183+
),
166184
None,
167185
)
168186
self.pad_token_id = self.tokenizer.padding["pad_id"]

0 commit comments

Comments
 (0)