Skip to content

Commit 73e1e5e

Browse files
chore: Add missing returns in defs (#451)
* chore: Add missing returns in defs * remove return type from init * remove incorrect ndarray specifier --------- Co-authored-by: George Panchuk <[email protected]>
1 parent 105d6cf commit 73e1e5e

17 files changed

+64
-49
lines changed

docs/examples/ColBERT_with_FastEmbed.ipynb

+14-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,14 @@
5454
},
5555
{
5656
"data": {
57-
"text/plain": "[{'model': 'colbert-ir/colbertv2.0',\n 'dim': 128,\n 'description': 'Late interaction model',\n 'size_in_GB': 0.44,\n 'sources': {'hf': 'colbert-ir/colbertv2.0'},\n 'model_file': 'model.onnx'}]"
57+
"text/plain": [
58+
"[{'model': 'colbert-ir/colbertv2.0',\n",
59+
" 'dim': 128,\n",
60+
" 'description': 'Late interaction model',\n",
61+
" 'size_in_GB': 0.44,\n",
62+
" 'sources': {'hf': 'colbert-ir/colbertv2.0'},\n",
63+
" 'model_file': 'model.onnx'}]"
64+
]
5865
},
5966
"execution_count": 1,
6067
"metadata": {},
@@ -212,7 +219,9 @@
212219
"outputs": [
213220
{
214221
"data": {
215-
"text/plain": "((26, 128), (32, 128))"
222+
"text/plain": [
223+
"((26, 128), (32, 128))"
224+
]
216225
},
217226
"execution_count": 18,
218227
"metadata": {},
@@ -271,7 +280,9 @@
271280
"import numpy as np\n",
272281
"\n",
273282
"\n",
274-
"def compute_relevance_scores(query_embedding: np.array, document_embeddings: np.array, k: int):\n",
283+
"def compute_relevance_scores(\n",
284+
" query_embedding: np.array, document_embeddings: np.array, k: int\n",
285+
") -> list[int]:\n",
275286
" \"\"\"\n",
276287
" Compute relevance scores for top-k documents given a query.\n",
277288
"\n",

docs/examples/FastEmbed_vs_HF_Comparison.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@
152152
" HuggingFace Transformer implementation of FlagEmbedding\n",
153153
" \"\"\"\n",
154154
"\n",
155-
" def __init__(self, model_id: str):\n",
155+
" def __init__(self, model_id: str) -> None:\n",
156156
" self.model = AutoModel.from_pretrained(model_id)\n",
157157
" self.tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
158158
"\n",

docs/examples/Hybrid_Search.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@
488488
}
489489
],
490490
"source": [
491-
"def make_sparse_embedding(texts: list[str]):\n",
491+
"def make_sparse_embedding(texts: list[str]) -> list[SparseEmbedding]:\n",
492492
" return list(sparse_model.embed(texts, batch_size=32))\n",
493493
"\n",
494494
"\n",
@@ -615,7 +615,7 @@
615615
}
616616
],
617617
"source": [
618-
"def get_tokens_and_weights(sparse_embedding, model_name):\n",
618+
"def get_tokens_and_weights(sparse_embedding, model_name) -> dict[str, float]:\n",
619619
" # Find the tokenizer for the model\n",
620620
" tokenizer_source = None\n",
621621
" for model_info in SparseTextEmbedding.list_supported_models():\n",
@@ -626,7 +626,7 @@
626626
" raise ValueError(f\"Model {model_name} not found in the supported models.\")\n",
627627
"\n",
628628
" tokenizer = AutoTokenizer.from_pretrained(tokenizer_source)\n",
629-
" token_weight_dict = {}\n",
629+
" token_weight_dict: dict[str, float] = {}\n",
630630
" for i in range(len(sparse_embedding.indices)):\n",
631631
" token = tokenizer.decode([sparse_embedding.indices[i]])\n",
632632
" weight = sparse_embedding.values[i]\n",

fastembed/common/model_management.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, int]]) -> Non
255255
return result
256256

