Skip to content

Commit 3148421

Browse files
I8dNLojoein
andauthored
Load from local dir (#443)
* HF sources for all models * Specific_model_path model path support * Fix hf download * fix: rollback incorrect model replacement * refactor: remove redundant type imports * refactor: replace List with list * fix: remove redundant param in late interaction text embedding * Update fastembed/common/model_management.py * fix: rollback post process onnx output --------- Co-authored-by: George Panchuk <[email protected]>
1 parent c2f6fd1 commit 3148421

File tree

9 files changed

+44
-8
lines changed

9 files changed

+44
-8
lines changed

fastembed/common/model_management.py

+4
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def download_files_from_huggingface(
114114
extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
115115
includes the required model files.
116116
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
117+
specific_model_path (Optional[str], optional): The path to the model dir already pooled from external source
117118
Returns:
118119
Path: The path to the model directory.
119120
"""
@@ -364,6 +365,9 @@ def download_model(
364365
Path: The path to the downloaded model directory.
365366
"""
366367
local_files_only = kwargs.get("local_files_only", False)
368+
specific_model_path: Optional[str] = kwargs.pop("specific_model_path", None)
369+
if specific_model_path:
370+
return Path(specific_model_path)
367371
retries = 1 if local_files_only else retries
368372
hf_source = model.get("sources", {}).get("hf")
369373
url_source = model.get("sources", {}).get("url")

fastembed/image/onnx_embedding.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Any, Iterable, Optional, Sequence, Type
22

33
import numpy as np
4-
54
from fastembed.common import ImageInput, OnnxProvider
65
from fastembed.common.onnx_model import OnnxOutputContext
76
from fastembed.common.utils import define_cache_dir, normalize
@@ -78,6 +77,7 @@ def __init__(
7877
device_ids: Optional[list[int]] = None,
7978
lazy_load: bool = False,
8079
device_id: Optional[int] = None,
80+
specific_model_path: Optional[str] = None,
8181
**kwargs,
8282
):
8383
"""
@@ -96,6 +96,7 @@ def __init__(
9696
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
9797
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
9898
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
99+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
99100
100101
Raises:
101102
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
@@ -120,7 +121,10 @@ def __init__(
120121
self.model_description = self._get_model_description(model_name)
121122
self.cache_dir = define_cache_dir(cache_dir)
122123
self._model_dir = self.download_model(
123-
self.model_description, self.cache_dir, local_files_only=self._local_files_only
124+
self.model_description,
125+
self.cache_dir,
126+
local_files_only=self._local_files_only,
127+
specific_model_path=specific_model_path,
124128
)
125129

126130
if not self.lazy_load:
@@ -145,7 +149,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
145149
Lists the supported models.
146150
147151
Returns:
148-
list[Dict[str, Any]]: A list of dictionaries containing the model information.
152+
list[dict[str, Any]]: A list of dictionaries containing the model information.
149153
"""
150154
return supported_onnx_models
151155

fastembed/late_interaction/colbert.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
device_ids: Optional[list[int]] = None,
125125
lazy_load: bool = False,
126126
device_id: Optional[int] = None,
127+
specific_model_path: Optional[str] = None,
127128
**kwargs,
128129
):
129130
"""
@@ -142,6 +143,7 @@ def __init__(
142143
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
143144
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
144145
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
146+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
145147
146148
Raises:
147149
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
@@ -167,7 +169,10 @@ def __init__(
167169
self.cache_dir = define_cache_dir(cache_dir)
168170

169171
self._model_dir = self.download_model(
170-
self.model_description, self.cache_dir, local_files_only=self._local_files_only
172+
self.model_description,
173+
self.cache_dir,
174+
local_files_only=self._local_files_only,
175+
specific_model_path=specific_model_path,
171176
)
172177
self.mask_token_id = None
173178
self.pad_token_id = None

fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py

+3
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
device_ids: Optional[list[int]] = None,
9696
lazy_load: bool = False,
9797
device_id: Optional[int] = None,
98+
specific_model_path: Optional[str] = None,
9899
**kwargs: Any,
99100
):
100101
"""
@@ -113,6 +114,7 @@ def __init__(
113114
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
114115
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
115116
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
117+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
116118
117119
Raises:
118120
ValueError: If the model_name is not in the format <org>/<model> e.g. Xenova/ms-marco-MiniLM-L-6-v2.
@@ -145,6 +147,7 @@ def __init__(
145147
self.model_description,
146148
self.cache_dir,
147149
local_files_only=self._local_files_only,
150+
specific_model_path=specific_model_path,
148151
)
149152

150153
if not self.lazy_load:

fastembed/sparse/bm25.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
language: str = "english",
109109
token_max_length: int = 40,
110110
disable_stemmer: bool = False,
111+
specific_model_path: Optional[str] = None,
111112
**kwargs,
112113
):
113114
super().__init__(model_name, cache_dir, **kwargs)
@@ -125,7 +126,10 @@ def __init__(
125126
self.cache_dir = define_cache_dir(cache_dir)
126127

127128
self._model_dir = self.download_model(
128-
model_description, self.cache_dir, local_files_only=self._local_files_only
129+
model_description,
130+
self.cache_dir,
131+
local_files_only=self._local_files_only,
132+
specific_model_path=specific_model_path,
129133
)
130134

131135
self.token_max_length = token_max_length

fastembed/sparse/bm42.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
device_ids: Optional[list[int]] = None,
6767
lazy_load: bool = False,
6868
device_id: Optional[int] = None,
69+
specific_model_path: Optional[str] = None,
6970
**kwargs,
7071
):
7172
"""
@@ -86,6 +87,7 @@ def __init__(
8687
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
8788
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
8889
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
90+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
8991
9092
Raises:
9193
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
@@ -111,7 +113,10 @@ def __init__(
111113
self.cache_dir = define_cache_dir(cache_dir)
112114

113115
self._model_dir = self.download_model(
114-
self.model_description, self.cache_dir, local_files_only=self._local_files_only
116+
self.model_description,
117+
self.cache_dir,
118+
local_files_only=self._local_files_only,
119+
specific_model_path=specific_model_path,
115120
)
116121

117122
self.invert_vocab = {}

fastembed/sparse/splade_pp.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
device_ids: Optional[list[int]] = None,
7474
lazy_load: bool = False,
7575
device_id: Optional[int] = None,
76+
specific_model_path: Optional[str] = None,
7677
**kwargs,
7778
):
7879
"""
@@ -91,6 +92,7 @@ def __init__(
9192
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
9293
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
9394
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
95+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
9496
9597
Raises:
9698
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
@@ -115,7 +117,10 @@ def __init__(
115117
self.cache_dir = define_cache_dir(cache_dir)
116118

117119
self._model_dir = self.download_model(
118-
self.model_description, self.cache_dir, local_files_only=self._local_files_only
120+
self.model_description,
121+
self.cache_dir,
122+
local_files_only=self._local_files_only,
123+
specific_model_path=specific_model_path,
119124
)
120125

121126
if not self.lazy_load:

fastembed/text/onnx_embedding.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def __init__(
193193
device_ids: Optional[list[int]] = None,
194194
lazy_load: bool = False,
195195
device_id: Optional[int] = None,
196+
specific_model_path: Optional[str] = None,
196197
**kwargs,
197198
):
198199
"""
@@ -211,6 +212,7 @@ def __init__(
211212
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
212213
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
213214
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
215+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
214216
215217
Raises:
216218
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
@@ -234,7 +236,10 @@ def __init__(
234236
self.model_description = self._get_model_description(model_name)
235237
self.cache_dir = define_cache_dir(cache_dir)
236238
self._model_dir = self.download_model(
237-
self.model_description, self.cache_dir, local_files_only=self._local_files_only
239+
self.model_description,
240+
self.cache_dir,
241+
local_files_only=self._local_files_only,
242+
specific_model_path=specific_model_path,
238243
)
239244

240245
if not self.lazy_load:

fastembed/text/onnx_text_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _load_onnx_model(
4444
providers: Optional[Sequence[OnnxProvider]] = None,
4545
cuda: bool = False,
4646
device_id: Optional[int] = None,
47+
specific_model_path: Optional[str] = None,
4748
) -> None:
4849
super()._load_onnx_model(
4950
model_dir=model_dir,

0 commit comments

Comments
 (0)