@@ -35,6 +35,21 @@ def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
35
35
yield b
36
36
37
37
38
+ def locate_model_file (model_dir : Path , file_names : List [str ]):
39
+ """
40
+ Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used by Optimum and Qdrant
41
+ """
42
+ if not model_dir .is_dir ():
43
+ raise ValueError (f"Provided model path '{ model_dir } ' is not a directory." )
44
+
45
+ for path in model_dir .rglob ("*.onnx" ):
46
+ for file_name in file_names :
47
+ if path .is_file () and path .name == file_name :
48
+ return path
49
+
50
+ raise ValueError (f"Could not find either of { ', ' .join (file_names )} in { model_dir } " )
51
+
52
+
38
53
def normalize (input_array , p = 2 , dim = 1 , eps = 1e-12 ):
39
54
# Calculate the Lp norm along the specified dimension
40
55
norm = np .linalg .norm (input_array , ord = p , axis = dim , keepdims = True )
@@ -92,19 +107,11 @@ def __init__(
92
107
):
93
108
self .path = path
94
109
self .model_name = model_name
95
- model_path = self .path / "model.onnx"
96
- optimized_model_path = self .path / "model_optimized.onnx"
110
+ model_path = locate_model_file (self .path , ["model.onnx" , "model_optimized.onnx" ])
97
111
98
112
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
99
113
onnx_providers = ["CPUExecutionProvider" ]
100
114
101
- if not model_path .exists ():
102
- # Rename file model_optimized.onnx to model.onnx if it exists
103
- if optimized_model_path .exists ():
104
- optimized_model_path .rename (model_path )
105
- else :
106
- raise ValueError (f"Could not find model.onnx in { self .path } " )
107
-
108
115
# Hacky support for multilingual model
109
116
self .exclude_token_type_ids = False
110
117
if model_name == "intfloat/multilingual-e5-large" :
0 commit comments