257257
@classmethod
258-
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
258+
def decompress_to_cache(cls, targz_path: str, cache_dir: str) -> str:
259259
"""
260260
Decompresses a .tar.gz file to a cache directory.
261261

fastembed/image/onnx_image_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _get_worker_class(cls) -> Type["ImageEmbeddingWorker"]:
2424
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
2525
raise NotImplementedError("Subclasses must implement this method")
2626

27-
def __init__(self) -> None:
27+
def __init__(self):
2828
super().__init__()
2929
self.processor = None
3030

fastembed/image/transform/functional.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def rescale(image: np.ndarray, scale: float, dtype=np.float32) -> np.ndarray:
118118
return (image * scale).astype(dtype)
119119

120120

121-
def pil2ndarray(image: Union[Image.Image, np.ndarray]):
121+
def pil2ndarray(image: Union[Image.Image, np.ndarray]) -> np.ndarray:
122122
if isinstance(image, Image.Image):
123123
return np.asarray(image).transpose((2, 0, 1))
124124
return image

fastembed/image/transform/operators.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
133133
return cls(transforms=transforms)
134134

135135
@staticmethod
136-
def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]):
136+
def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]) -> None:
137137
transforms.append(ConvertToRGB())
138138

139139
@classmethod
140-
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]):
140+
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> None:
141141
mode = config.get("image_processor_type", "CLIPImageProcessor")
142142
if mode == "CLIPImageProcessor":
143143
if config.get("do_resize", False):
@@ -200,7 +200,7 @@ def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]):
200200
raise ValueError(f"Preprocessor {mode} is not supported")
201201

202202
@staticmethod
203-
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
203+
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> None:
204204
mode = config.get("image_processor_type", "CLIPImageProcessor")
205205
if mode == "CLIPImageProcessor":
206206
if config.get("do_center_crop", False):
@@ -220,24 +220,24 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
220220
raise ValueError(f"Preprocessor {mode} is not supported")
221221

222222
@staticmethod
223-
def _get_pil2ndarray(transforms: list[Transform], config: dict[str, Any]):
223+
def _get_pil2ndarray(transforms: list[Transform], config: dict[str, Any]) -> None:
224224
transforms.append(PILtoNDarray())
225225

226226
@staticmethod
227-
def _get_rescale(transforms: list[Transform], config: dict[str, Any]):
227+
def _get_rescale(transforms: list[Transform], config: dict[str, Any]) -> None:
228228
if config.get("do_rescale", True):
229229
rescale_factor = config.get("rescale_factor", 1 / 255)
230230
transforms.append(Rescale(scale=rescale_factor))
231231

232232
@staticmethod
233-
def _get_normalize(transforms: list[Transform], config: dict[str, Any]):
233+
def _get_normalize(transforms: list[Transform], config: dict[str, Any]) -> None:
234234
if config.get("do_normalize", False):
235235
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
236236
elif "mean" in config and "std" in config:
237237
transforms.append(Normalize(mean=config["mean"], std=config["std"]))
238238

239239
@staticmethod
240-
def _get_pad2square(transforms: list[Transform], config: dict[str, Any]):
240+
def _get_pad2square(transforms: list[Transform], config: dict[str, Any]) -> None:
241241
mode = config.get("image_processor_type", "CLIPImageProcessor")
242242
if mode == "CLIPImageProcessor":
243243
pass

fastembed/rerank/cross_encoder/onnx_text_model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
from multiprocessing import get_all_start_methods
33
from pathlib import Path
4-
from typing import Any, Iterable, Optional, Sequence, Type
4+
from typing import Any, Iterable, Optional, Sequence, Type, Union
55

66
import numpy as np
7+
from numpy.typing import NDArray
78
from tokenizers import Encoding
89

910
from fastembed.common.onnx_model import (
@@ -46,7 +47,9 @@ def _load_onnx_model(
4647
def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
4748
return self.tokenizer.encode_batch(pairs)
4849

49-
def _build_onnx_input(self, tokenized_input):
50+
def _build_onnx_input(
51+
self, tokenized_input
52+
) -> dict[str, NDArray[Union[np.float32, np.int64]]]:
5053
input_names = {node.name for node in self.model.get_inputs()}
5154
inputs = {
5255
"input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),

fastembed/text/onnx_text_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
2323
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
2424
raise NotImplementedError("Subclasses must implement this method")
2525

26-
def __init__(self) -> None:
26+
def __init__(self):
2727
super().__init__()
2828
self.tokenizer = None
2929
self.special_token_to_id = {}

fastembed/text/text_embedding_base.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[np.ndar
4343
yield from self.embed(texts, **kwargs)
4444

4545
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[np.ndarray]:
46+
4647
"""
4748
Embeds queries
4849

