Skip to content

Commit 85aaae4

Browse files
I8dNLod.rudenkojoein
authored
Add resnet (#246)
* Resnet support added * Tests fixed Shapes matching for Resnet50-onnx Example of Resnet50 to onnx conversion (basic) * Removed optional conversion from PIL to np.ndarray and now it it's made default Fixed test accordingly * Refactoring of pil2ndarray * Partial support of convnext preprocessing Resize logic * normalize canonical value * Style changes for review * new: update resnet repo --------- Co-authored-by: d.rudenko <[email protected]> Co-authored-by: George Panchuk <[email protected]>
1 parent dfd25d4 commit 85aaae4

6 files changed

+249
-60
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "4bdb2a91-fa2a-4cee-ad5a-176cc957394d",
7+
"metadata": {
8+
"ExecuteTime": {
9+
"end_time": "2024-05-23T12:15:28.171586Z",
10+
"start_time": "2024-05-23T12:15:28.076314Z"
11+
}
12+
},
13+
"outputs": [
14+
{
15+
"ename": "ModuleNotFoundError",
16+
"evalue": "No module named 'torch'",
17+
"output_type": "error",
18+
"traceback": [
19+
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
20+
"\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
21+
"Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01monnx\u001B[39;00m\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorchvision\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m\n",
22+
"\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'torch'"
23+
]
24+
}
25+
],
26+
"source": [
27+
"import torch\n",
28+
"import torch.onnx\n",
29+
"import torchvision.models as models\n",
30+
"import torchvision.transforms as transforms\n",
31+
"from PIL import Image\n",
32+
"import numpy as np\n",
33+
"from tests.config import TEST_MISC_DIR\n",
34+
"\n",
35+
"# Load pre-trained ResNet-50 model\n",
36+
"resnet = models.resnet50(pretrained=True)\n",
37+
"resnet = torch.nn.Sequential(*(list(resnet.children())[:-1])) # Remove the last fully connected layer\n",
38+
"resnet.eval()\n",
39+
"\n",
40+
"# Define preprocessing transform\n",
41+
"preprocess = transforms.Compose([\n",
42+
" transforms.Resize(256),\n",
43+
" transforms.CenterCrop(224),\n",
44+
" transforms.ToTensor(),\n",
45+
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
46+
"])\n",
47+
"\n",
48+
"# Load and preprocess the image\n",
49+
"def preprocess_image(image_path):\n",
50+
" input_image = Image.open(image_path)\n",
51+
" input_tensor = preprocess(input_image)\n",
52+
" input_batch = input_tensor.unsqueeze(0) # Add batch dimension\n",
53+
" return input_batch\n",
54+
"\n",
55+
"# Example input for exporting\n",
56+
"input_image = preprocess_image('example.jpg')\n",
57+
"\n",
58+
"# Export the model to ONNX with dynamic axes\n",
59+
"torch.onnx.export(\n",
60+
" resnet, \n",
61+
" input_image, \n",
62+
" \"model.onnx\", \n",
63+
" export_params=True, \n",
64+
" opset_version=9, \n",
65+
" input_names=['input'], \n",
66+
" output_names=['output'],\n",
67+
" dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}\n",
68+
")\n",
69+
"\n",
70+
"# Load ONNX model\n",
71+
"import onnx\n",
72+
"import onnxruntime as ort\n",
73+
"\n",
74+
"onnx_model = onnx.load(\"model.onnx\")\n",
75+
"ort_session = ort.InferenceSession(\"model.onnx\")\n",
76+
"\n",
77+
"# Run inference and extract feature vectors\n",
78+
"def extract_feature_vectors(image_paths):\n",
79+
" input_images = [preprocess_image(image_path) for image_path in image_paths]\n",
80+
" input_batch = torch.cat(input_images, dim=0) # Combine images into a single batch\n",
81+
" ort_inputs = {ort_session.get_inputs()[0].name: input_batch.numpy()}\n",
82+
" ort_outs = ort_session.run(None, ort_inputs)\n",
83+
" return ort_outs[0]\n",
84+
"\n",
85+
"# Example usage\n",
86+
"images = [TEST_MISC_DIR / \"image.jpeg\", str(TEST_MISC_DIR / \"small_image.jpeg\")] # Replace with your image paths\n",
87+
"feature_vectors = extract_feature_vectors(images)\n",
88+
"print(\"Feature vector shape:\", feature_vectors.shape)\n"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"outputs": [],
94+
"source": [],
95+
"metadata": {
96+
"collapsed": false
97+
},
98+
"id": "baa650c4cb3e0e6d"
99+
}
100+
],
101+
"metadata": {
102+
"kernelspec": {
103+
"display_name": "Python 3 (ipykernel)",
104+
"language": "python",
105+
"name": "python3"
106+
},
107+
"language_info": {
108+
"codemirror_mode": {
109+
"name": "ipython",
110+
"version": 3
111+
},
112+
"file_extension": ".py",
113+
"mimetype": "text/x-python",
114+
"name": "python",
115+
"nbconvert_exporter": "python",
116+
"pygments_lexer": "ipython3",
117+
"version": "3.12.2"
118+
}
119+
},
120+
"nbformat": 4,
121+
"nbformat_minor": 5
122+
}

