1
+ import os
2
+ import shutil
3
+
1
4
import pytest
5
+ import numpy as np
2
6
3
7
from fastembed .sparse .bm25 import Bm25
4
8
from fastembed .sparse .sparse_text_embedding import SparseTextEmbedding
46
50
47
51
48
52
def test_batch_embedding ():
53
+ is_ci = os .getenv ("CI" )
49
54
docs_to_embed = docs * 10
50
55
51
56
for model_name , expected_result in CANONICAL_COLUMN_VALUES .items ():
@@ -55,9 +60,12 @@ def test_batch_embedding():
55
60
56
61
for i , value in enumerate (result .values ):
57
62
assert pytest .approx (value , abs = 0.001 ) == expected_result ["values" ][i ]
63
+ if is_ci :
64
+ shutil .rmtree (model .model ._model_dir )
58
65
59
66
60
67
def test_single_embedding ():
68
+ is_ci = os .getenv ("CI" )
61
69
for model_name , expected_result in CANONICAL_COLUMN_VALUES .items ():
62
70
model = SparseTextEmbedding (model_name = model_name )
63
71
@@ -68,11 +76,12 @@ def test_single_embedding():
68
76
69
77
for i , value in enumerate (result .values ):
70
78
assert pytest .approx (value , abs = 0.001 ) == expected_result ["values" ][i ]
79
+ if is_ci :
80
+ shutil .rmtree (model .model ._model_dir )
71
81
72
82
73
83
def test_parallel_processing ():
74
- import numpy as np
75
-
84
+ is_ci = os .getenv ("CI" )
76
85
model = SparseTextEmbedding (model_name = "prithivida/Splade_PP_en_v1" )
77
86
docs = ["hello world" , "flag embedding" ] * 30
78
87
sparse_embeddings_duo = list (model .embed (docs , batch_size = 10 , parallel = 2 ))
@@ -97,16 +106,23 @@ def test_parallel_processing():
97
106
assert np .allclose (sparse_embedding .values , sparse_embedding_duo .values , atol = 1e-3 )
98
107
assert np .allclose (sparse_embedding .values , sparse_embedding_all .values , atol = 1e-3 )
99
108
109
+ if is_ci :
110
+ shutil .rmtree (model .model ._model_dir )
111
+
100
112
101
113
@pytest .fixture
102
114
def bm25_instance ():
103
- return Bm25 ("Qdrant/bm25" , language = "english" )
115
+ ci = os .getenv ("CI" , True )
116
+ model = Bm25 ("Qdrant/bm25" , language = "english" )
117
+ yield model
118
+ if ci :
119
+ shutil .rmtree (model ._model_dir )
104
120
105
121
106
122
def test_stem_with_stopwords_and_punctuation (bm25_instance ):
107
123
# Setup
108
- bm25_instance .stopwords = set ([ "the" , "is" , "a" ])
109
- bm25_instance .punctuation = set ([ "." , "," , "!" ])
124
+ bm25_instance .stopwords = { "the" , "is" , "a" }
125
+ bm25_instance .punctuation = { "." , "," , "!" }
110
126
111
127
# Test data
112
128
tokens = ["The" , "quick" , "brown" , "fox" , "is" , "a" , "test" , "sentence" , "." , "!" ]
@@ -121,8 +137,8 @@ def test_stem_with_stopwords_and_punctuation(bm25_instance):
121
137
122
138
def test_stem_case_insensitive_stopwords (bm25_instance ):
123
139
# Setup
124
- bm25_instance .stopwords = set ([ "the" , "is" , "a" ])
125
- bm25_instance .punctuation = set ([ "." , "," , "!" ])
140
+ bm25_instance .stopwords = { "the" , "is" , "a" }
141
+ bm25_instance .punctuation = { "." , "," , "!" }
126
142
127
143
# Test data
128
144
tokens = ["THE" , "Quick" , "Brown" , "Fox" , "IS" , "A" , "Test" , "Sentence" , "." , "!" ]
0 commit comments