tests/test_attention_embeddings.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
11-
def test_attention_embeddings(model_name):
11+
def test_attention_embeddings(model_name) -> None:
1212
is_ci = os.getenv("CI")
1313
model = SparseTextEmbedding(model_name=model_name)
1414

@@ -71,7 +71,7 @@ def test_attention_embeddings(model_name):
7171

7272

7373
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
74-
def test_parallel_processing(model_name):
74+
def test_parallel_processing(model_name) -> None:
7575
is_ci = os.getenv("CI")
7676

7777
model = SparseTextEmbedding(model_name=model_name)
@@ -96,7 +96,7 @@ def test_parallel_processing(model_name):
9696

9797

9898
@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
99-
def test_multilanguage(model_name):
99+
def test_multilanguage(model_name) -> None:
100100
is_ci = os.getenv("CI")
101101

102102
docs = ["Mangez-vous vraiment des grenouilles?", "Je suis au lit"]
@@ -122,7 +122,7 @@ def test_multilanguage(model_name):
122122

123123

124124
@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
125-
def test_special_characters(model_name):
125+
def test_special_characters(model_name) -> None:
126126
is_ci = os.getenv("CI")
127127

128128
docs = [
@@ -145,7 +145,7 @@ def test_special_characters(model_name):
145145

146146

147147
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions"])
148-
def test_lazy_load(model_name):
148+
def test_lazy_load(model_name) -> None:
149149
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
150150
assert not hasattr(model.model, "model")
151151
docs = ["hello world", "flag embedding"]

tests/test_image_onnx_embeddings.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
}
2828

2929

30-
def test_embedding():
30+
def test_embedding() -> None:
3131
is_ci = os.getenv("CI")
3232

3333
for model_desc in ImageEmbedding.list_supported_models():
@@ -61,7 +61,7 @@ def test_embedding():
6161

6262

6363
@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
64-
def test_batch_embedding(n_dims, model_name):
64+
def test_batch_embedding(n_dims, model_name) -> None:
6565
is_ci = os.getenv("CI")
6666
model = ImageEmbedding(model_name=model_name)
6767
n_images = 32
@@ -81,7 +81,7 @@ def test_batch_embedding(n_dims, model_name):
8181

8282

8383
@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
84-
def test_parallel_processing(n_dims, model_name):
84+
def test_parallel_processing(n_dims, model_name) -> None:
8585
is_ci = os.getenv("CI")
8686
model = ImageEmbedding(model_name=model_name)
8787

@@ -109,7 +109,7 @@ def test_parallel_processing(n_dims, model_name):
109109

110110

111111
@pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"])
112-
def test_lazy_load(model_name):
112+
def test_lazy_load(model_name) -> None:
113113
is_ci = os.getenv("CI")
114114
model = ImageEmbedding(model_name=model_name, lazy_load=True)
115115
assert not hasattr(model.model, "model")

tests/test_multi_gpu.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@pytest.mark.skip(reason="Requires a multi-gpu server")
1515
@pytest.mark.parametrize("device_id", [None, 0, 1])
16-
def test_gpu_via_providers(device_id):
16+
def test_gpu_via_providers(device_id) -> None:
1717
docs = ["hello world", "flag embedding"]
1818

1919
device_id = device_id if device_id is not None else 0
@@ -85,7 +85,7 @@ def test_gpu_via_providers(device_id):
8585

