Skip to content

Commit 3b5e4c8

Browse files
new: Added jina clip v1 (#408)
* WIP: Added jina clip text embedding * WIP: Added preprocess for jina clip * WIP: Added jina clip vision (not sure if it works yet) * improve: Improved mean pooling if the output doesnt have seq length * fix: Fixed jina clip text * nit * fix: Fixed jina clip image preprocessor * fix: Fix type hints new: added resize2square * tests: Add jina clip vision test case * nit * refactor: Update fastembed/image/transform/operators.py Co-authored-by: George <george.panchuk@qdrant.tech> * fix: Fix indentation * refactor: Refactored how we call padding for image * fix: Fix pad to image when resized size larger than new square canvas * refactor: minor refactor * refactor: Refactor some functions in preprocess image * fix: Fix to pad image with specified fill color * refactor: Change resize to classmethod * fix: Fix jina clip text v1 * fix: fix pad to square for some rectangular images (#421) --------- Co-authored-by: George <george.panchuk@qdrant.tech>
1 parent 516170c commit 3b5e4c8

File tree

7 files changed

+144
-10
lines changed

7 files changed

+144
-10
lines changed

fastembed/image/onnx_embedding.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@
5353
},
5454
"model_file": "model.onnx",
5555
},
56+
{
57+
"model": "jinaai/jina-clip-v1",
58+
"dim": 768,
59+
"description": "Image embeddings, Multimodal (text&image), 2024 year",
60+
"license": "apache-2.0",
61+
"size_in_GB": 0.34,
62+
"sources": {
63+
"hf": "jinaai/jina-clip-v1",
64+
},
65+
"model_file": "onnx/vision_model.onnx",
66+
},
5667
]
5768

5869

fastembed/image/transform/functional.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def center_crop(
6262

6363
def normalize(
6464
image: np.ndarray,
65-
mean=Union[float, np.ndarray],
66-
std=Union[float, np.ndarray],
65+
mean: Union[float, np.ndarray],
66+
std: Union[float, np.ndarray],
6767
) -> np.ndarray:
6868
if not isinstance(image, np.ndarray):
6969
raise ValueError("image must be a numpy array")
@@ -96,10 +96,10 @@ def normalize(
9696

9797

9898
def resize(
99-
image: Image,
99+
image: Image.Image,
100100
size: Union[int, tuple[int, int]],
101-
resample: Image.Resampling = Image.Resampling.BILINEAR,
102-
) -> Image:
101+
resample: Union[int, Image.Resampling] = Image.Resampling.BILINEAR,
102+
) -> Image.Image:
103103
if isinstance(size, tuple):
104104
return image.resize(size, resample)
105105

@@ -122,3 +122,29 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]):
122122
if isinstance(image, Image.Image):
123123
return np.asarray(image).transpose((2, 0, 1))
124124
return image
125+
126+
127+
def pad2square(
128+
image: Image.Image,
129+
size: int,
130+
fill_color: Union[str, int, tuple[int, ...]] = 0,
131+
) -> Image.Image:
132+
height, width = image.height, image.width
133+
134+
left, right = 0, width
135+
top, bottom = 0, height
136+
137+
crop_required = False
138+
if width > size:
139+
left = (width - size) // 2
140+
right = left + size
141+
crop_required = True
142+
143+
if height > size:
144+
top = (height - size) // 2
145+
bottom = top + size
146+
crop_required = True
147+
148+
new_image = Image.new(mode="RGB", size=(size, size), color=fill_color)
149+
new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image)
150+
return new_image

fastembed/image/transform/operators.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Union
1+
from typing import Any, Union, Optional
22

33
import numpy as np
44
from PIL import Image
@@ -10,6 +10,7 @@
1010
pil2ndarray,
1111
rescale,
1212
resize,
13+
pad2square,
1314
)
1415

1516

@@ -66,6 +67,21 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar
6667
return [pil2ndarray(image) for image in images]
6768

6869

70+
class PadtoSquare(Transform):
71+
def __init__(
72+
self,
73+
size: int,
74+
fill_color: Optional[Union[str, int, tuple[int, ...]]] = None,
75+
):
76+
self.size = size
77+
self.fill_color = fill_color
78+
79+
def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
80+
return [
81+
pad2square(image=image, size=self.size, fill_color=self.fill_color) for image in images
82+
]
83+
84+
6985
class Compose:
7086
def __init__(self, transforms: list[Transform]):
7187
self.transforms = transforms
@@ -85,14 +101,20 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
85101
86102
Valid keys:
87103
- do_resize
104+
- resize_mode
88105
- size
106+
- fill_color
89107
- do_center_crop
90108
- crop_size
91109
- do_rescale
92110
- rescale_factor
93111
- do_normalize
94112
- image_mean
113+
- mean
95114
- image_std
115+
- std
116+
- resample
117+
- interpolation
96118
Valid size keys (nested):
97119
- {"height", "width"}
98120
- {"shortest_edge"}
@@ -103,6 +125,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
103125
transforms = []
104126
cls._get_convert_to_rgb(transforms, config)
105127
cls._get_resize(transforms, config)
128+
cls._get_pad2square(transforms, config)
106129
cls._get_center_crop(transforms, config)
107130
cls._get_pil2ndarray(transforms, config)
108131
cls._get_rescale(transforms, config)
@@ -113,8 +136,8 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
113136
def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]):
114137
transforms.append(ConvertToRGB())
115138

