45
45
46
46
@pytest .mark .parametrize ("scenario" , canonical_vectors )
47
47
def test_add_custom_model_variations (scenario ):
48
- """
49
- Tests add_custom_model for different base models and different
50
- (mean_pooling, normalization) configs. Checks the first 5 dims
51
- of "hello world" match the scenario's canonical vector.
52
- """
53
-
54
48
is_ci = bool (os .getenv ("CI" , False ))
55
49
56
50
base_model_name = scenario ["model" ]
57
51
mean_pooling = scenario ["mean_pooling" ]
58
52
normalization = scenario ["normalization" ]
59
53
cv = np .array (scenario ["canonical_vector" ], dtype = np .float32 )
60
54
55
+ backup_supported_models = {}
56
+ for embedding_cls in TextEmbedding .EMBEDDINGS_REGISTRY :
57
+ backup_supported_models [embedding_cls ] = embedding_cls .list_supported_models ().copy ()
58
+
61
59
suffixes = []
62
60
suffixes .append ("mean" if mean_pooling else "no-mean" )
63
61
suffixes .append ("norm" if normalization else "no-norm" )
@@ -81,32 +79,37 @@ def test_add_custom_model_variations(scenario):
81
79
"additional_files" : [],
82
80
}
83
81
84
- if is_ci and model_info ["size_in_GB" ] > 1 :
82
+ if is_ci and model_info ["size_in_GB" ] > 1.0 :
85
83
pytest .skip (
86
84
f"Skipping { custom_model_name } on CI due to size_in_GB={ model_info ['size_in_GB' ]} "
87
85
)
88
86
89
- TextEmbedding .add_custom_model (
90
- model_info = model_info , mean_pooling = mean_pooling , normalization = normalization
91
- )
87
+ try :
88
+ TextEmbedding .add_custom_model (
89
+ model_info = model_info , mean_pooling = mean_pooling , normalization = normalization
90
+ )
91
+
92
+ model = TextEmbedding (model_name = custom_model_name )
92
93
93
- model = TextEmbedding (model_name = custom_model_name )
94
+ docs = ["hello world" , "flag embedding" ]
95
+ embeddings = list (model .embed (docs ))
96
+ embeddings = np .stack (embeddings , axis = 0 )
94
97
95
- docs = ["hello world" , "flag embedding" ]
96
- embeddings = list (model .embed (docs ))
97
- embeddings = np .stack (embeddings , axis = 0 )
98
+ assert embeddings .shape == (
99
+ 2 ,
100
+ dim ,
101
+ ), f"Expected shape (2, { dim } ) for { custom_model_name } , but got { embeddings .shape } "
98
102
99
- assert embeddings . shape == (
100
- 2 ,
101
- dim ,
102
- ), f"Expected shape (2, { dim } ) for { custom_model_name } , but got { embeddings . shape } "
103
+ num_compare_dims = cv . shape [ 0 ]
104
+ assert np . allclose (
105
+ embeddings [ 0 , : num_compare_dims ], cv , atol = 1e-3
106
+ ), f"Embedding mismatch for { custom_model_name } (first { num_compare_dims } dims). "
103
107
104
- num_compare_dims = cv .shape [0 ]
105
- assert np .allclose (
106
- embeddings [0 , :num_compare_dims ], cv , atol = 1e-3
107
- ), f"Embedding mismatch for { custom_model_name } (first { num_compare_dims } dims)."
108
+ assert not np .allclose (embeddings [0 , :], 0.0 ), "Embedding should not be all zeros."
108
109
109
- assert not np .allclose (embeddings [0 , :], 0.0 ), "Embedding should not be all zeros."
110
+ if is_ci :
111
+ delete_model_cache (model .model ._model_dir )
110
112
111
- if is_ci :
112
- delete_model_cache (model .model ._model_dir )
113
+ finally :
114
+ for embedding_cls , old_list in backup_supported_models .items ():
115
+ embedding_cls .supported_models = old_list
0 commit comments