|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +import cv2 |
| 17 | +import time |
| 18 | +import sys |
| 19 | +import argparse |
| 20 | +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) |
| 21 | + |
| 22 | +import paddle |
| 23 | +import paddle.nn.functional as F |
| 24 | +import numpy as np |
| 25 | +from PIL import Image, ImageDraw |
| 26 | +import matplotlib.pyplot as plt |
| 27 | + |
| 28 | +from segment_anything import sam_model_registry, SamAutomaticMaskGenerator |
| 29 | +from segment_anything.modeling.clip_paddle import build_clip_model, _transform |
| 30 | +from segment_anything.utils.sample_tokenizer import tokenize |
| 31 | +from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list |
| 32 | + |
| 33 | +ID_PHOTO_IMAGE_DEMO = "./examples/cityscapes_demo.png" |
| 34 | +CACHE_DIR = ".temp" |
| 35 | +model_link = { |
| 36 | + 'vit_h': |
| 37 | + "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams", |
| 38 | + 'vit_l': |
| 39 | + "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams", |
| 40 | + 'vit_b': |
| 41 | + "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams", |
| 42 | + 'clip_b_32': |
| 43 | + "https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams" |
| 44 | +} |
| 45 | + |
| 46 | +parser = argparse.ArgumentParser(description=( |
| 47 | + "Runs automatic mask generation on an input image or directory of images, " |
| 48 | + "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " |
| 49 | + "as well as pycocotools if saving in RLE format.")) |
| 50 | + |
| 51 | +parser.add_argument( |
| 52 | + "--model-type", |
| 53 | + type=str, |
| 54 | + default="vit_h", |
| 55 | + required=True, |
| 56 | + help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']", ) |
| 57 | + |
| 58 | + |
| 59 | +def download(img): |
| 60 | + if not os.path.exists(CACHE_DIR): |
| 61 | + os.makedirs(CACHE_DIR) |
| 62 | + while True: |
| 63 | + name = str(int(time.time())) |
| 64 | + tmp_name = os.path.join(CACHE_DIR, name + '.jpg') |
| 65 | + if not os.path.exists(tmp_name): |
| 66 | + break |
| 67 | + else: |
| 68 | + time.sleep(1) |
| 69 | + img.save(tmp_name, 'png') |
| 70 | + return tmp_name |
| 71 | + |
| 72 | + |
| 73 | +def segment_image(image, segment_mask): |
| 74 | + image_array = np.array(image) |
| 75 | + gray_image = Image.new("RGB", image.size, (128, 128, 128)) |
| 76 | + segmented_image_array = np.zeros_like(image_array) |
| 77 | + segmented_image_array[segment_mask] = image_array[segment_mask] |
| 78 | + segmented_image = Image.fromarray(segmented_image_array) |
| 79 | + transparency = np.zeros_like(segment_mask, dtype=np.uint8) |
| 80 | + transparency[segment_mask] = 255 |
| 81 | + transparency_image = Image.fromarray(transparency, mode='L') |
| 82 | + gray_image.paste(segmented_image, mask=transparency_image) |
| 83 | + return gray_image |
| 84 | + |
| 85 | + |
| 86 | +def image_text_match(cropped_objects, text_query): |
| 87 | + transformed_images = [transform(image) for image in cropped_objects] |
| 88 | + tokenized_text = tokenize([text_query]) |
| 89 | + batch_images = paddle.stack(transformed_images) |
| 90 | + image_features = model.encode_image(batch_images) |
| 91 | + print("encode_image done!") |
| 92 | + text_features = model.encode_text(tokenized_text) |
| 93 | + print("encode_text done!") |
| 94 | + image_features /= image_features.norm(axis=-1, keepdim=True) |
| 95 | + text_features /= text_features.norm(axis=-1, keepdim=True) |
| 96 | + probs = 100. * image_features @text_features.T |
| 97 | + return F.softmax(probs[:, 0], axis=0) |
| 98 | + |
| 99 | + |
| 100 | +def masks2pseudomap(masks): |
| 101 | + result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255 |
| 102 | + for i, mask_data in enumerate(masks): |
| 103 | + result[mask_data["segmentation"] == 1] = i + 1 |
| 104 | + pred_result = result |
| 105 | + result = get_pseudo_color_map(result) |
| 106 | + return pred_result, result |
| 107 | + |
| 108 | + |
| 109 | +def visualize(image, result, color_map, weight=0.6): |
| 110 | + """ |
| 111 | + Convert predict result to color image, and save added image. |
| 112 | +
|
| 113 | + Args: |
| 114 | + image (str): The path of origin image. |
| 115 | + result (np.ndarray): The predict result of image. |
| 116 | + color_map (list): The color used to save the prediction results. |
| 117 | + save_dir (str): The directory for saving visual image. Default: None. |
| 118 | + weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6 |
| 119 | +
|
| 120 | + Returns: |
| 121 | + vis_result (np.ndarray): If `save_dir` is None, return the visualized result. |
| 122 | + """ |
| 123 | + |
| 124 | + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] |
| 125 | + color_map = np.array(color_map).astype("uint8") |
| 126 | + # Use OpenCV LUT for color mapping |
| 127 | + c1 = cv2.LUT(result, color_map[:, 0]) |
| 128 | + c2 = cv2.LUT(result, color_map[:, 1]) |
| 129 | + c3 = cv2.LUT(result, color_map[:, 2]) |
| 130 | + pseudo_img = np.dstack((c3, c2, c1)) |
| 131 | + |
| 132 | + vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0) |
| 133 | + return vis_result |
| 134 | + |
| 135 | + |
| 136 | +def get_id_photo_output(image, text): |
| 137 | + """ |
| 138 | + Get the special size and background photo. |
| 139 | +
|
| 140 | + Args: |
| 141 | + img(numpy:ndarray): The image array. |
| 142 | + size(str): The size user specified. |
| 143 | + bg(str): The background color user specified. |
| 144 | + download_size(str): The size for image saving. |
| 145 | +
|
| 146 | + """ |
| 147 | + image_ori = image.copy() |
| 148 | + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| 149 | + masks = mask_generator.generate(image) |
| 150 | + pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image |
| 151 | + added_pseudo_map = visualize( |
| 152 | + image, pred_result, color_map=get_color_map_list(256)) |
| 153 | + cropped_objects = [] |
| 154 | + image_pil = Image.fromarray(image) |
| 155 | + for mask in masks: |
| 156 | + bbox = [ |
| 157 | + mask["bbox"][0], mask["bbox"][1], mask["bbox"][0] + mask["bbox"][2], |
| 158 | + mask["bbox"][1] + mask["bbox"][3] |
| 159 | + ] |
| 160 | + cropped_objects.append( |
| 161 | + segment_image(image_pil, mask["segmentation"]).crop(bbox)) |
| 162 | + |
| 163 | + scores = image_text_match(cropped_objects, str(text)) |
| 164 | + text_matching_masks = [] |
| 165 | + for idx, score in enumerate(scores): |
| 166 | + if score < 0.05: |
| 167 | + continue |
| 168 | + text_matching_mask = Image.fromarray( |
| 169 | + masks[idx]["segmentation"].astype('uint8') * 255) |
| 170 | + text_matching_masks.append(text_matching_mask) |
| 171 | + |
| 172 | + image_pil_ori = Image.fromarray(image_ori) |
| 173 | + alpha_image = Image.new('RGBA', image_pil_ori.size, (0, 0, 0, 0)) |
| 174 | + alpha_color = (255, 0, 0, 180) |
| 175 | + |
| 176 | + draw = ImageDraw.Draw(alpha_image) |
| 177 | + for text_matching_mask in text_matching_masks: |
| 178 | + draw.bitmap((0, 0), text_matching_mask, fill=alpha_color) |
| 179 | + |
| 180 | + result_image = Image.alpha_composite( |
| 181 | + image_pil_ori.convert('RGBA'), alpha_image) |
| 182 | + res_download = download(result_image) |
| 183 | + return result_image, added_pseudo_map, res_download |
| 184 | + |
| 185 | + |
| 186 | +def gradio_display(): |
| 187 | + import gradio as gr |
| 188 | + examples_sam = [["./examples/cityscapes_demo.png", "a photo of car"], |
| 189 | + ["examples/dog.jpg", "dog"], |
| 190 | + ["examples/zixingche.jpeg", "kid"]] |
| 191 | + |
| 192 | + demo_mask_sam = gr.Interface( |
| 193 | + fn=get_id_photo_output, |
| 194 | + inputs=[ |
| 195 | + gr.Image( |
| 196 | + value=ID_PHOTO_IMAGE_DEMO, |
| 197 | + label="Input image").style(height=400), gr.inputs.Textbox( |
| 198 | + lines=3, |
| 199 | + placeholder=None, |
| 200 | + default="a photo of car", |
| 201 | + label='🔥 Input text prompt 🔥', |
| 202 | + optional=False) |
| 203 | + ], |
| 204 | + outputs=[ |
| 205 | + gr.Image( |
| 206 | + label="Output based on text", |
| 207 | + interactive=False).style(height=300), gr.Image( |
| 208 | + label="Output mask", interactive=False).style(height=300) |
| 209 | + ], |
| 210 | + examples=examples_sam, |
| 211 | + description="<p> \ |
| 212 | + <strong>SAM+CLIP: Text prompt for segmentation. </strong> <br>\ |
| 213 | + Choose an example below; Or, upload by yourself: <br>\ |
| 214 | + 1. Upload images to be tested to 'input image'. 2. Input a text prompt to 'input text prompt' and click 'submit'</strong>. <br>\ |
| 215 | + </p>", |
| 216 | + cache_examples=False, |
| 217 | + allow_flagging="never", ) |
| 218 | + |
| 219 | + demo = gr.TabbedInterface( |
| 220 | + [demo_mask_sam, ], ['SAM+CLIP(Text to Segment)'], |
| 221 | + title=" 🔥 Text to Segment Anything with PaddleSeg 🔥") |
| 222 | + demo.launch( |
| 223 | + server_name="0.0.0.0", enable_queue=False, server_port=8078, share=True) |
| 224 | + |
| 225 | + |
| 226 | +args = parser.parse_args() |
| 227 | +print("Loading model...") |
| 228 | + |
| 229 | +if paddle.is_compiled_with_cuda(): |
| 230 | + paddle.set_device("gpu") |
| 231 | +else: |
| 232 | + paddle.set_device("cpu") |
| 233 | + |
| 234 | +sam = sam_model_registry[args.model_type]( |
| 235 | + checkpoint=model_link[args.model_type]) |
| 236 | +mask_generator = SamAutomaticMaskGenerator(sam) |
| 237 | + |
| 238 | +model, transform = build_clip_model(model_link["clip_b_32"]) |
| 239 | +gradio_display() |
0 commit comments