Skip to content

Commit 0f79d3f

Browse files
feat: Added a toggle to disable stemmer in bm25 (#416)
* feat: Added a toggle to disable stemmer in bm25 * refactor: Refactored how to disable stemming in bm25 * refactor: Refactored the way of disabling stemmer in bm25 * new: Added english fallback if language = None * tests: Added test case for disable stemmer * fix: Fix language to be only string * tests: Updated bm25 toggle stemmer tests * refactor: fix stopwords type * fix: fix param propagation in parallel embed in bm25 --------- Co-authored-by: George Panchuk <[email protected]>
1 parent 2ef9c38 commit 0f79d3f

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

fastembed/sparse/bm25.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class Bm25(SparseTextEmbeddingBase):
9292
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
9393
Defaults to 0.75.
9494
avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
95+
language (str): Specifies the language for the stemmer.
96+
disable_stemmer (bool): Disable the stemmer.
9597
Raises:
9698
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
9799
"""
@@ -105,6 +107,7 @@ def __init__(
105107
avg_len: float = 256.0,
106108
language: str = "english",
107109
token_max_length: int = 40,
110+
disable_stemmer: bool = False,
108111
**kwargs,
109112
):
110113
super().__init__(model_name, cache_dir, **kwargs)
@@ -127,9 +130,15 @@ def __init__(
127130

128131
self.token_max_length = token_max_length
129132
self.punctuation = set(get_all_punctuation())
130-
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
133+
self.disable_stemmer = disable_stemmer
134+
135+
if disable_stemmer:
136+
self.stopwords = set()
137+
self.stemmer = None
138+
else:
139+
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
140+
self.stemmer = SnowballStemmer(language)
131141

132-
self.stemmer = SnowballStemmer(language)
133142
self.tokenizer = SimpleTokenizer
134143

135144
@classmethod
@@ -182,6 +191,9 @@ def _embed_documents(
182191
"k": self.k,
183192
"b": self.b,
184193
"avg_len": self.avg_len,
194+
"language": self.language,
195+
"token_max_length": self.token_max_length,
196+
"disable_stemmer": self.disable_stemmer,
185197
}
186198
pool = ParallelWorkerPool(
187199
num_workers=parallel or 1,
@@ -225,16 +237,18 @@ def embed(
225237
def _stem(self, tokens: list[str]) -> list[str]:
226238
stemmed_tokens = []
227239
for token in tokens:
240+
lower_token = token.lower()
241+
228242
if token in self.punctuation:
229243
continue
230244

231-
if token.lower() in self.stopwords:
245+
if lower_token in self.stopwords:
232246
continue
233247

234248
if len(token) > self.token_max_length:
235249
continue
236250

237-
stemmed_token = self.stemmer.stem_word(token.lower())
251+
stemmed_token = self.stemmer.stem_word(lower_token) if self.stemmer else lower_token
238252

239253
if stemmed_token:
240254
stemmed_tokens.append(stemmed_token)

tests/test_sparse_embeddings.py

+21
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,27 @@ def test_stem_case_insensitive_stopwords(bm25_instance):
151151
assert result == expected, f"Expected {expected}, but got {result}"
152152

153153

154+
@pytest.mark.parametrize("disable_stemmer", [True, False])
155+
def test_disable_stemmer_behavior(disable_stemmer):
156+
# Setup
157+
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
158+
model.stopwords = {"the", "is", "a"}
159+
model.punctuation = {".", ",", "!"}
160+
161+
# Test data
162+
tokens = ["The", "quick", "brown", "fox", "is", "a", "test", "sentence", ".", "!"]
163+
164+
# Execute
165+
result = model._stem(tokens)
166+
167+
# Assert
168+
if disable_stemmer:
169+
expected = ["quick", "brown", "fox", "test", "sentence"] # no stemming, lower case only
170+
else:
171+
expected = ["quick", "brown", "fox", "test", "sentenc"]
172+
assert result == expected, f"Expected {expected}, but got {result}"
173+
174+
154175
@pytest.mark.parametrize(
155176
"model_name",
156177
["prithivida/Splade_PP_en_v1"],

0 commit comments

Comments
 (0)