Skip to content

Commit 97f2fb2

Browse files
authored
wip: remove model dir (#350)
* wip: remove model dir * fix: update pytest run * wip: disable some tests not used atm * wip: disable some tests in ci * wip: add debug print * fix: fix ci, remove models after usage * fix: fix bm25 deletion * fix: remove redundant ci commands
1 parent fa22051 commit 97f2fb2

12 files changed

+93
-31
lines changed

.github/workflows/python-tests.yml

+2-11
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ name: Tests
33
on:
44
push:
55
branches: [ master, main, gpu ]
6-
schedule:
7-
- cron: 0 0 * * *
86
pull_request:
97

108
env:
@@ -42,13 +40,6 @@ jobs:
4240
poetry config virtualenvs.create false
4341
poetry install --no-interaction --no-ansi --without docs
4442
45-
- name: Install Test Dependencies
46-
run: pip install pytest pytest-md pytest-emoji
47-
4843
- name: Run pytest
49-
uses: pavelzw/pytest-action@v2
50-
with:
51-
verbose: true
52-
emoji: true
53-
job-summary: true
54-
report-title: 'FastEmbed Test Report'
44+
run: |
45+
poetry run pytest

fastembed/image/onnx_embedding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ def __init__(
7777

7878
model_description = self._get_model_description(model_name)
7979
self.cache_dir = define_cache_dir(cache_dir)
80-
model_dir = self.download_model(
80+
self._model_dir = self.download_model(
8181
model_description, self.cache_dir, local_files_only=self._local_files_only
8282
)
8383

8484
self.load_onnx_model(
85-
model_dir=model_dir,
85+
model_dir=self._model_dir,
8686
model_file=model_description["model_file"],
8787
threads=threads,
8888
providers=providers,

fastembed/late_interaction/colbert.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def __init__(
136136
model_description = self._get_model_description(model_name)
137137
self.cache_dir = define_cache_dir(cache_dir)
138138

139-
model_dir = self.download_model(
139+
self._model_dir = self.download_model(
140140
model_description, self.cache_dir, local_files_only=self._local_files_only
141141
)
142142

143143
self.load_onnx_model(
144-
model_dir=model_dir,
144+
model_dir=self._model_dir,
145145
model_file=model_description["model_file"],
146146
threads=threads,
147147
providers=providers,

fastembed/sparse/bm25.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,12 @@ def __init__(
116116
model_description = self._get_model_description(model_name)
117117
self.cache_dir = define_cache_dir(cache_dir)
118118

119-
model_dir = self.download_model(
119+
self._model_dir = self.download_model(
120120
model_description, self.cache_dir, local_files_only=self._local_files_only
121121
)
122122

123123
self.punctuation = set(string.punctuation)
124-
self.stopwords = set(self._load_stopwords(model_dir, self.language))
124+
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
125125
self.stemmer = get_stemmer(language)
126126
self.tokenizer = WordTokenizer
127127

fastembed/sparse/bm42.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def __init__(
8484
model_description = self._get_model_description(model_name)
8585
self.cache_dir = define_cache_dir(cache_dir)
8686

87-
model_dir = self.download_model(
87+
self._model_dir = self.download_model(
8888
model_description, self.cache_dir, local_files_only=self._local_files_only
8989
)
9090

9191
self.load_onnx_model(
92-
model_dir=model_dir,
92+
model_dir=self._model_dir,
9393
model_file=model_description["model_file"],
9494
threads=threads,
9595
providers=providers,
@@ -103,7 +103,7 @@ def __init__(
103103
self.special_tokens = set(self.special_token_to_id.keys())
104104
self.special_tokens_ids = set(self.special_token_to_id.values())
105105
self.punctuation = set(string.punctuation)
106-
self.stopwords = set(self._load_stopwords(model_dir))
106+
self.stopwords = set(self._load_stopwords(self._model_dir))
107107
self.stemmer = get_stemmer(MODEL_TO_LANGUAGE[model_name])
108108
self.alpha = alpha
109109

fastembed/sparse/splade_pp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def __init__(
8282
model_description = self._get_model_description(model_name)
8383
self.cache_dir = define_cache_dir(cache_dir)
8484

85-
model_dir = self.download_model(
85+
self._model_dir = self.download_model(
8686
model_description, self.cache_dir, local_files_only=self._local_files_only
8787
)
8888

8989
self.load_onnx_model(
90-
model_dir=model_dir,
90+
model_dir=self._model_dir,
9191
model_file=model_description["model_file"],
9292
threads=threads,
9393
providers=providers,

fastembed/text/onnx_embedding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,12 @@ def __init__(
190190

191191
model_description = self._get_model_description(model_name)
192192
self.cache_dir = define_cache_dir(cache_dir)
193-
model_dir = self.download_model(
193+
self._model_dir = self.download_model(
194194
model_description, self.cache_dir, local_files_only=self._local_files_only
195195
)
196196

197197
self.load_onnx_model(
198-
model_dir=model_dir,
198+
model_dir=self._model_dir,
199199
model_file=model_description["model_file"],
200200
threads=threads,
201201
providers=providers,

tests/test_attention_embeddings.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import shutil
3+
14
import numpy as np
25
import pytest
36

@@ -6,6 +9,7 @@
69

710
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
811
def test_attention_embeddings(model_name):
12+
is_ci = os.getenv("CI")
913
model = SparseTextEmbedding(model_name=model_name)
1014

1115
output = list(
@@ -62,9 +66,14 @@ def test_attention_embeddings(model_name):
6266
assert len(result.indices) == len(result.values)
6367
assert len(result.indices) == 2
6468

69+
if is_ci:
70+
shutil.rmtree(model.model._model_dir)
71+
6572

6673
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
6774
def test_parallel_processing(model_name):
75+
is_ci = os.getenv("CI")
76+
6877
model = SparseTextEmbedding(model_name=model_name)
6978

7079
docs = ["hello world", "attention embedding", "Mangez-vous vraiment des grenouilles?"] * 100
@@ -82,9 +91,14 @@ def test_parallel_processing(model_name):
8291
assert np.allclose(emb_1.values, emb_2.values)
8392
assert np.allclose(emb_1.values, emb_3.values)
8493

94+
if is_ci:
95+
shutil.rmtree(model.model._model_dir)
96+
8597

8698
@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
8799
def test_multilanguage(model_name):
100+
is_ci = os.getenv("CI")
101+
88102
docs = ["Mangez-vous vraiment des grenouilles?", "Je suis au lit"]
89103

90104
model = SparseTextEmbedding(model_name=model_name, language="french")
@@ -102,3 +116,6 @@ def test_multilanguage(model_name):
102116

103117
assert embeddings[1].values.shape == (4,)
104118
assert embeddings[1].indices.shape == (4,)
119+
120+
if is_ci:
121+
shutil.rmtree(model.model._model_dir)

tests/test_image_onnx_embeddings.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import shutil
23
from io import BytesIO
34

45
import numpy as np
@@ -52,9 +53,13 @@ def test_embedding():
5253

5354
assert np.allclose(embeddings[1], embeddings[2]), model_desc["model"]
5455

56+
if is_ci:
57+
shutil.rmtree(model.model._model_dir)
58+
5559

5660
@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
5761
def test_batch_embedding(n_dims, model_name):
62+
is_ci = os.getenv("CI")
5863
model = ImageEmbedding(model_name=model_name)
5964
n_images = 32
6065
test_images = [
@@ -68,10 +73,13 @@ def test_batch_embedding(n_dims, model_name):
6873
embeddings = np.stack(embeddings, axis=0)
6974

7075
assert embeddings.shape == (len(test_images) * n_images, n_dims)
76+
if is_ci:
77+
shutil.rmtree(model.model._model_dir)
7178

7279

7380
@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
7481
def test_parallel_processing(n_dims, model_name):
82+
is_ci = os.getenv("CI")
7583
model = ImageEmbedding(model_name=model_name)
7684

7785
n_images = 32
@@ -93,3 +101,5 @@ def test_parallel_processing(n_dims, model_name):
93101
assert embeddings.shape == (n_images * len(test_images), n_dims)
94102
assert np.allclose(embeddings, embeddings_2, atol=1e-3)
95103
assert np.allclose(embeddings, embeddings_3, atol=1e-3)
104+
if is_ci:
105+
shutil.rmtree(model.model._model_dir)

tests/test_late_interaction_embeddings.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import shutil
3+
14
import numpy as np
25

36
from fastembed.late_interaction.late_interaction_text_embedding import (
@@ -105,6 +108,7 @@
105108

106109

107110
def test_batch_embedding():
111+
is_ci = os.getenv("CI")
108112
docs_to_embed = docs * 10
109113

110114
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
@@ -116,8 +120,12 @@ def test_batch_embedding():
116120
token_num, abridged_dim = expected_result.shape
117121
assert np.allclose(value[:, :abridged_dim], expected_result, atol=10e-4)
118122

123+
if is_ci:
124+
shutil.rmtree(model.model._model_dir)
125+
119126

120127
def test_single_embedding():
128+
is_ci = os.getenv("CI")
121129
docs_to_embed = docs
122130

123131
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
@@ -127,8 +135,12 @@ def test_single_embedding():
127135
token_num, abridged_dim = expected_result.shape
128136
assert np.allclose(result[:, :abridged_dim], expected_result, atol=10e-4)
129137

138+
if is_ci:
139+
shutil.rmtree(model.model._model_dir)
140+
130141

131142
def test_single_embedding_query():
143+
is_ci = os.getenv("CI")
132144
queries_to_embed = docs
133145

134146
for model_name, expected_result in CANONICAL_QUERY_VALUES.items():
@@ -138,8 +150,12 @@ def test_single_embedding_query():
138150
token_num, abridged_dim = expected_result.shape
139151
assert np.allclose(result[:, :abridged_dim], expected_result, atol=10e-4)
140152

153+
if is_ci:
154+
shutil.rmtree(model.model._model_dir)
155+
141156

142157
def test_parallel_processing():
158+
is_ci = os.getenv("CI")
143159
model = LateInteractionTextEmbedding(model_name="colbert-ir/colbertv2.0")
144160
token_dim = 128
145161
docs = ["hello world", "flag embedding"] * 100
@@ -155,3 +171,6 @@ def test_parallel_processing():
155171
assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == token_dim
156172
assert np.allclose(embeddings, embeddings_2, atol=1e-3)
157173
assert np.allclose(embeddings, embeddings_3, atol=1e-3)
174+
175+
if is_ci:
176+
shutil.rmtree(model.model._model_dir)

tests/test_sparse_embeddings.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import os
2+
import shutil
3+
14
import pytest
5+
import numpy as np
26

37
from fastembed.sparse.bm25 import Bm25
48
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
@@ -46,6 +50,7 @@
4650

4751

4852
def test_batch_embedding():
53+
is_ci = os.getenv("CI")
4954
docs_to_embed = docs * 10
5055

5156
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
@@ -55,9 +60,12 @@ def test_batch_embedding():
5560

5661
for i, value in enumerate(result.values):
5762
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
63+
if is_ci:
64+
shutil.rmtree(model.model._model_dir)
5865

5966

6067
def test_single_embedding():
68+
is_ci = os.getenv("CI")
6169
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
6270
model = SparseTextEmbedding(model_name=model_name)
6371

@@ -68,11 +76,12 @@ def test_single_embedding():
6876

6977
for i, value in enumerate(result.values):
7078
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
79+
if is_ci:
80+
shutil.rmtree(model.model._model_dir)
7181

7282

7383
def test_parallel_processing():
74-
import numpy as np
75-
84+
is_ci = os.getenv("CI")
7685
model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
7786
docs = ["hello world", "flag embedding"] * 30
7887
sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2))
@@ -97,16 +106,23 @@ def test_parallel_processing():
97106
assert np.allclose(sparse_embedding.values, sparse_embedding_duo.values, atol=1e-3)
98107
assert np.allclose(sparse_embedding.values, sparse_embedding_all.values, atol=1e-3)
99108

109+
if is_ci:
110+
shutil.rmtree(model.model._model_dir)
111+
100112

101113
@pytest.fixture
102114
def bm25_instance():
103-
return Bm25("Qdrant/bm25", language="english")
115+
ci = os.getenv("CI", True)
116+
model = Bm25("Qdrant/bm25", language="english")
117+
yield model
118+
if ci:
119+
shutil.rmtree(model._model_dir)
104120

105121

106122
def test_stem_with_stopwords_and_punctuation(bm25_instance):
107123
# Setup
108-
bm25_instance.stopwords = set(["the", "is", "a"])
109-
bm25_instance.punctuation = set([".", ",", "!"])
124+
bm25_instance.stopwords = {"the", "is", "a"}
125+
bm25_instance.punctuation = {".", ",", "!"}
110126

111127
# Test data
112128
tokens = ["The", "quick", "brown", "fox", "is", "a", "test", "sentence", ".", "!"]
@@ -121,8 +137,8 @@ def test_stem_with_stopwords_and_punctuation(bm25_instance):
121137

122138
def test_stem_case_insensitive_stopwords(bm25_instance):
123139
# Setup
124-
bm25_instance.stopwords = set(["the", "is", "a"])
125-
bm25_instance.punctuation = set([".", ",", "!"])
140+
bm25_instance.stopwords = {"the", "is", "a"}
141+
bm25_instance.punctuation = {".", ",", "!"}
126142

127143
# Test data
128144
tokens = ["THE", "Quick", "Brown", "Fox", "IS", "A", "Test", "Sentence", ".", "!"]

0 commit comments

Comments
 (0)