4848
4949
5050class Colbert (LateInteractionTextEmbeddingBase , OnnxTextModel [np .ndarray ]):
51- QUERY_MARKER_TOKEN_ID = 1
52- DOCUMENT_MARKER_TOKEN_ID = 2
5351 MIN_QUERY_LENGTH = 32
5452 MASK_TOKENS = ["[MASK]" , "<mask>" ]
53+ QUERY_MARKER_TOKENS = ["[Q]" , "[QueryMarker]" ]
54+ DOCUMENT_MARKER_TOKENS = ["[D]" , "[DocumentMarker]" ]
5555
5656 def _post_process_onnx_output (
5757 self , output : OnnxOutputContext , is_doc : bool = True
@@ -74,9 +74,9 @@ def _preprocess_onnx_input(
7474 self , onnx_input : Dict [str , np .ndarray ], is_doc : bool = True
7575 ) -> Dict [str , np .ndarray ]:
7676 if is_doc :
77- onnx_input ["input_ids" ][:, 1 ] = self .DOCUMENT_MARKER_TOKEN_ID
77+ onnx_input ["input_ids" ][:, 1 ] = self .document_marker_token_id
7878 else :
79- onnx_input ["input_ids" ][:, 1 ] = self .QUERY_MARKER_TOKEN_ID
79+ onnx_input ["input_ids" ][:, 1 ] = self .query_marker_token_id
8080 return onnx_input
8181
8282 def tokenize (self , documents : List [str ], is_doc : bool = True ) -> List [Encoding ]:
@@ -87,16 +87,17 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]:
8787 )
8888
8989 def _tokenize_query (self , query : str ) -> List [Encoding ]:
90- # ". " is added to a query to be replaced with a special query token
91- query = [f". { query } " ]
90+ # "@ " is added to a query to be replaced with a special query token
91+ # please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models
92+ query = [f"@ { query } " ]
9293 encoded = self .tokenizer .encode_batch (query )
9394 # colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
9495 if len (encoded [0 ].ids ) < self .MIN_QUERY_LENGTH :
9596 prev_padding = None
9697 if self .tokenizer .padding :
9798 prev_padding = self .tokenizer .padding
9899 self .tokenizer .enable_padding (
99- pad_token = self .MASK_TOKENS [ 0 ] ,
100+ pad_token = self .mask_token ,
100101 pad_id = self .mask_token_id ,
101102 length = self .MIN_QUERY_LENGTH ,
102103 )
@@ -108,8 +109,9 @@ def _tokenize_query(self, query: str) -> List[Encoding]:
108109 return encoded
109110
110111 def _tokenize_documents (self , documents : List [str ]) -> List [Encoding ]:
111- # ". " is added to a document to be replaced with a special document token
112- documents = [". " + doc for doc in documents ]
112+ # "@ " is added to a document to be replaced with a special document token
113+ # please make sure that "@ " is considered as one token in all tokenizers we use in Late Interaction Models
114+ documents = ["@ " + doc for doc in documents ]
113115 encoded = self .tokenizer .encode_batch (documents )
114116 return encoded
115117
@@ -157,12 +159,28 @@ def __init__(
157159 threads = threads ,
158160 providers = providers ,
159161 )
160- self .mask_token_id = next (
162+ self .mask_token_id , self . mask_token = next (
161163 (
162- self .special_token_to_id [token ]
164+ ( self .special_token_to_id [token ], token )
163165 for token in self .MASK_TOKENS
164166 if token in self .special_token_to_id
165167 ),
168+ (None , None ),
169+ )
170+ self .query_marker_token_id = next (
171+ (
172+ self .special_token_to_id [token ]
173+ for token in self .QUERY_MARKER_TOKENS
174+ if token in self .special_token_to_id
175+ ),
176+ None ,
177+ )
178+ self .document_marker_token_id = next (
179+ (
180+ self .special_token_to_id [token ]
181+ for token in self .DOCUMENT_MARKER_TOKENS
182+ if token in self .special_token_to_id
183+ ),
166184 None ,
167185 )
168186 self .pad_token_id = self .tokenizer .padding ["pad_id" ]
0 commit comments