Skip to content

Commit dfd25d4

Browse files
generalljoein
andauthored
Attention sparse embeddings (#235)
* WIP: sparse embeddings using attention * support for stopwords * apply stopwords * proceed implementation of sparse attention embeddings (#234) * complete inference * query embed + comment * use simpler weights formula instead of sorting of words * update tests * fix: fix bm42 usage, add query_embed to SparseTextEmbedding, update tests --------- Co-authored-by: George <[email protected]>
1 parent 316c336 commit dfd25d4

15 files changed

+486
-66
lines changed

fastembed/common/onnx_model.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import dataclass
12
from pathlib import Path
23
from typing import Any, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Sequence
34
import warnings
@@ -13,13 +14,19 @@
1314
T = TypeVar("T")
1415

1516

17+
@dataclass
18+
class OnnxOutputContext:
19+
model_output: np.ndarray
20+
attention_mask: Optional[np.ndarray] = None
21+
input_ids: Optional[np.ndarray] = None
22+
23+
1624
class OnnxModel(Generic[T]):
1725
@classmethod
1826
def _get_worker_class(cls) -> Type["EmbeddingWorker"]:
1927
raise NotImplementedError("Subclasses must implement this method")
2028

21-
@classmethod
22-
def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> Iterable[T]:
29+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
2330
raise NotImplementedError("Subclasses must implement this method")
2431

2532
def __init__(self) -> None:
@@ -74,7 +81,7 @@ def load_onnx_model(
7481
RuntimeWarning,
7582
)
7683

77-
def onnx_embed(self, *args, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
84+
def onnx_embed(self, *args, **kwargs) -> OnnxOutputContext:
7885
raise NotImplementedError("Subclasses must implement this method")
7986

8087

fastembed/common/preprocessor_utils.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import json
22
from pathlib import Path
3+
from typing import Tuple
34

45
from tokenizers import Tokenizer, AddedToken
56

67
from fastembed.image.transform.operators import Compose
78

89

9-
def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
10+
def load_special_tokens(model_dir: Path) -> dict:
11+
tokens_map_path = model_dir / "special_tokens_map.json"
12+
if not tokens_map_path.exists():
13+
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")
14+
15+
with open(str(tokens_map_path)) as tokens_map_file:
16+
tokens_map = json.load(tokens_map_file)
17+
18+
return tokens_map
19+
20+
21+
def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tuple[Tokenizer, dict]:
1022
config_path = model_dir / "config.json"
1123
if not config_path.exists():
1224
raise ValueError(f"Could not find config.json in {model_dir}")
@@ -19,18 +31,13 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
1931
if not tokenizer_config_path.exists():
2032
raise ValueError(f"Could not find tokenizer_config.json in {model_dir}")
2133

22-
tokens_map_path = model_dir / "special_tokens_map.json"
23-
if not tokens_map_path.exists():
24-
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")
25-
2634
with open(str(config_path)) as config_file:
2735
config = json.load(config_file)
2836

2937
with open(str(tokenizer_config_path)) as tokenizer_config_file:
3038
tokenizer_config = json.load(tokenizer_config_file)
3139

32-
with open(str(tokens_map_path)) as tokens_map_file:
33-
tokens_map = json.load(tokens_map_file)
40+
tokens_map = load_special_tokens(model_dir)
3441

3542
tokenizer = Tokenizer.from_file(str(tokenizer_path))
3643
tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length))
@@ -44,7 +51,16 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
4451
elif isinstance(token, dict):
4552
tokenizer.add_special_tokens([AddedToken(**token)])
4653

47-
return tokenizer
54+
special_token_to_id = {}
55+
56+
for token in tokens_map.values():
57+
if isinstance(token, str):
58+
special_token_to_id[token] = tokenizer.token_to_id(token)
59+
elif isinstance(token, dict):
60+
token_str = token.get("content", "")
61+
special_token_to_id[token_str] = tokenizer.token_to_id(token_str)
62+
63+
return tokenizer, special_token_to_id
4864

4965

5066
def load_preprocessor(model_dir: Path) -> Compose:

fastembed/image/onnx_embedding.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from fastembed.common.onnx_model import OnnxOutputContext
56
from fastembed.common.utils import normalize, define_cache_dir
67
from fastembed.common import ImageInput, OnnxProvider
78
from fastembed.image.image_embedding_base import ImageEmbeddingBase
@@ -108,9 +109,8 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str,
108109

109110
return onnx_input
110111

111-
@classmethod
112-
def _post_process_onnx_output(cls, output: np.ndarray) -> Iterable[np.ndarray]:
113-
return normalize(output).astype(np.float32)
112+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
113+
return normalize(output.model_output).astype(np.float32)
114114

115115

116116
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):

fastembed/image/onnx_image_model.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from fastembed.common.preprocessor_utils import load_preprocessor
11-
from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker, T
11+
from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker, T, OnnxOutputContext
1212
from fastembed.common import PathInput, ImageInput, OnnxProvider
1313
from fastembed.common.utils import iter_batch
1414
from fastembed.parallel_processor import ParallelWorkerPool
@@ -21,8 +21,7 @@ class OnnxImageModel(OnnxModel[T]):
2121
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker"]:
2222
raise NotImplementedError("Subclasses must implement this method")
2323

24-
@classmethod
25-
def _post_process_onnx_output(cls, output: np.ndarray) -> Iterable[T]:
24+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
2625
raise NotImplementedError("Subclasses must implement this method")
2726

2827
def __init__(self) -> None:
@@ -47,7 +46,7 @@ def load_onnx_model(
4746
)
4847
self.processor = load_preprocessor(model_dir=model_dir)
4948

50-
def onnx_embed(self, images: List[PathInput]) -> np.ndarray:
49+
def onnx_embed(self, images: List[PathInput]) -> OnnxOutputContext:
5150
with contextlib.ExitStack():
5251
image_files = [Image.open(image) for image in images]
5352
encoded = self.processor(image_files)
@@ -56,7 +55,9 @@ def onnx_embed(self, images: List[PathInput]) -> np.ndarray:
5655

5756
model_output = self.model.run(None, onnx_input)
5857
embeddings = model_output[0]
59-
return embeddings
58+
return OnnxOutputContext(
59+
model_output=embeddings
60+
)
6061

6162
def _embed_images(
6263
self,

0 commit comments

Comments
 (0)