Skip to content

Commit 25671ec

Browse files
joeinNirantK
andauthored
Update ruff (#172)
* refactoring: reduce max line-length * new: update ruff --------- Co-authored-by: Nirant <[email protected]>
1 parent ce98631 commit 25671ec

13 files changed

+193
-163
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.1.13
3+
rev: v0.3.4
44
hooks:
55
- id: ruff
66
types_or: [ python, pyi, jupyter ]

fastembed/common/model_management.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
9292

9393
show_progress = total_size_in_bytes and show_progress
9494

95-
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress) as progress_bar:
95+
with tqdm(
96+
total=total_size_in_bytes, unit="iB", unit_scale=True, disable=not show_progress
97+
) as progress_bar:
9698
with open(output_path, "wb") as file:
9799
for chunk in response.iter_content(chunk_size=1024):
98100
if chunk: # Filter out keep-alive new chunks
@@ -101,7 +103,9 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
101103
return output_path
102104

103105
@classmethod
104-
def download_files_from_huggingface(cls, hf_source_repo: str, cache_dir: Optional[str] = None) -> str:
106+
def download_files_from_huggingface(
107+
cls, hf_source_repo: str, cache_dir: Optional[str] = None
108+
) -> str:
105109
"""
106110
Downloads a model from HuggingFace Hub.
107111
Args:
@@ -216,9 +220,14 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path:
216220

217221
if hf_source:
218222
try:
219-
return Path(cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir)))
223+
return Path(
224+
cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir))
225+
)
220226
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
221-
logger.error(f"Could not download model from HuggingFace: {e}" "Falling back to other sources.")
227+
logger.error(
228+
f"Could not download model from HuggingFace: {e}"
229+
"Falling back to other sources."
230+
)
222231

223232
if url_source:
224233
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))

fastembed/common/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tokenizer:
3333

3434
tokenizer = Tokenizer.from_file(str(tokenizer_path))
3535
tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length))
36-
tokenizer.enable_padding(pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"])
36+
tokenizer.enable_padding(
37+
pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"]
38+
)
3739

3840
for token in tokens_map.values():
3941
if isinstance(token, str):

fastembed/common/onnx_model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: i
4848
so.inter_op_num_threads = threads
4949

5050
self.tokenizer = load_tokenizer(model_dir=model_dir, max_length=max_length)
51-
self.model = ort.InferenceSession(str(model_path), providers=onnx_providers, sess_options=so)
51+
self.model = ort.InferenceSession(
52+
str(model_path), providers=onnx_providers, sess_options=so
53+
)
5254

5355
def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]:
5456
encoded = self.tokenizer.encode_batch(documents)
@@ -58,7 +60,9 @@ def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]:
5860
onnx_input = {
5961
"input_ids": np.array(input_ids, dtype=np.int64),
6062
"attention_mask": np.array(attention_mask, dtype=np.int64),
61-
"token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64),
63+
"token_type_ids": np.array(
64+
[np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
65+
),
6266
}
6367

6468
onnx_input = self._preprocess_onnx_input(onnx_input)
@@ -97,7 +101,9 @@ def _embed_documents(
97101
"model_name": model_name,
98102
"cache_dir": cache_dir,
99103
}
100-
pool = ParallelWorkerPool(parallel, self._get_worker_class(), start_method=start_method)
104+
pool = ParallelWorkerPool(
105+
parallel, self._get_worker_class(), start_method=start_method
106+
)
101107
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
102108
yield from self._post_process_onnx_output(batch)
103109

fastembed/embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from fastembed.text.text_embedding import TextEmbedding
66

77
logger.warning(
8-
"DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated." "Use from fastembed import TextEmbedding instead."
8+
"DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated."
9+
"Use from fastembed import TextEmbedding instead."
910
)
1011

1112
DefaultEmbedding = TextEmbedding

fastembed/parallel_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Itera
128128
yield buffer.pop(next_expected)
129129
next_expected += 1
130130

131-
def semi_ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Tuple[int, Any]]:
131+
def semi_ordered_map(
132+
self, stream: Iterable[Any], *args: Any, **kwargs: Any
133+
) -> Iterable[Tuple[int, Any]]:
132134
try:
133135
self.start(**kwargs)
134136

fastembed/sparse/sparse_embedding_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ def as_dict(self) -> Dict[int, float]:
2222

2323

2424
class SparseTextEmbeddingBase(ModelManagement):
25-
def __init__(self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, **kwargs):
25+
def __init__(
26+
self,
27+
model_name: str,
28+
cache_dir: Optional[str] = None,
29+
threads: Optional[int] = None,
30+
**kwargs,
31+
):
2632
self.model_name = model_name
2733
self.cache_dir = cache_dir
2834
self.threads = threads

fastembed/sparse/splade_pp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030

3131
class SpladePP(SparseTextEmbeddingBase, OnnxModel[SparseEmbedding]):
3232
@classmethod
33-
def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> Iterable[SparseEmbedding]:
33+
def _post_process_onnx_output(
34+
cls, output: Tuple[np.ndarray, np.ndarray]
35+
) -> Iterable[SparseEmbedding]:
3436
logits, attention_mask = output
3537
relu_log = np.log(1 + np.maximum(logits, 0))
3638

fastembed/text/jina_onnx_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def list_supported_models(cls) -> List[Dict[str, Any]]:
4949
return supported_jina_models
5050

5151
@classmethod
52-
def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> Iterable[np.ndarray]:
52+
def _post_process_onnx_output(
53+
cls, output: Tuple[np.ndarray, np.ndarray]
54+
) -> Iterable[np.ndarray]:
5355
embeddings, attn_mask = output
5456
return normalize(cls.mean_pooling(embeddings, attn_mask)).astype(np.float32)
5557

fastembed/text/onnx_embedding.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@
144144
# ]
145145
# }
146146
{
147-
"model": "mixedbread-ai/mxbai-embed-large-v1",
148-
"dim": 1024,
149-
"description": "MixedBread Base sentence embedding model, does well on MTEB",
150-
"size_in_GB": 1.34,
151-
"sources": {
152-
"hf": "mixedbread-ai/mxbai-embed-large-v1",
147+
"model": "mixedbread-ai/mxbai-embed-large-v1",
148+
"dim": 1024,
149+
"description": "MixedBread Base sentence embedding model, does well on MTEB",
150+
"size_in_GB": 1.34,
151+
"sources": {
152+
"hf": "mixedbread-ai/mxbai-embed-large-v1",
153153
},
154154
},
155155
]
@@ -239,7 +239,9 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str,
239239
return onnx_input
240240

241241
@classmethod
242-
def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> Iterable[np.ndarray]:
242+
def _post_process_onnx_output(
243+
cls, output: Tuple[np.ndarray, np.ndarray]
244+
) -> Iterable[np.ndarray]:
243245
embeddings, _ = output
244246
return normalize(embeddings[:, 0]).astype(np.float32)
245247

0 commit comments

Comments
 (0)