116-
@staticmethod
117-
def _get_resize(transforms: list[Transform], config: dict[str, Any]):
139+
@classmethod
140+
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]):
118141
mode = config.get("image_processor_type", "CLIPImageProcessor")
119142
if mode == "CLIPImageProcessor":
120143
if config.get("do_resize", False):
@@ -157,6 +180,24 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]):
157180
resample=config.get("resample", Image.Resampling.BICUBIC),
158181
)
159182
)
183+
elif mode == "JinaCLIPImageProcessor":
184+
interpolation = config.get("interpolation")
185+
if isinstance(interpolation, str):
186+
resample = cls._interpolation_resolver(interpolation)
187+
else:
188+
resample = interpolation or Image.Resampling.BICUBIC
189+
190+
if "size" in config:
191+
resize_mode = config.get("resize_mode", "shortest")
192+
if resize_mode == "shortest":
193+
transforms.append(
194+
Resize(
195+
size=config["size"],
196+
resample=resample,
197+
)
198+
)
199+
else:
200+
raise ValueError(f"Preprocessor {mode} is not supported")
160201

161202
@staticmethod
162203
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
@@ -173,6 +214,8 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
173214
transforms.append(CenterCrop(size=crop_size))
174215
elif mode == "ConvNextFeatureExtractor":
175216
pass
217+
elif mode == "JinaCLIPImageProcessor":
218+
pass
176219
else:
177220
raise ValueError(f"Preprocessor {mode} is not supported")
178221

@@ -190,3 +233,36 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]):
190233
def _get_normalize(transforms: list[Transform], config: dict[str, Any]):
191234
if config.get("do_normalize", False):
192235
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
236+
elif "mean" in config and "std" in config:
237+
transforms.append(Normalize(mean=config["mean"], std=config["std"]))
238+
239+
@staticmethod
240+
def _get_pad2square(transforms: list[Transform], config: dict[str, Any]):
241+
mode = config.get("image_processor_type", "CLIPImageProcessor")
242+
if mode == "CLIPImageProcessor":
243+
pass
244+
elif mode == "ConvNextFeatureExtractor":
245+
pass
246+
elif mode == "JinaCLIPImageProcessor":
247+
transforms.append(
248+
PadtoSquare(
249+
size=config["size"],
250+
fill_color=config.get("fill_color", 0),
251+
)
252+
)
253+
254+
@staticmethod
255+
def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling:
256+
interpolation_map = {
257+
"nearest": Image.Resampling.NEAREST,
258+
"lanczos": Image.Resampling.LANCZOS,
259+
"bilinear": Image.Resampling.BILINEAR,
260+
"bicubic": Image.Resampling.BICUBIC,
261+
"box": Image.Resampling.BOX,
262+
"hamming": Image.Resampling.HAMMING,
263+
}
264+
265+
if resample and (method := interpolation_map.get(resample.lower())):
266+
return method
267+
268+
raise ValueError(f"Unknown interpolation method: {resample}")

fastembed/text/onnx_embedding.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,17 @@
164164
},
165165
"model_file": "onnx/model.onnx",
166166
},
167+
{
168+
"model": "jinaai/jina-clip-v1",
169+
"dim": 768,
170+
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
171+
"license": "apache-2.0",
172+
"size_in_GB": 0.55,
173+
"sources": {
174+
"hf": "jinaai/jina-clip-v1",
175+
},
176+
"model_file": "onnx/text_model.onnx",
177+
},
167178
]
168179

169180

@@ -285,7 +296,13 @@ def _preprocess_onnx_input(
285296

286297
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
287298
embeddings = output.model_output
288-
return normalize(embeddings[:, 0]).astype(np.float32)
299+
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
300+
processed_embeddings = embeddings[:, 0]
301+
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
302+
processed_embeddings = embeddings
303+
else:
304+
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
305+
return normalize(processed_embeddings).astype(np.float32)
289306

290307
def load_onnx_model(self) -> None:
291308
self._load_onnx_model(

fastembed/text/text_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
return
7979

8080
raise ValueError(
81-
f"Model {model_name} is not supported in TextEmbedding."
81+
f"Model {model_name} is not supported in TextEmbedding. "
8282
"Please check the supported models using `TextEmbedding.list_supported_models()`"
8383
)
8484

tests/test_image_onnx_embeddings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
"Qdrant/Unicom-ViT-B-32": np.array(
2222
[0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186]
2323
),
24+
"jinaai/jina-clip-v1": np.array(
25+
[-0.029, 0.0216, 0.0396, 0.0283, -0.0023, 0.0151, 0.011, -0.0235, 0.0251, -0.0343]
26+
),
2427
}
2528

2629

tests/test_text_onnx_embeddings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
6666
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
6767
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
68+
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
6869
}
6970

7071

0 commit comments

Comments
 (0)