Skip to content

Commit 9eb12b7

Browse files
committed
refactor: replace remaining union and optional with |
1 parent 6e70369 commit 9eb12b7

4 files changed

Lines changed: 28 additions & 28 deletions

File tree

fastembed/common/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import unicodedata
66
from pathlib import Path
77
from itertools import islice
8-
from typing import Iterable, Optional, TypeVar
8+
from typing import Iterable, TypeVar
99

1010
import numpy as np
1111
from numpy.typing import NDArray
@@ -45,7 +45,7 @@ def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
4545
yield b
4646

4747

48-
def define_cache_dir(cache_dir: Optional[str] = None) -> Path:
48+
def define_cache_dir(cache_dir: str | None = None) -> Path:
4949
"""
5050
Define the cache directory for fastembed
5151
"""

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from multiprocessing import get_all_start_methods
44
from pathlib import Path
5-
from typing import Any, Iterable, Optional, Sequence, Type, Union
5+
from typing import Any, Iterable, Sequence, Type
66

77
import numpy as np
88
from PIL import Image
@@ -18,12 +18,12 @@
1818

1919

2020
class OnnxMultimodalModel(OnnxModel[T]):
21-
ONNX_OUTPUT_NAMES: Optional[list[str]] = None
21+
ONNX_OUTPUT_NAMES: list[str] | None = None
2222

2323
def __init__(self) -> None:
2424
super().__init__()
25-
self.tokenizer: Optional[Tokenizer] = None
26-
self.processor: Optional[Compose] = None
25+
self.tokenizer: Tokenizer | None = None
26+
self.processor: Compose | None = None
2727
self.special_token_to_id: dict[str, int] = {}
2828

2929
def _preprocess_onnx_text_input(
@@ -60,11 +60,11 @@ def _load_onnx_model(
6060
self,
6161
model_dir: Path,
6262
model_file: str,
63-
threads: Optional[int],
64-
providers: Optional[Sequence[OnnxProvider]] = None,
63+
threads: int | None,
64+
providers: Sequence[OnnxProvider] | None = None,
6565
cuda: bool = False,
66-
device_id: Optional[int] = None,
67-
extra_session_options: Optional[dict[str, Any]] = None,
66+
device_id: int | None = None,
67+
extra_session_options: dict[str, Any] | None = None,
6868
) -> None:
6969
super()._load_onnx_model(
7070
model_dir=model_dir,
@@ -116,15 +116,15 @@ def _embed_documents(
116116
self,
117117
model_name: str,
118118
cache_dir: str,
119-
documents: Union[str, Iterable[str]],
119+
documents: str | Iterable[str],
120120
batch_size: int = 256,
121-
parallel: Optional[int] = None,
122-
providers: Optional[Sequence[OnnxProvider]] = None,
121+
parallel: int | None = None,
122+
providers: Sequence[OnnxProvider] | None = None,
123123
cuda: bool = False,
124-
device_ids: Optional[list[int]] = None,
124+
device_ids: list[int] | None = None,
125125
local_files_only: bool = False,
126-
specific_model_path: Optional[str] = None,
127-
extra_session_options: Optional[dict[str, Any]] = None,
126+
specific_model_path: str | None = None,
127+
extra_session_options: dict[str, Any] | None = None,
128128
**kwargs: Any,
129129
) -> Iterable[T]:
130130
is_small = False
@@ -187,15 +187,15 @@ def _embed_images(
187187
self,
188188
model_name: str,
189189
cache_dir: str,
190-
images: Union[Iterable[ImageInput], ImageInput],
190+
images: Iterable[ImageInput] | ImageInput,
191191
batch_size: int = 256,
192-
parallel: Optional[int] = None,
193-
providers: Optional[Sequence[OnnxProvider]] = None,
192+
parallel: int | None = None,
193+
providers: Sequence[OnnxProvider] | None = None,
194194
cuda: bool = False,
195-
device_ids: Optional[list[int]] = None,
195+
device_ids: list[int] | None = None,
196196
local_files_only: bool = False,
197-
specific_model_path: Optional[str] = None,
198-
extra_session_options: Optional[dict[str, Any]] = None,
197+
specific_model_path: str | None = None,
198+
extra_session_options: dict[str, Any] | None = None,
199199
**kwargs: Any,
200200
) -> Iterable[T]:
201201
is_small = False

fastembed/text/custom_text_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Sequence, Any, Iterable, Optional
1+
from typing import Sequence, Any, Iterable
22

33
from dataclasses import dataclass
44

@@ -64,7 +64,7 @@ def _post_process_onnx_output(
6464
return self._normalize(self._pool(output.model_output, output.attention_mask))
6565

6666
def _pool(
67-
self, embeddings: NumpyArray, attention_mask: Optional[NDArray[np.int64]] = None
67+
self, embeddings: NumpyArray, attention_mask: NDArray[np.int64] | None = None
6868
) -> NumpyArray:
6969
if self._pooling == PoolingType.CLS:
7070
return embeddings[:, 0]

tests/test_multi_gpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from typing import Optional
2+
33
from fastembed import (
44
TextEmbedding,
55
SparseTextEmbedding,
@@ -14,7 +14,7 @@
1414

1515
@pytest.mark.skip(reason="Requires a multi-gpu server")
1616
@pytest.mark.parametrize("device_id", [None, 0, 1])
17-
def test_gpu_via_providers(device_id: Optional[int]) -> None:
17+
def test_gpu_via_providers(device_id: int | None) -> None:
1818
docs = ["hello world", "flag embedding"]
1919

2020
device_id = device_id if device_id is not None else 0
@@ -86,7 +86,7 @@ def test_gpu_via_providers(device_id: Optional[int]) -> None:
8686

8787
@pytest.mark.skip(reason="Requires a multi-gpu server")
8888
@pytest.mark.parametrize("device_ids", [None, [0], [1], [0, 1]])
89-
def test_gpu_cuda_device_ids(device_ids: Optional[list[int]]) -> None:
89+
def test_gpu_cuda_device_ids(device_ids: list[int] | None) -> None:
9090
docs = ["hello world", "flag embedding"]
9191
device_id = device_ids[0] if device_ids else 0
9292
embedding_model = TextEmbedding(
@@ -171,7 +171,7 @@ def test_gpu_cuda_device_ids(device_ids: Optional[list[int]]) -> None:
171171
@pytest.mark.parametrize(
172172
"device_ids,parallel", [(None, None), (None, 2), ([1], None), ([1], 1), ([1], 2), ([0, 1], 2)]
173173
)
174-
def test_multi_gpu_parallel_inference(device_ids: Optional[list[int]], parallel: int) -> None:
174+
def test_multi_gpu_parallel_inference(device_ids: list[int] | None, parallel: int) -> None:
175175
docs = ["hello world", "flag embedding"] * 100
176176
batch_size = 5
177177

0 commit comments

Comments
 (0)