Skip to content

Commit c0ec55c

Browse files
authored
wip: dataclass idea, small fixes (#475)
* wip: dataclass idea, small fixes * fix: fix exception message in base model description
1 parent 9fdf78d commit c0ec55c

26 files changed

+155
-171
lines changed

fastembed/common/model_description.py

+40-53
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from dataclasses import dataclass, field, InitVar
2-
from typing import Optional, List, Dict
1+
from dataclasses import dataclass, field
2+
from typing import Optional, Any
33

44

55
@dataclass(frozen=True)
@@ -15,67 +15,54 @@ def __post_init__(self) -> None:
1515

1616

1717
@dataclass(frozen=True)
18-
class ModelDescription:
18+
class BaseModelDescription:
1919
model: str
2020
sources: ModelSource
2121
model_file: str
22-
dim: Optional[int]
22+
description: str = ""
23+
license: str = ""
24+
size_in_GB: Optional[float] = None
25+
additional_files: list[str] = field(default_factory=list)
2326

24-
description: str
25-
license: str
26-
size_in_GB: Optional[float]
27-
additional_files: List[str] = field(default_factory=list)
28-
tasks: Dict[str, int] = field(default_factory=dict)
27+
def validate_info(self) -> None:
28+
if self.license == "":
29+
raise ValueError("license is required in builtin model description")
30+
31+
if self.description == "":
32+
raise ValueError("description is required in builtin model description")
33+
34+
if self.size_in_GB is None:
35+
raise ValueError("size_in_GB is required in builtin model description")
36+
37+
def __post_init__(self) -> None:
38+
self.validate_info()
2939

3040

3141
@dataclass(frozen=True)
32-
class MultimodalModelDescription(ModelDescription):
33-
dim: int
42+
class DenseModelDescription(BaseModelDescription):
43+
dim: Optional[int] = None
44+
tasks: Optional[dict[str, Any]] = None
45+
46+
def __post_init__(self) -> None:
47+
assert self.dim is not None, "dim is required for dense model description"
48+
self.validate_info()
3449

3550

3651
@dataclass(frozen=True)
37-
class SparseModelDescription(ModelDescription):
38-
_vocab_size: InitVar[Optional[int]] = None
39-
_requires_idf: InitVar[Optional[bool]] = None
40-
41-
vocab_size: int = field(init=False)
42-
requires_idf: Optional[bool] = field(init=False, default=None)
43-
dim: Optional[int] = field(default=None, init=False)
44-
45-
def __init__(
46-
self,
47-
*,
48-
model: str,
49-
sources: ModelSource,
50-
model_file: str,
51-
description: str,
52-
license: str,
53-
size_in_GB: Optional[float],
54-
dim: Optional[int] = None,
55-
additional_files: Optional[List[str]] = None,
56-
tasks: Optional[Dict[str, int]] = None,
57-
vocab_size: int,
58-
requires_idf: Optional[bool] = None,
59-
):
60-
# Call the parent initializer with the fields it needs.
61-
object.__setattr__(self, "model", model)
62-
object.__setattr__(self, "sources", sources)
63-
object.__setattr__(self, "model_file", model_file)
64-
object.__setattr__(self, "dim", dim if dim else None)
65-
object.__setattr__(self, "description", description)
66-
object.__setattr__(self, "license", license)
67-
object.__setattr__(self, "size_in_GB", size_in_GB)
68-
object.__setattr__(
69-
self, "additional_files", additional_files if additional_files is not None else []
70-
)
71-
object.__setattr__(self, "tasks", tasks if tasks is not None else {})
72-
# Set new fields.
73-
object.__setattr__(self, "vocab_size", vocab_size)
74-
object.__setattr__(self, "requires_idf", requires_idf)
52+
class SparseModelDescription(BaseModelDescription):
53+
requires_idf: Optional[bool] = None
54+
vocab_size: Optional[int] = None
7555

7656

7757
@dataclass(frozen=True)
78-
class CustomModelDescription(ModelDescription):
79-
description: str = ""
80-
license: str = ""
81-
size_in_GB: Optional[float] = None
58+
class CustomDenseModelDescription(DenseModelDescription):
59+
def __post_init__(self) -> None:
60+
if self.dim is None:
61+
raise ValueError("dim is required for custom dense model description")
62+
# disable self.validate_info
63+
64+
65+
@dataclass(frozen=True)
66+
class CustomSparseModelDescription(SparseModelDescription):
67+
def __post_init__(self) -> None:
68+
pass # disable self.validate_info

fastembed/common/model_management.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import shutil
55
import tarfile
66
from pathlib import Path
7-
from typing import Any, Optional, Union, Sequence, TypeVar, Generic
7+
from typing import Any, Optional, Union, TypeVar, Generic
88

99
import requests
1010
from huggingface_hub import snapshot_download, model_info, list_repo_tree
@@ -16,16 +16,16 @@
1616
)
1717
from loguru import logger
1818
from tqdm import tqdm
19-
from fastembed.common.model_description import ModelDescription
19+
from fastembed.common.model_description import BaseModelDescription
2020

