@@ -35,6 +35,18 @@ def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
35
35
yield b
36
36
37
37
38
+ def locate_model_file (model_dir : Path , file_names : list ):
39
+ if not model_dir .is_dir ():
40
+ raise ValueError (f"Provided model path '{ model_dir } ' is not a directory." )
41
+
42
+ for path in model_dir .rglob ("*" ):
43
+ for file_name in file_names :
44
+ if path .is_file () and path .name == file_name :
45
+ return path
46
+
47
+ raise ValueError (f"Could not find model file in { model_dir } " )
48
+
49
+
38
50
def normalize (input_array , p = 2 , dim = 1 , eps = 1e-12 ):
39
51
# Calculate the Lp norm along the specified dimension
40
52
norm = np .linalg .norm (input_array , ord = p , axis = dim , keepdims = True )
@@ -92,32 +104,11 @@ def __init__(
92
104
):
93
105
self .path = path
94
106
self .model_name = model_name
95
- model_path = self .path / "model.onnx"
96
- optimized_model_path = self .path / "model_optimized.onnx"
97
-
98
- xenova_model_path = self .path / "onnx" / "model.onnx"
99
- xenova_optimized_model_path = self .path / "onnx" / "model_optimized.onnx"
107
+ model_path = locate_model_file (self .path , ["model.onnx" , "model_optimized.onnx" ])
100
108
101
109
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
102
110
onnx_providers = ["CPUExecutionProvider" ]
103
111
104
- if not model_path .exists ():
105
- # Rename file model_optimized.onnx to model.onnx if it exists
106
- if optimized_model_path .exists ():
107
- optimized_model_path .rename (model_path )
108
-
109
- # Patch for inconsistent repo structure at
110
- # - https://huggingface.co/Xenova/jina-embeddings-v2-small-en
111
- # - https://huggingface.co/Xenova/jina-embeddings-v2-base-en
112
- elif xenova_model_path .exists ():
113
- model_path = xenova_model_path
114
-
115
- elif xenova_optimized_model_path .exists ():
116
- model_path = xenova_optimized_model_path
117
-
118
- else :
119
- raise ValueError (f"Could not find model.onnx in { self .path } " )
120
-
121
112
# Hacky support for multilingual model
122
113
self .exclude_token_type_ids = False
123
114
if model_name == "intfloat/multilingual-e5-large" :
0 commit comments