Skip to content

Commit a6841a8

Browse files
authored
Added DeprecationWarning for Splade model (#331)
* Added DeprecationWarning for Splade model * Dry and simple
1 parent 62607c2 commit a6841a8

File tree

2 files changed

+10
-19
lines changed

2 files changed

+10
-19
lines changed

fastembed/sparse/sparse_text_embedding.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
SparseTextEmbeddingBase,
99
)
1010
from fastembed.sparse.splade_pp import SpladePP
11+
import warnings
1112

1213

1314
class SparseTextEmbedding(SparseTextEmbeddingBase):
@@ -50,13 +51,17 @@ def __init__(
5051
**kwargs,
5152
):
5253
super().__init__(model_name, cache_dir, threads, **kwargs)
54+
if model_name == "prithvida/Splade_PP_en_v1":
55+
warnings.warn(
56+
"The right spelling is prithivida/Splade_PP_en_v1. "
57+
"Support of this name will be removed soon, please fix the model_name",
58+
DeprecationWarning,
59+
)
60+
model_name = "prithivida/Splade_PP_en_v1"
5361

5462
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
5563
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
56-
if any(
57-
model_name.lower() == model["model"].lower()
58-
for model in supported_models
59-
):
64+
if any(model_name.lower() == model["model"].lower() for model in supported_models):
6065
self.model = EMBEDDING_MODEL_TYPE(
6166
model_name,
6267
cache_dir,
@@ -95,9 +100,7 @@ def embed(
95100
"""
96101
yield from self.model.embed(documents, batch_size, parallel, **kwargs)
97102

98-
def query_embed(
99-
self, query: Union[str, Iterable[str]], **kwargs
100-
) -> Iterable[SparseEmbedding]:
103+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[SparseEmbedding]:
101104
"""
102105
Embeds queries
103106

fastembed/sparse/splade_pp.py

-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union
22

33
import numpy as np
4-
54
from fastembed.common import OnnxProvider
65
from fastembed.common.onnx_model import OnnxOutputContext
76
from fastembed.common.utils import define_cache_dir
@@ -12,16 +11,6 @@
1211
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
1312

1413
supported_splade_models = [
15-
{
16-
"model": "prithvida/Splade_PP_en_v1",
17-
"vocab_size": 30522,
18-
"description": "Misspelled version of the model. Retained for backward compatibility. Independent Implementation of SPLADE++ Model for English",
19-
"size_in_GB": 0.532,
20-
"sources": {
21-
"hf": "Qdrant/SPLADE_PP_en_v1",
22-
},
23-
"model_file": "model.onnx",
24-
},
2514
{
2615
"model": "prithivida/Splade_PP_en_v1",
2716
"vocab_size": 30522,
@@ -78,7 +67,6 @@ def __init__(
7867
Raises:
7968
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
8069
"""
81-
8270
super().__init__(model_name, cache_dir, threads, **kwargs)
8371

8472
model_description = self._get_model_description(model_name)

0 commit comments

Comments
 (0)