21-
T = TypeVar("T", bound=ModelDescription)
21+
T = TypeVar("T", bound=BaseModelDescription)
2222

2323

2424
class ModelManagement(Generic[T]):
2525
METADATA_FILE = "files_metadata.json"
2626

2727
@classmethod
28-
def list_supported_models(cls) -> Sequence[T]:
28+
def list_supported_models(cls) -> list[T]:
2929
"""Lists the supported models.
3030
3131
Returns:

fastembed/image/image_embedding.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
from fastembed.common import ImageInput, OnnxProvider
55
from fastembed.image.image_embedding_base import ImageEmbeddingBase
66
from fastembed.image.onnx_embedding import OnnxImageEmbedding
7-
from fastembed.common.model_description import ModelDescription
7+
from fastembed.common.model_description import DenseModelDescription
88

99

1010
class ImageEmbedding(ImageEmbeddingBase):
1111
EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding]
1212

1313
@classmethod
14-
def list_supported_models(cls) -> list[ModelDescription]:
14+
def list_supported_models(cls) -> list[DenseModelDescription]:
1515
"""
1616
Lists the supported models.
1717
@@ -35,7 +35,7 @@ def list_supported_models(cls) -> list[ModelDescription]:
3535
]
3636
```
3737
"""
38-
result: list[ModelDescription] = []
38+
result: list[DenseModelDescription] = []
3939
for embedding in cls.EMBEDDINGS_REGISTRY:
4040
result.extend(embedding.list_supported_models())
4141
return result

fastembed/image/image_embedding_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Iterable, Optional, Any, Union
22

3+
from fastembed.common.model_description import DenseModelDescription
34
from fastembed.common.types import NumpyArray
45
from fastembed.common.model_management import ModelManagement
56
from fastembed.common.types import ImageInput
67

78

8-
class ImageEmbeddingBase(ModelManagement):
9+
class ImageEmbeddingBase(ModelManagement[DenseModelDescription]):
910
def __init__(
1011
self,
1112
model_name: str,

fastembed/image/onnx_embedding.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from fastembed.image.image_embedding_base import ImageEmbeddingBase
1010
from fastembed.image.onnx_image_model import ImageEmbeddingWorker, OnnxImageModel
1111

12-
from fastembed.common.model_description import ModelDescription, ModelSource
12+
from fastembed.common.model_description import DenseModelDescription, ModelSource
1313

14-
supported_onnx_models: list[ModelDescription] = [
15-
ModelDescription(
14+
supported_onnx_models: list[DenseModelDescription] = [
15+
DenseModelDescription(
1616
model="Qdrant/clip-ViT-B-32-vision",
1717
dim=512,
1818
description="Image embeddings, Multimodal (text&image), 2021 year",
@@ -21,7 +21,7 @@
2121
sources=ModelSource(hf="Qdrant/clip-ViT-B-32-vision"),
2222
model_file="model.onnx",
2323
),
24-
ModelDescription(
24+
DenseModelDescription(
2525
model="Qdrant/resnet50-onnx",
2626
dim=2048,
2727
description="Image embeddings, Unimodal (image), 2016 year",
@@ -30,7 +30,7 @@
3030
sources=ModelSource(hf="Qdrant/resnet50-onnx"),
3131
model_file="model.onnx",
3232
),
33-
ModelDescription(
33+
DenseModelDescription(
3434
model="Qdrant/Unicom-ViT-B-16",
3535
dim=768,
3636
description="Image embeddings (more detailed than Unicom-ViT-B-32), Multimodal (text&image), 2023 year",
@@ -39,7 +39,7 @@
3939
sources=ModelSource(hf="Qdrant/Unicom-ViT-B-16"),
4040
model_file="model.onnx",
4141
),
42-
ModelDescription(
42+
DenseModelDescription(
4343
model="Qdrant/Unicom-ViT-B-32",
4444
dim=512,
4545
description="Image embeddings, Multimodal (text&image), 2023 year",
@@ -48,7 +48,7 @@
4848
sources=ModelSource(hf="Qdrant/Unicom-ViT-B-32"),
4949
model_file="model.onnx",
5050
),
51-
ModelDescription(
51+
DenseModelDescription(
5252
model="jinaai/jina-clip-v1",
5353
dim=768,
5454
description="Image embeddings, Multimodal (text&image), 2024 year",
@@ -137,7 +137,7 @@ def load_onnx_model(self) -> None:
137137
)
138138

139139
@classmethod
140-
def list_supported_models(cls) -> list[ModelDescription]:
140+
def list_supported_models(cls) -> list[DenseModelDescription]:
141141
"""
142142
Lists the supported models.
143143

fastembed/late_interaction/colbert.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
LateInteractionTextEmbeddingBase,
1313
)
1414
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
15-
from fastembed.common.model_description import ModelDescription, ModelSource
15+
from fastembed.common.model_description import DenseModelDescription, ModelSource
1616

17-
supported_colbert_models: list[ModelDescription] = [
18-
ModelDescription(
17+
supported_colbert_models: list[DenseModelDescription] = [
18+
DenseModelDescription(
1919
model="colbert-ir/colbertv2.0",
2020
dim=128,
2121
description="Late interaction model",
@@ -24,7 +24,7 @@
2424
sources=ModelSource(hf="colbert-ir/colbertv2.0"),
2525
model_file="model.onnx",
2626
),
27-
ModelDescription(
27+
DenseModelDescription(
2828
model="answerdotai/answerai-colbert-small-v1",
2929
dim=96,
3030
description="Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, 2024 year",
@@ -108,7 +108,7 @@ def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
108108
return encoded
109109

110110
@classmethod
111-
def list_supported_models(cls) -> list[ModelDescription]:
111+
def list_supported_models(cls) -> list[DenseModelDescription]:
112112
"""Lists the supported models.
113113
114114
Returns:

fastembed/late_interaction/jina_colbert.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from fastembed.common.types import NumpyArray
44
from fastembed.late_interaction.colbert import Colbert, ColbertEmbeddingWorker
5-
from fastembed.common.model_description import ModelDescription, ModelSource
5+
from fastembed.common.model_description import DenseModelDescription, ModelSource
66

7-
supported_jina_colbert_models: list[ModelDescription] = [
8-
ModelDescription(
7+
supported_jina_colbert_models: list[DenseModelDescription] = [
8+
DenseModelDescription(
99
model="jinaai/jina-colbert-v2",
1010
dim=128,
1111
description="New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year",
@@ -29,7 +29,7 @@ def _get_worker_class(cls) -> Type[ColbertEmbeddingWorker]:
2929
return JinaColbertEmbeddingWorker
3030

3131
@classmethod
32-
def list_supported_models(cls) -> list[ModelDescription]:
32+
def list_supported_models(cls) -> list[DenseModelDescription]:
3333
"""Lists the supported models.
3434
3535
Returns:

fastembed/late_interaction/late_interaction_embedding_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Iterable, Optional, Union, Any
22

3+
from fastembed.common.model_description import DenseModelDescription
34
from fastembed.common.types import NumpyArray
45
from fastembed.common.model_management import ModelManagement
56

67

7-
class LateInteractionTextEmbeddingBase(ModelManagement):
8+
class LateInteractionTextEmbeddingBase(ModelManagement[DenseModelDescription]):
89
def __init__(
910
self,
1011
model_name: str,

fastembed/late_interaction/late_interaction_text_embedding.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
from fastembed.late_interaction.late_interaction_embedding_base import (
88
LateInteractionTextEmbeddingBase,
99
)
10-
from fastembed.common.model_description import ModelDescription
10+
from fastembed.common.model_description import DenseModelDescription
1111

1212

1313
class LateInteractionTextEmbedding(LateInteractionTextEmbeddingBase):
1414
EMBEDDINGS_REGISTRY: list[Type[LateInteractionTextEmbeddingBase]] = [Colbert, JinaColbert]
1515

1616
@classmethod
17-
def list_supported_models(cls) -> list[ModelDescription]:
17+
def list_supported_models(cls) -> list[DenseModelDescription]:
1818
"""
1919
Lists the supported models.
2020
2121
Returns:
22-
list[ModelDescription]: A list of dictionaries containing the model information.
22+
list[DenseModelDescription]: A list of dictionaries containing the model information.
2323
2424
Example:
2525
```
@@ -38,7 +38,7 @@ def list_supported_models(cls) -> list[ModelDescription]:
3838
]
3939
```
4040
"""
41-
result: list[ModelDescription] = []
41+
result: list[DenseModelDescription] = []
4242
for embedding in cls.EMBEDDINGS_REGISTRY:
4343
result.extend(embedding.list_supported_models())
4444
return result

fastembed/late_interaction_multimodal/colpali.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
TextEmbeddingWorker,
1616
ImageEmbeddingWorker,
1717
)
18-
from fastembed.common.model_description import MultimodalModelDescription, ModelSource
18+
from fastembed.common.model_description import DenseModelDescription, ModelSource
1919

20-
supported_colpali_models: list[MultimodalModelDescription] = [
21-
MultimodalModelDescription(
20+
supported_colpali_models: list[DenseModelDescription] = [
21+
DenseModelDescription(
2222
model="Qdrant/colpali-v1.3-fp16",
2323
dim=128,
2424
description="Text embeddings, Multimodal (text&image), English, 50 tokens query length truncation, 2024.",
@@ -108,7 +108,7 @@ def __init__(
108108
self.load_onnx_model()
109109

110110
@classmethod
111-
def list_supported_models(cls) -> list[MultimodalModelDescription]:
111+
def list_supported_models(cls) -> list[DenseModelDescription]:
112112
"""Lists the supported models.
113113
114114
Returns:
@@ -139,6 +139,7 @@ def _post_process_onnx_image_output(
139139
Returns:
140140
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
141141
"""
142+
assert self.model_description.dim is not None, "Model dim is not defined"
142143
return output.model_output.reshape(
143144
output.model_output.shape[0], -1, self.model_description.dim
144145
).astype(np.float32)

0 commit comments

Comments
 (0)