8686
@pytest.mark.skip(reason="Requires a multi-gpu server")
8787
@pytest.mark.parametrize("device_ids", [None, [0], [1], [0, 1]])
88-
def test_gpu_cuda_device_ids(device_ids):
88+
def test_gpu_cuda_device_ids(device_ids) -> None:
8989
docs = ["hello world", "flag embedding"]
9090
device_id = device_ids[0] if device_ids else 0
9191
embedding_model = TextEmbedding(
@@ -170,7 +170,7 @@ def test_gpu_cuda_device_ids(device_ids):
170170
@pytest.mark.parametrize(
171171
"device_ids,parallel", [(None, None), (None, 2), ([1], None), ([1], 1), ([1], 2), ([0, 1], 2)]
172172
)
173-
def test_multi_gpu_parallel_inference(device_ids, parallel):
173+
def test_multi_gpu_parallel_inference(device_ids, parallel) -> None:
174174
docs = ["hello world", "flag embedding"] * 100
175175
batch_size = 5
176176

tests/test_sparse_embeddings.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
docs = ["Hello World"]
5050

5151

52-
def test_batch_embedding():
52+
def test_batch_embedding() -> None:
5353
is_ci = os.getenv("CI")
5454
docs_to_embed = docs * 10
5555

@@ -64,7 +64,7 @@ def test_batch_embedding():
6464
delete_model_cache(model.model._model_dir)
6565

6666

67-
def test_single_embedding():
67+
def test_single_embedding() -> None:
6868
is_ci = os.getenv("CI")
6969
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
7070
model = SparseTextEmbedding(model_name=model_name)
@@ -80,7 +80,7 @@ def test_single_embedding():
8080
delete_model_cache(model.model._model_dir)
8181

8282

83-
def test_parallel_processing():
83+
def test_parallel_processing() -> None:
8484
is_ci = os.getenv("CI")
8585
model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
8686
docs = ["hello world", "flag embedding"] * 30
@@ -111,15 +111,15 @@ def test_parallel_processing():
111111

112112

113113
@pytest.fixture
114-
def bm25_instance():
114+
def bm25_instance() -> None:
115115
ci = os.getenv("CI", True)
116116
model = Bm25("Qdrant/bm25", language="english")
117117
yield model
118118
if ci:
119119
delete_model_cache(model._model_dir)
120120

121121

122-
def test_stem_with_stopwords_and_punctuation(bm25_instance):
122+
def test_stem_with_stopwords_and_punctuation(bm25_instance) -> None:
123123
# Setup
124124
bm25_instance.stopwords = {"the", "is", "a"}
125125
bm25_instance.punctuation = {".", ",", "!"}
@@ -135,7 +135,7 @@ def test_stem_with_stopwords_and_punctuation(bm25_instance):
135135
assert result == expected, f"Expected {expected}, but got {result}"
136136

137137

138-
def test_stem_case_insensitive_stopwords(bm25_instance):
138+
def test_stem_case_insensitive_stopwords(bm25_instance) -> None:
139139
# Setup
140140
bm25_instance.stopwords = {"the", "is", "a"}
141141
bm25_instance.punctuation = {".", ",", "!"}
@@ -152,7 +152,7 @@ def test_stem_case_insensitive_stopwords(bm25_instance):
152152

153153

154154
@pytest.mark.parametrize("disable_stemmer", [True, False])
155-
def test_disable_stemmer_behavior(disable_stemmer):
155+
def test_disable_stemmer_behavior(disable_stemmer) -> None:
156156
# Setup
157157
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
158158
model.stopwords = {"the", "is", "a"}
@@ -176,7 +176,7 @@ def test_disable_stemmer_behavior(disable_stemmer):
176176
"model_name",
177177
["prithivida/Splade_PP_en_v1"],
178178
)
179-
def test_lazy_load(model_name):
179+
def test_lazy_load(model_name) -> None:
180180
is_ci = os.getenv("CI")
181181
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
182182
assert not hasattr(model.model, "model")

0 commit comments

Comments
 (0)