fastembed/image/onnx_embedding.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,17 @@
1818
"hf": "Qdrant/clip-ViT-B-32-vision",
1919
},
2020
"model_file": "model.onnx",
21-
}
21+
},
22+
{
23+
"model": "Qdrant/resnet50-onnx",
24+
"dim": 2048,
25+
"description": "ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.",
26+
"size_in_GB": 0.1,
27+
"sources": {
28+
"hf": "Qdrant/resnet50-onnx",
29+
},
30+
"model_file": "model.onnx",
31+
},
2232
]
2333

2434

fastembed/image/onnx_image_model.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,19 @@ def load_onnx_model(
4646
)
4747
self.processor = load_preprocessor(model_dir=model_dir)
4848

49+
def _build_onnx_input(self, encoded: np.ndarray) -> Dict[str, np.ndarray]:
50+
return {node.name: encoded for node in self.model.get_inputs()}
51+
4952
def onnx_embed(self, images: List[PathInput]) -> OnnxOutputContext:
5053
with contextlib.ExitStack():
5154
image_files = [Image.open(image) for image in images]
5255
encoded = self.processor(image_files)
53-
onnx_input = {"pixel_values": encoded}
56+
onnx_input = self._build_onnx_input(encoded)
5457
onnx_input = self._preprocess_onnx_input(onnx_input)
55-
5658
model_output = self.model.run(None, onnx_input)
57-
embeddings = model_output[0]
59+
60+
embeddings = model_output[0].reshape(len(images), -1)
61+
5862
return OnnxOutputContext(
5963
model_output=embeddings
6064
)
@@ -82,7 +86,6 @@ def _embed_images(
8286

8387
if parallel is None or is_small:
8488
for batch in iter_batch(images, batch_size):
85-
# open and preprocess images
8689
yield from self._post_process_onnx_output(self.onnx_embed(batch))
8790
else:
8891
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"

fastembed/image/transform/functional.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@ def convert_to_rgb(image: Image.Image) -> Image.Image:
1414

1515

1616
def center_crop(
17-
image: Image.Image,
17+
image: Union[Image.Image, np.ndarray],
1818
size: Tuple[int, int],
1919
) -> np.ndarray:
20-
orig_height, orig_width = image.height, image.width
21-
crop_height, crop_width = size
20+
if isinstance(image, np.ndarray):
21+
_, orig_height, orig_width = image.shape
22+
else:
23+
orig_height, orig_width = image.height, image.width
24+
# (H, W, C) -> (C, H, W)
25+
image = np.array(image).transpose((2, 0, 1))
2226

23-
# (H, W, C) -> (C, H, W)
24-
image = np.array(image).transpose((2, 0, 1))
27+
crop_height, crop_width = size
2528

2629
# left upper corner (0, 0)
2730
top = (orig_height - crop_height) // 2
@@ -96,7 +99,7 @@ def normalize(
9699
def resize(
97100
image: Image,
98101
size: Union[int, Tuple[int, int]],
99-
resample: Image.Resampling = Image.Resampling.BICUBIC,
102+
resample: Image.Resampling = Image.Resampling.BILINEAR,
100103
) -> Image:
101104
if isinstance(size, tuple):
102105
return image.resize(size, resample)
@@ -109,9 +112,14 @@ def resize(
109112
new_size = (new_short, new_long)
110113
else:
111114
new_size = (new_long, new_short)
112-
113-
return image.resize(new_size, Image.Resampling.BICUBIC)
115+
return image.resize(new_size, resample)
114116

115117

116118
def rescale(image: np.ndarray, scale: float, dtype=np.float32) -> np.ndarray:
117119
return (image * scale).astype(dtype)
120+
121+
122+
def pil2ndarray(image: Union[Image.Image, np.ndarray]):
123+
if isinstance(image, Image.Image):
124+
return np.asarray(image).transpose((2, 0, 1))
125+
return image

fastembed/image/transform/operators.py

+89-47
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import numpy as np
44
from PIL import Image
5-
65
from fastembed.image.transform.functional import (
76
center_crop,
87
normalize,
98
resize,
109
convert_to_rgb,
1110
rescale,
11+
pil2ndarray
1212
)
1313

1414

@@ -59,68 +59,110 @@ def __init__(self, scale: float = 1 / 255):
5959
def __call__(self, images: List[np.ndarray]) -> List[np.ndarray]:
6060
return [rescale(image, scale=self.scale) for image in images]
6161

62+
class PILtoNDarray(Transform):
63+
def __call__(self, images: List[Union[Image.Image, np.ndarray]]) -> List[np.ndarray]:
64+
return [pil2ndarray(image) for image in images]
6265

6366
class Compose:
6467
def __init__(self, transforms: List[Transform]):
6568
self.transforms = transforms
6669

67-
def __call__(
68-
self, images: Union[List[Image.Image], List[np.ndarray]]
69-
) -> Union[List[np.ndarray], List[Image.Image]]:
70+
def __call__(self, images: Union[List[Image.Image], List[np.ndarray]]) -> Union[List[np.ndarray], List[Image.Image]]:
7071
for transform in self.transforms:
7172
images = transform(images)
7273
return images
7374

7475
@classmethod
7576
def from_config(cls, config: Dict[str, Any]) -> "Compose":
7677
"""Creates processor from a config dict.
77-
78-
Args:
79-
config (Dict[str, Any]): Configuration dictionary.
80-
81-
Valid keys:
82-
- do_resize
83-
- size
84-
- do_center_crop
85-
- crop_size
86-
- do_rescale
87-
- rescale_factor
88-
- do_normalize
89-
- image_mean
90-
- image_std
91-
Valid size keys (nested):
92-
- {"height", "width"}
93-
- {"shortest_edge"}
94-
95-
Returns:
96-
Compose: Image processor.
78+
Args:
79+
config (Dict[str, Any]): Configuration dictionary.
80+
81+
Valid keys:
82+
- do_resize
83+
- size
84+
- do_center_crop
85+
- crop_size
86+
- do_rescale
87+
- rescale_factor
88+
- do_normalize
89+
- image_mean
90+
- image_std
91+
Valid size keys (nested):
92+
- {"height", "width"}
93+
- {"shortest_edge"}
94+
95+
Returns:
96+
Compose: Image processor.
9797
"""
98-
transforms = [ConvertToRGB()]
99-
if config.get("do_resize", False):
100-
size = config["size"]
101-
if "shortest_edge" in size:
102-
size = size["shortest_edge"]
103-
elif "height" in size and "width" in size:
104-
size = (size["height"], size["width"])
105-
else:
106-
raise ValueError(
107-
"Size must contain either 'shortest_edge' or 'height' and 'width'."
108-
)
109-
transforms.append(
110-
Resize(size=size, resample=config.get("resample", Image.Resampling.BICUBIC))
111-
)
112-
if config.get("do_center_crop", False):
113-
crop_size = config["crop_size"]
114-
if isinstance(crop_size, int):
115-
crop_size = (crop_size, crop_size)
116-
elif isinstance(crop_size, dict):
117-
crop_size = (crop_size["height"], crop_size["width"])
98+
transforms = []
99+
cls._get_convert_to_rgb(transforms, config)
100+
cls._get_resize(transforms, config)
101+
cls._get_center_crop(transforms, config)
102+
cls._get_pil2ndarray(transforms, config)
103+
cls._get_rescale(transforms, config)
104+
cls._get_normalize(transforms, config)
105+
return cls(transforms=transforms)
106+
107+
@staticmethod
108+
def _get_convert_to_rgb(transforms: List[Transform], config: Dict[str, Any]):
109+
transforms.append(ConvertToRGB())
110+
111+
@staticmethod
112+
def _get_resize(transforms: List[Transform], config: Dict[str, Any]):
113+
mode = config.get('image_processor_type', 'CLIPImageProcessor')
114+
if mode == 'CLIPImageProcessor':
115+
if config.get("do_resize", False):
116+
size = config["size"]
117+
if "shortest_edge" in size:
118+
size = size["shortest_edge"]
119+
elif "height" in size and "width" in size:
120+
size = (size["height"], size["width"])
121+
else:
122+
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
123+
transforms.append(Resize(size=size, resample=config.get("resample", Image.Resampling.BICUBIC)))
124+
elif mode == 'ConvNextFeatureExtractor':
125+
if 'size' in config and "shortest_edge" not in config['size']:
126+
raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {config['size'].keys()}")
127+
shortest_edge = config['size']["shortest_edge"]
128+
crop_pct = config.get("crop_pct", 0.875)
129+
if shortest_edge < 384:
130+
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
131+
resize_shortest_edge = int(shortest_edge / crop_pct)
132+
transforms.append(Resize(size=resize_shortest_edge, resample=config.get("resample", Image.Resampling.BICUBIC)))
133+
transforms.append(CenterCrop(size=(shortest_edge, shortest_edge)))
118134
else:
119-
raise ValueError(f"Invalid crop size: {crop_size}")
120-
transforms.append(CenterCrop(size=crop_size))
135+
transforms.append(Resize(size=(shortest_edge, shortest_edge), resample=config.get("resample", Image.Resampling.BICUBIC)))
136+
137+
@staticmethod
138+
def _get_center_crop(transforms: List[Transform], config: Dict[str, Any]):
139+
mode = config.get('image_processor_type', 'CLIPImageProcessor')
140+
if mode == 'CLIPImageProcessor':
141+
if config.get("do_center_crop", False):
142+
crop_size = config["crop_size"]
143+
if isinstance(crop_size, int):
144+
crop_size = (crop_size, crop_size)
145+
elif isinstance(crop_size, dict):
146+
crop_size = (crop_size["height"], crop_size["width"])
147+
else:
148+
raise ValueError(f"Invalid crop size: {crop_size}")
149+
transforms.append(CenterCrop(size=crop_size))
150+
elif mode == 'ConvNextFeatureExtractor':
151+
pass
152+
else:
153+
raise ValueError(f"Preprocessor {mode} is not supported")
154+
155+
@staticmethod
156+
def _get_pil2ndarray(transforms: List[Transform], config: Dict[str, Any]):
157+
transforms.append(PILtoNDarray())
158+
159+
@staticmethod
160+
def _get_rescale(transforms: List[Transform], config: Dict[str, Any]):
121161
if config.get("do_rescale", True):
122162
rescale_factor = config.get("rescale_factor", 1 / 255)
123163
transforms.append(Rescale(scale=rescale_factor))
164+
165+
@staticmethod
166+
def _get_normalize(transforms: List[Transform], config: Dict[str, Any]):
124167
if config.get("do_normalize", False):
125168
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
126-
return cls(transforms=transforms)

0 commit comments

Comments
 (0)