Skip to content

Commit 0bd761f

Browse files
committed
chore: try recusive model location
1 parent 5adf3f6 commit 0bd761f

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

fastembed/embedding.py

+13-22
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
3535
yield b
3636

3737

38+
def locate_model_file(model_dir: Path, file_names: list):
39+
if not model_dir.is_dir():
40+
raise ValueError(f"Provided model path '{model_dir}' is not a directory.")
41+
42+
for path in model_dir.rglob("*"):
43+
for file_name in file_names:
44+
if path.is_file() and path.name == file_name:
45+
return path
46+
47+
raise ValueError(f"Could not find model file in {model_dir}")
48+
49+
3850
def normalize(input_array, p=2, dim=1, eps=1e-12):
3951
# Calculate the Lp norm along the specified dimension
4052
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
@@ -92,32 +104,11 @@ def __init__(
92104
):
93105
self.path = path
94106
self.model_name = model_name
95-
model_path = self.path / "model.onnx"
96-
optimized_model_path = self.path / "model_optimized.onnx"
97-
98-
xenova_model_path = self.path / "onnx" / "model.onnx"
99-
xenova_optimized_model_path = self.path / "onnx" / "model_optimized.onnx"
107+
model_path = locate_model_file(self.path, ["model.onnx", "model_optimized.onnx"])
100108

101109
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
102110
onnx_providers = ["CPUExecutionProvider"]
103111

104-
if not model_path.exists():
105-
# Rename file model_optimized.onnx to model.onnx if it exists
106-
if optimized_model_path.exists():
107-
optimized_model_path.rename(model_path)
108-
109-
# Patch for inconsistent repo structure at
110-
# - https://huggingface.co/Xenova/jina-embeddings-v2-small-en
111-
# - https://huggingface.co/Xenova/jina-embeddings-v2-base-en
112-
elif xenova_model_path.exists():
113-
model_path = xenova_model_path
114-
115-
elif xenova_optimized_model_path.exists():
116-
model_path = xenova_optimized_model_path
117-
118-
else:
119-
raise ValueError(f"Could not find model.onnx in {self.path}")
120-
121112
# Hacky support for multilingual model
122113
self.exclude_token_type_ids = False
123114
if model_name == "intfloat/multilingual-e5-large":

0 commit comments

Comments
 (0)