Skip to content

Commit 37a66d9

Browse files
new: Add sparse type hints (#460)
* new: Add sparse type hints * fix: ndarray -> numpyarray --------- Co-authored-by: George Panchuk <[email protected]>
1 parent b08febb commit 37a66d9

File tree

5 files changed

+30
-29
lines changed

5 files changed

+30
-29
lines changed

fastembed/sparse/bm25.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
self.avg_len = avg_len
124124

125125
model_description = self._get_model_description(model_name)
126-
self.cache_dir = define_cache_dir(cache_dir)
126+
self.cache_dir = str(define_cache_dir(cache_dir))
127127

128128
self._model_dir = self.download_model(
129129
model_description,
@@ -137,7 +137,7 @@ def __init__(
137137
self.disable_stemmer = disable_stemmer
138138

139139
if disable_stemmer:
140-
self.stopwords = set()
140+
self.stopwords: set[str] = set()
141141
self.stemmer = None
142142
else:
143143
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
@@ -239,7 +239,7 @@ def embed(
239239
)
240240

241241
def _stem(self, tokens: list[str]) -> list[str]:
242-
stemmed_tokens = []
242+
stemmed_tokens: list[str] = []
243243
for token in tokens:
244244
lower_token = token.lower()
245245

@@ -262,7 +262,7 @@ def raw_embed(
262262
self,
263263
documents: list[str],
264264
) -> list[SparseEmbedding]:
265-
embeddings = []
265+
embeddings: list[SparseEmbedding] = []
266266
for document in documents:
267267
document = remove_non_alphanumeric(document)
268268
tokens = self.tokenizer.tokenize(document)
@@ -286,8 +286,8 @@ def _term_frequency(self, tokens: list[str]) -> dict[int, float]:
286286
Returns:
287287
dict[int, float]: The token_id to term frequency mapping.
288288
"""
289-
tf_map = {}
290-
counter = defaultdict(int)
289+
tf_map: dict[int, float] = {}
290+
counter: defaultdict[str, int] = defaultdict(int)
291291
for stemmed_token in tokens:
292292
counter[stemmed_token] += 1
293293

fastembed/sparse/bm42.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
self.device_id = None
111111

112112
self.model_description = self._get_model_description(model_name)
113-
self.cache_dir = define_cache_dir(cache_dir)
113+
self.cache_dir = str(define_cache_dir(cache_dir))
114114

115115
self._model_dir = self.download_model(
116116
self.model_description,
@@ -119,10 +119,10 @@ def __init__(
119119
specific_model_path=specific_model_path,
120120
)
121121

122-
self.invert_vocab = {}
122+
self.invert_vocab: dict[int, str] = {}
123123

124-
self.special_tokens = set()
125-
self.special_tokens_ids = set()
124+
self.special_tokens: set[str] = set()
125+
self.special_tokens_ids: set[int] = set()
126126
self.punctuation = set(string.punctuation)
127127
self.stopwords = set(self._load_stopwords(self._model_dir))
128128
self.stemmer = SnowballStemmer(MODEL_TO_LANGUAGE[model_name])
@@ -147,15 +147,15 @@ def load_onnx_model(self) -> None:
147147
self.stopwords = set(self._load_stopwords(self._model_dir))
148148

149149
def _filter_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, Any]]:
150-
result = []
150+
result: list[tuple[str, Any]] = []
151151
for token, value in tokens:
152152
if token in self.stopwords or token in self.punctuation:
153153
continue
154154
result.append((token, value))
155155
return result
156156

157157
def _stem_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, Any]]:
158-
result = []
158+
result: list[tuple[str, Any]] = []
159159
for token, value in tokens:
160160
processed_token = self.stemmer.stem_word(token)
161161
result.append((processed_token, value))
@@ -165,7 +165,7 @@ def _stem_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, An
165165
def _aggregate_weights(
166166
cls, tokens: list[tuple[str, list[int]]], weights: list[float]
167167
) -> list[tuple[str, float]]:
168-
result = []
168+
result: list[tuple[str, float]] = []
169169
for token, idxs in tokens:
170170
sum_weight = sum(weights[idx] for idx in idxs)
171171
result.append((token, sum_weight))
@@ -174,9 +174,9 @@ def _aggregate_weights(
174174
def _reconstruct_bpe(
175175
self, bpe_tokens: Iterable[tuple[int, str]]
176176
) -> list[tuple[str, list[int]]]:
177-
result = []
178-
acc = ""
179-
acc_idx = []
177+
result: list[tuple[str, list[int]]] = []
178+
acc: str = ""
179+
acc_idx: list[int] = []
180180

181181
continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix
182182
continuing_subword_prefix_len = len(continuing_subword_prefix)
@@ -206,7 +206,7 @@ def _rescore_vector(self, vector: dict[str, float]) -> dict[int, float]:
206206
So that the scoring doesn't depend on absolute values assigned by the model, but on the relative importance.
207207
"""
208208

209-
new_vector = {}
209+
new_vector: dict[int, float] = {}
210210

211211
for token, value in vector.items():
212212
token_id = abs(mmh3.hash(token))
@@ -241,7 +241,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Spars
241241

242242
weighted = self._aggregate_weights(stemmed, attention_value)
243243

244-
max_token_weight = {}
244+
max_token_weight: dict[str, float] = {}
245245

246246
for token, weight in weighted:
247247
max_token_weight[token] = max(max_token_weight.get(token, 0), weight)
@@ -304,7 +304,7 @@ def embed(
304304

305305
@classmethod
306306
def _query_rehash(cls, tokens: Iterable[str]) -> dict[int, float]:
307-
result = {}
307+
result: dict[int, float] = {}
308308
for token in tokens:
309309
token_id = abs(mmh3.hash(token))
310310
result[token_id] = 1.0
@@ -334,11 +334,11 @@ def query_embed(
334334
yield SparseEmbedding.from_dict(self._query_rehash(token for token, _ in stemmed))
335335

336336
@classmethod
337-
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
337+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]:
338338
return Bm42TextEmbeddingWorker
339339

340340

341-
class Bm42TextEmbeddingWorker(TextEmbeddingWorker):
341+
class Bm42TextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
342342
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Bm42:
343343
return Bm42(
344344
model_name=model_name,

fastembed/sparse/sparse_embedding_base.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33

44
import numpy as np
55

6+
from fastembed.common.types import NumpyArray
67
from fastembed.common.model_management import ModelManagement
78

89

910
@dataclass
1011
class SparseEmbedding:
11-
values: np.ndarray
12-
indices: np.ndarray
12+
values: NumpyArray
13+
indices: NumpyArray
1314

14-
def as_object(self) -> dict[str, np.ndarray]:
15+
def as_object(self) -> dict[str, NumpyArray]:
1516
return {
1617
"values": self.values,
1718
"indices": self.indices,
@@ -81,5 +82,5 @@ def query_embed(
8182
# This is model-specific, so that different models can have specialized implementations
8283
if isinstance(query, str):
8384
yield from self.embed([query], **kwargs)
84-
if isinstance(query, Iterable):
85+
else:
8586
yield from self.embed(query, **kwargs)

fastembed/sparse/sparse_text_embedding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
3838
]
3939
```
4040
"""
41-
result = []
41+
result: list[dict[str, Any]] = []
4242
for embedding in cls.EMBEDDINGS_REGISTRY:
4343
result.extend(embedding.list_supported_models())
4444
return result

fastembed/sparse/splade_pp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
self.device_id = None
115115

116116
self.model_description = self._get_model_description(model_name)
117-
self.cache_dir = define_cache_dir(cache_dir)
117+
self.cache_dir = str(define_cache_dir(cache_dir))
118118

119119
self._model_dir = self.download_model(
120120
self.model_description,
@@ -171,11 +171,11 @@ def embed(
171171
)
172172

173173
@classmethod
174-
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
174+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]:
175175
return SpladePPEmbeddingWorker
176176

177177

178-
class SpladePPEmbeddingWorker(TextEmbeddingWorker):
178+
class SpladePPEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
179179
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> SpladePP:
180180
return SpladePP(
181181
model_name=model_name,

0 commit comments

Comments
 (0)