@@ -35,6 +35,21 @@ def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
3535 yield b
3636
3737
38+ def locate_model_file (model_dir : Path , file_names : List [str ]):
39+ """
40+ Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used by Optimum and Qdrant
41+ """
42+ if not model_dir .is_dir ():
43+ raise ValueError (f"Provided model path '{ model_dir } ' is not a directory." )
44+
45+ for path in model_dir .rglob ("*.onnx" ):
46+ for file_name in file_names :
47+ if path .is_file () and path .name == file_name :
48+ return path
49+
50+ raise ValueError (f"Could not find either of { ', ' .join (file_names )} in { model_dir } " )
51+
52+
3853def normalize (input_array , p = 2 , dim = 1 , eps = 1e-12 ):
3954 # Calculate the Lp norm along the specified dimension
4055 norm = np .linalg .norm (input_array , ord = p , axis = dim , keepdims = True )
@@ -92,19 +107,11 @@ def __init__(
92107 ):
93108 self .path = path
94109 self .model_name = model_name
95- model_path = self .path / "model.onnx"
96- optimized_model_path = self .path / "model_optimized.onnx"
110+ model_path = locate_model_file (self .path , ["model.onnx" , "model_optimized.onnx" ])
97111
98112 # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
99113 onnx_providers = ["CPUExecutionProvider" ]
100114
101- if not model_path .exists ():
102- # Rename file model_optimized.onnx to model.onnx if it exists
103- if optimized_model_path .exists ():
104- optimized_model_path .rename (model_path )
105- else :
106- raise ValueError (f"Could not find model.onnx in { self .path } " )
107-
108115 # Hacky support for multilingual model
109116 self .exclude_token_type_ids = False
110117 if model_name == "intfloat/multilingual-e5-large" :
0 commit comments