Skip to content

Commit f1393d1

Browse files
committed
Multitask embeddings test fix
1 parent 86b01b4 commit f1393d1

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

tests/test_text_multitask_embeddings.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ def test_batch_embedding():
6666
default_task = Task.RETRIEVAL_PASSAGE
6767

6868
for model_desc in TextEmbedding.list_supported_models():
69-
if not is_ci and model_desc["size_in_GB"] > 1:
69+
if not is_ci and model_desc.size_in_GB > 1:
7070
continue
7171

72-
model_name = model_desc["model"]
73-
dim = model_desc["dim"]
72+
model_name = model_desc.model
73+
dim = model_desc.dim
7474

7575
if model_name not in CANONICAL_VECTOR_VALUES.keys():
7676
continue
@@ -87,7 +87,7 @@ def test_batch_embedding():
8787
canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"]
8888
assert np.allclose(
8989
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
90-
), model_desc["model"]
90+
), model_desc.model
9191

9292
if is_ci:
9393
delete_model_cache(model.model._model_dir)
@@ -97,11 +97,11 @@ def test_single_embedding():
9797
is_ci = os.getenv("CI")
9898

9999
for model_desc in TextEmbedding.list_supported_models():
100-
if not is_ci and model_desc["size_in_GB"] > 1:
100+
if not is_ci and model_desc.size_in_GB > 1:
101101
continue
102102

103-
model_name = model_desc["model"]
104-
dim = model_desc["dim"]
103+
model_name = model_desc.model
104+
dim = model_desc.dim
105105

106106
if model_name not in CANONICAL_VECTOR_VALUES.keys():
107107
continue
@@ -119,7 +119,7 @@ def test_single_embedding():
119119
canonical_vector = task["vectors"]
120120
assert np.allclose(
121121
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
122-
), model_desc["model"]
122+
), model_desc.model
123123

124124
if is_ci:
125125
delete_model_cache(model.model._model_dir)
@@ -130,11 +130,11 @@ def test_single_embedding_query():
130130
task_id = Task.RETRIEVAL_QUERY
131131

132132
for model_desc in TextEmbedding.list_supported_models():
133-
if not is_ci and model_desc["size_in_GB"] > 1:
133+
if not is_ci and model_desc.size_in_GB > 1:
134134
continue
135135

136-
model_name = model_desc["model"]
137-
dim = model_desc["dim"]
136+
model_name = model_desc.model
137+
dim = model_desc.dim
138138

139139
if model_name not in CANONICAL_VECTOR_VALUES.keys():
140140
continue
@@ -151,7 +151,7 @@ def test_single_embedding_query():
151151
canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
152152
assert np.allclose(
153153
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
154-
), model_desc["model"]
154+
), model_desc.model
155155

156156
if is_ci:
157157
delete_model_cache(model.model._model_dir)
@@ -162,11 +162,11 @@ def test_single_embedding_passage():
162162
task_id = Task.RETRIEVAL_PASSAGE
163163

164164
for model_desc in TextEmbedding.list_supported_models():
165-
if not is_ci and model_desc["size_in_GB"] > 1:
165+
if not is_ci and model_desc.size_in_GB > 1:
166166
continue
167167

168-
model_name = model_desc["model"]
169-
dim = model_desc["dim"]
168+
model_name = model_desc.model
169+
dim = model_desc.dim
170170

171171
if model_name not in CANONICAL_VECTOR_VALUES.keys():
172172
continue
@@ -183,7 +183,7 @@ def test_single_embedding_passage():
183183
canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
184184
assert np.allclose(
185185
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
186-
), model_desc["model"]
186+
), model_desc.model
187187

188188
if is_ci:
189189
delete_model_cache(model.model._model_dir)
@@ -220,10 +220,10 @@ def test_task_assignment():
220220
is_ci = os.getenv("CI")
221221

222222
for model_desc in TextEmbedding.list_supported_models():
223-
if not is_ci and model_desc["size_in_GB"] > 1:
223+
if not is_ci and model_desc.size_in_GB > 1:
224224
continue
225225

226-
model_name = model_desc["model"]
226+
model_name = model_desc.model
227227
if model_name not in CANONICAL_VECTOR_VALUES.keys():
228228
continue
229229

0 commit comments

Comments
 (0)