|
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 | from PIL import Image
|
5 |
| - |
6 | 5 | from fastembed.image.transform.functional import (
|
7 | 6 | center_crop,
|
8 | 7 | normalize,
|
9 | 8 | resize,
|
10 | 9 | convert_to_rgb,
|
11 | 10 | rescale,
|
| 11 | + pil2ndarray |
12 | 12 | )
|
13 | 13 |
|
14 | 14 |
|
@@ -59,68 +59,110 @@ def __init__(self, scale: float = 1 / 255):
|
59 | 59 | def __call__(self, images: List[np.ndarray]) -> List[np.ndarray]:
|
60 | 60 | return [rescale(image, scale=self.scale) for image in images]
|
61 | 61 |
|
| 62 | +class PILtoNDarray(Transform): |
| 63 | + def __call__(self, images: List[Union[Image.Image, np.ndarray]]) -> List[np.ndarray]: |
| 64 | + return [pil2ndarray(image) for image in images] |
62 | 65 |
|
63 | 66 | class Compose:
|
64 | 67 | def __init__(self, transforms: List[Transform]):
|
65 | 68 | self.transforms = transforms
|
66 | 69 |
|
67 |
| - def __call__( |
68 |
| - self, images: Union[List[Image.Image], List[np.ndarray]] |
69 |
| - ) -> Union[List[np.ndarray], List[Image.Image]]: |
| 70 | + def __call__(self, images: Union[List[Image.Image], List[np.ndarray]]) -> Union[List[np.ndarray], List[Image.Image]]: |
70 | 71 | for transform in self.transforms:
|
71 | 72 | images = transform(images)
|
72 | 73 | return images
|
73 | 74 |
|
74 | 75 | @classmethod
|
75 | 76 | def from_config(cls, config: Dict[str, Any]) -> "Compose":
|
76 | 77 | """Creates processor from a config dict.
|
77 |
| -
|
78 |
| - Args: |
79 |
| - config (Dict[str, Any]): Configuration dictionary. |
80 |
| -
|
81 |
| - Valid keys: |
82 |
| - - do_resize |
83 |
| - - size |
84 |
| - - do_center_crop |
85 |
| - - crop_size |
86 |
| - - do_rescale |
87 |
| - - rescale_factor |
88 |
| - - do_normalize |
89 |
| - - image_mean |
90 |
| - - image_std |
91 |
| - Valid size keys (nested): |
92 |
| - - {"height", "width"} |
93 |
| - - {"shortest_edge"} |
94 |
| -
|
95 |
| - Returns: |
96 |
| - Compose: Image processor. |
| 78 | + Args: |
| 79 | + config (Dict[str, Any]): Configuration dictionary. |
| 80 | +
|
| 81 | + Valid keys: |
| 82 | + - do_resize |
| 83 | + - size |
| 84 | + - do_center_crop |
| 85 | + - crop_size |
| 86 | + - do_rescale |
| 87 | + - rescale_factor |
| 88 | + - do_normalize |
| 89 | + - image_mean |
| 90 | + - image_std |
| 91 | + Valid size keys (nested): |
| 92 | + - {"height", "width"} |
| 93 | + - {"shortest_edge"} |
| 94 | +
|
| 95 | + Returns: |
| 96 | + Compose: Image processor. |
97 | 97 | """
|
98 |
| - transforms = [ConvertToRGB()] |
99 |
| - if config.get("do_resize", False): |
100 |
| - size = config["size"] |
101 |
| - if "shortest_edge" in size: |
102 |
| - size = size["shortest_edge"] |
103 |
| - elif "height" in size and "width" in size: |
104 |
| - size = (size["height"], size["width"]) |
105 |
| - else: |
106 |
| - raise ValueError( |
107 |
| - "Size must contain either 'shortest_edge' or 'height' and 'width'." |
108 |
| - ) |
109 |
| - transforms.append( |
110 |
| - Resize(size=size, resample=config.get("resample", Image.Resampling.BICUBIC)) |
111 |
| - ) |
112 |
| - if config.get("do_center_crop", False): |
113 |
| - crop_size = config["crop_size"] |
114 |
| - if isinstance(crop_size, int): |
115 |
| - crop_size = (crop_size, crop_size) |
116 |
| - elif isinstance(crop_size, dict): |
117 |
| - crop_size = (crop_size["height"], crop_size["width"]) |
| 98 | + transforms = [] |
| 99 | + cls._get_convert_to_rgb(transforms, config) |
| 100 | + cls._get_resize(transforms, config) |
| 101 | + cls._get_center_crop(transforms, config) |
| 102 | + cls._get_pil2ndarray(transforms, config) |
| 103 | + cls._get_rescale(transforms, config) |
| 104 | + cls._get_normalize(transforms, config) |
| 105 | + return cls(transforms=transforms) |
| 106 | + |
| 107 | + @staticmethod |
| 108 | + def _get_convert_to_rgb(transforms: List[Transform], config: Dict[str, Any]): |
| 109 | + transforms.append(ConvertToRGB()) |
| 110 | + |
| 111 | + @staticmethod |
| 112 | + def _get_resize(transforms: List[Transform], config: Dict[str, Any]): |
| 113 | + mode = config.get('image_processor_type', 'CLIPImageProcessor') |
| 114 | + if mode == 'CLIPImageProcessor': |
| 115 | + if config.get("do_resize", False): |
| 116 | + size = config["size"] |
| 117 | + if "shortest_edge" in size: |
| 118 | + size = size["shortest_edge"] |
| 119 | + elif "height" in size and "width" in size: |
| 120 | + size = (size["height"], size["width"]) |
| 121 | + else: |
| 122 | + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") |
| 123 | + transforms.append(Resize(size=size, resample=config.get("resample", Image.Resampling.BICUBIC))) |
| 124 | + elif mode == 'ConvNextFeatureExtractor': |
| 125 | + if 'size' in config and "shortest_edge" not in config['size']: |
| 126 | + raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {config['size'].keys()}") |
| 127 | + shortest_edge = config['size']["shortest_edge"] |
| 128 | + crop_pct = config.get("crop_pct", 0.875) |
| 129 | + if shortest_edge < 384: |
| 130 | + # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct |
| 131 | + resize_shortest_edge = int(shortest_edge / crop_pct) |
| 132 | + transforms.append(Resize(size=resize_shortest_edge, resample=config.get("resample", Image.Resampling.BICUBIC))) |
| 133 | + transforms.append(CenterCrop(size=(shortest_edge, shortest_edge))) |
118 | 134 | else:
|
119 |
| - raise ValueError(f"Invalid crop size: {crop_size}") |
120 |
| - transforms.append(CenterCrop(size=crop_size)) |
| 135 | + transforms.append(Resize(size=(shortest_edge, shortest_edge), resample=config.get("resample", Image.Resampling.BICUBIC))) |
| 136 | + |
| 137 | + @staticmethod |
| 138 | + def _get_center_crop(transforms: List[Transform], config: Dict[str, Any]): |
| 139 | + mode = config.get('image_processor_type', 'CLIPImageProcessor') |
| 140 | + if mode == 'CLIPImageProcessor': |
| 141 | + if config.get("do_center_crop", False): |
| 142 | + crop_size = config["crop_size"] |
| 143 | + if isinstance(crop_size, int): |
| 144 | + crop_size = (crop_size, crop_size) |
| 145 | + elif isinstance(crop_size, dict): |
| 146 | + crop_size = (crop_size["height"], crop_size["width"]) |
| 147 | + else: |
| 148 | + raise ValueError(f"Invalid crop size: {crop_size}") |
| 149 | + transforms.append(CenterCrop(size=crop_size)) |
| 150 | + elif mode == 'ConvNextFeatureExtractor': |
| 151 | + pass |
| 152 | + else: |
| 153 | + raise ValueError(f"Preprocessor {mode} is not supported") |
| 154 | + |
| 155 | + @staticmethod |
| 156 | + def _get_pil2ndarray(transforms: List[Transform], config: Dict[str, Any]): |
| 157 | + transforms.append(PILtoNDarray()) |
| 158 | + |
| 159 | + @staticmethod |
| 160 | + def _get_rescale(transforms: List[Transform], config: Dict[str, Any]): |
121 | 161 | if config.get("do_rescale", True):
|
122 | 162 | rescale_factor = config.get("rescale_factor", 1 / 255)
|
123 | 163 | transforms.append(Rescale(scale=rescale_factor))
|
| 164 | + |
| 165 | + @staticmethod |
| 166 | + def _get_normalize(transforms: List[Transform], config: Dict[str, Any]): |
124 | 167 | if config.get("do_normalize", False):
|
125 | 168 | transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
|
126 |
| - return cls(transforms=transforms) |
|
0 commit comments