|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | from PIL import Image |
5 | | - |
6 | 5 | from fastembed.image.transform.functional import ( |
7 | 6 | center_crop, |
8 | 7 | normalize, |
9 | 8 | resize, |
10 | 9 | convert_to_rgb, |
11 | 10 | rescale, |
| 11 | + pil2ndarray |
12 | 12 | ) |
13 | 13 |
|
14 | 14 |
|
@@ -59,68 +59,110 @@ def __init__(self, scale: float = 1 / 255): |
59 | 59 | def __call__(self, images: List[np.ndarray]) -> List[np.ndarray]: |
60 | 60 | return [rescale(image, scale=self.scale) for image in images] |
61 | 61 |
|
| 62 | +class PILtoNDarray(Transform): |
| 63 | + def __call__(self, images: List[Union[Image.Image, np.ndarray]]) -> List[np.ndarray]: |
| 64 | + return [pil2ndarray(image) for image in images] |
62 | 65 |
|
63 | 66 | class Compose: |
64 | 67 | def __init__(self, transforms: List[Transform]): |
65 | 68 | self.transforms = transforms |
66 | 69 |
|
67 | | - def __call__( |
68 | | - self, images: Union[List[Image.Image], List[np.ndarray]] |
69 | | - ) -> Union[List[np.ndarray], List[Image.Image]]: |
| 70 | + def __call__(self, images: Union[List[Image.Image], List[np.ndarray]]) -> Union[List[np.ndarray], List[Image.Image]]: |
70 | 71 | for transform in self.transforms: |
71 | 72 | images = transform(images) |
72 | 73 | return images |
73 | 74 |
|
74 | 75 | @classmethod |
75 | 76 | def from_config(cls, config: Dict[str, Any]) -> "Compose": |
76 | 77 | """Creates processor from a config dict. |
77 | | -
|
78 | | - Args: |
79 | | - config (Dict[str, Any]): Configuration dictionary. |
80 | | -
|
81 | | - Valid keys: |
82 | | - - do_resize |
83 | | - - size |
84 | | - - do_center_crop |
85 | | - - crop_size |
86 | | - - do_rescale |
87 | | - - rescale_factor |
88 | | - - do_normalize |
89 | | - - image_mean |
90 | | - - image_std |
91 | | - Valid size keys (nested): |
92 | | - - {"height", "width"} |
93 | | - - {"shortest_edge"} |
94 | | -
|
95 | | - Returns: |
96 | | - Compose: Image processor. |
| 78 | + Args: |
| 79 | + config (Dict[str, Any]): Configuration dictionary. |
| 80 | +
|
| 81 | + Valid keys: |
| 82 | + - do_resize |
| 83 | + - size |
| 84 | + - do_center_crop |
| 85 | + - crop_size |
| 86 | + - do_rescale |
| 87 | + - rescale_factor |
| 88 | + - do_normalize |
| 89 | + - image_mean |
| 90 | + - image_std |
| 91 | + Valid size keys (nested): |
| 92 | + - {"height", "width"} |
| 93 | + - {"shortest_edge"} |
| 94 | +
|
| 95 | + Returns: |
| 96 | + Compose: Image processor. |
97 | 97 | """ |
98 | | - transforms = [ConvertToRGB()] |
99 | | - if config.get("do_resize", False): |
100 | | - size = config["size"] |
101 | | - if "shortest_edge" in size: |
102 | | - size = size["shortest_edge"] |
103 | | - elif "height" in size and "width" in size: |
104 | | - size = (size["height"], size["width"]) |
105 | | - else: |
106 | | - raise ValueError( |
107 | | - "Size must contain either 'shortest_edge' or 'height' and 'width'." |
108 | | - ) |
109 | | - transforms.append( |
110 | | - Resize(size=size, resample=config.get("resample", Image.Resampling.BICUBIC)) |
111 | | - ) |
112 | | - if config.get("do_center_crop", False): |
113 | | - crop_size = config["crop_size"] |
114 | | - if isinstance(crop_size, int): |
115 | | - crop_size = (crop_size, crop_size) |
116 | | - elif isinstance(crop_size, dict): |
117 | | - crop_size = (crop_size["height"], crop_size["width"]) |
| 98 | + transforms = [] |
| 99 | + cls._get_convert_to_rgb(transforms, config) |
| 100 | + cls._get_resize(transforms, config) |
| 101 | + cls._get_center_crop(transforms, config) |
| 102 | + cls._get_pil2ndarray(transforms, config) |
| 103 | + cls._get_rescale(transforms, config) |
| 104 | + cls._get_normalize(transforms, config) |
| 105 | + return cls(transforms=transforms) |
| 106 | + |
| 107 | + @staticmethod |
| 108 | + def _get_convert_to_rgb(transforms: List[Transform], config: Dict[str, Any]): |
| 109 | + transforms.append(ConvertToRGB()) |
| 110 | + |
| 111 | + @staticmethod |
| 112 | + def _get_resize(transforms: List[Transform], config: Dict[str, Any]): |
| 113 | + mode = config.get('image_processor_type', 'CLIPImageProcessor') |
| 114 | + if mode == 'CLIPImageProcessor': |
| 115 | + if config.get("do_resize", False): |
| 116 | + size = config["size"] |
| 117 | + if "shortest_edge" in size: |
| 118 | + size = size["shortest_edge"] |
| 119 | + elif "height" in size and "width" in size: |
| 120 | + size = (size["height"], size["width"]) |
| 121 | + else: |
| 122 | + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") |
| 123 | + transforms.append(Resize(size=size, resample=config.get("resample", Image.Resampling.BICUBIC))) |
| 124 | + elif mode == 'ConvNextFeatureExtractor': |
| 125 | + if 'size' in config and "shortest_edge" not in config['size']: |
| 126 | + raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {config['size'].keys()}") |
| 127 | + shortest_edge = config['size']["shortest_edge"] |
| 128 | + crop_pct = config.get("crop_pct", 0.875) |
| 129 | + if shortest_edge < 384: |
| 130 | + # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct |
| 131 | + resize_shortest_edge = int(shortest_edge / crop_pct) |
| 132 | + transforms.append(Resize(size=resize_shortest_edge, resample=config.get("resample", Image.Resampling.BICUBIC))) |
| 133 | + transforms.append(CenterCrop(size=(shortest_edge, shortest_edge))) |
118 | 134 | else: |
119 | | - raise ValueError(f"Invalid crop size: {crop_size}") |
120 | | - transforms.append(CenterCrop(size=crop_size)) |
| 135 | + transforms.append(Resize(size=(shortest_edge, shortest_edge), resample=config.get("resample", Image.Resampling.BICUBIC))) |
| 136 | + |
| 137 | + @staticmethod |
| 138 | + def _get_center_crop(transforms: List[Transform], config: Dict[str, Any]): |
| 139 | + mode = config.get('image_processor_type', 'CLIPImageProcessor') |
| 140 | + if mode == 'CLIPImageProcessor': |
| 141 | + if config.get("do_center_crop", False): |
| 142 | + crop_size = config["crop_size"] |
| 143 | + if isinstance(crop_size, int): |
| 144 | + crop_size = (crop_size, crop_size) |
| 145 | + elif isinstance(crop_size, dict): |
| 146 | + crop_size = (crop_size["height"], crop_size["width"]) |
| 147 | + else: |
| 148 | + raise ValueError(f"Invalid crop size: {crop_size}") |
| 149 | + transforms.append(CenterCrop(size=crop_size)) |
| 150 | + elif mode == 'ConvNextFeatureExtractor': |
| 151 | + pass |
| 152 | + else: |
| 153 | + raise ValueError(f"Preprocessor {mode} is not supported") |
| 154 | + |
| 155 | + @staticmethod |
| 156 | + def _get_pil2ndarray(transforms: List[Transform], config: Dict[str, Any]): |
| 157 | + transforms.append(PILtoNDarray()) |
| 158 | + |
| 159 | + @staticmethod |
| 160 | + def _get_rescale(transforms: List[Transform], config: Dict[str, Any]): |
121 | 161 | if config.get("do_rescale", True): |
122 | 162 | rescale_factor = config.get("rescale_factor", 1 / 255) |
123 | 163 | transforms.append(Rescale(scale=rescale_factor)) |
| 164 | + |
| 165 | + @staticmethod |
| 166 | + def _get_normalize(transforms: List[Transform], config: Dict[str, Any]): |
124 | 167 | if config.get("do_normalize", False): |
125 | 168 | transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"])) |
126 | | - return cls(transforms=transforms) |
|
0 commit comments