Skip to content

Commit bfe9c7e

Browse files
authored
[New features] text prompt to SAM (#3186)
1 parent ef78d24 commit bfe9c7e

File tree

10 files changed

+798
-36
lines changed

10 files changed

+798
-36
lines changed

contrib/SegmentAnything/README.md

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
# Segment Anything with PaddleSeg
22

3-
## Reference
4-
5-
> Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollár, Ross Girshick. [Segment Anything](https://ai.facebook.com/research/publications/segment-anything/).
6-
73

84
## Contents
95
1. Overview
106
2. Performance
117
3. Try it by yourself with one line of code
8+
4. Reference
9+
1210

1311
## <img src="https://user-images.githubusercontent.com/34859558/190043857-bfbdaf8b-d2dc-4fff-81c7-e0aac50851f9.png" width="25"/> Overview
1412

15-
We implemente the segment anything with the PaddlePaddle framework. **Segment Anything Model (SAM)** is a new task, model, and dataset for image segmentation. It can produce high quality object masks from different types of prompts including points, boxes, masks and text. Further, SAM can generate masks for all objects in whole image. It built a largest segmentation [dataset](https://segment-anything.com/dataset/index.html) to date (by far), with over 1 billion masks on 11M licensed and privacy respecting images. SAM has impressive zero-shot performance on a variety of tasks, even often competitive with or even superior to prior fully supervised results.
13+
We implemente the segment anything with the PaddlePaddle framework. **Segment Anything Model (SAM)** is a new task, model, and dataset for image segmentation. It built a largest segmentation [dataset](https://segment-anything.com/dataset/index.html) to date (by far), with over 1 billion masks on 11M licensed and privacy respecting images. Further, SAM can produce high quality object masks from different types of prompts including points, boxes, masks and text. SAM has impressive zero-shot performance on a variety of tasks, even often competitive with or even superior to prior fully supervised results. However, the SAM model based on text prompt is not released at the moment. Therefore, we use a combination of **SAM** and **CLIP** to calculate the similarity between the output masks and text prompt. In this way, you can use **text prompt** to segment anything. In addition, we also implement SAM that can generate masks for all objects in whole image.
14+
1615

17-
We provide the pretrained model parameters of PaddlePaddle format, including [vit_b](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams), [vit_l](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams) and [vit_h](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams).
16+
We provide the pretrained model parameters of PaddlePaddle format, including [vit_b](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams), [vit_l](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams) and [vit_h](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams). For text prompt, we also provide the [CLIP_ViT_B](https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams) model parameters of PaddlePaddle format.
1817

1918
## <img src="https://user-images.githubusercontent.com/34859558/190044217-8f6befc2-7f20-473d-b356-148e06265205.png" width="25"/> Performance
2019

2120
<div align="center">
22-
<img src="https://github.com/Sunting78/images/blob/master/sam_new.gif" width="1000" />
21+
<img src="https://user-images.githubusercontent.com/18344247/232466911-f8d1c016-2eb2-46aa-94e2-3ec435f38502.gif" width="1000" />
2322
</div>
2423

2524

@@ -33,44 +32,51 @@ We provide the pretrained model parameters of PaddlePaddle format, including [vi
3332
git clone https://github.com/PaddlePaddle/PaddleSeg.git
3433
cd PaddleSeg
3534
pip install -r requirements.txt
35+
pip install ftfy regex
36+
cd contrib/SegmentAnything/
3637
```
37-
* Download the example image to ```contrib/SegmentAnything/examples```, and the file structure is as following:
38+
* Download the example image to ```contrib/SegmentAnything/examples``` and the vocab to ```contrib/SegmentAnything/```
3839
```bash
3940
wget https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png
41+
wget https://bj.bcebos.com/paddleseg/dygraph/bpe_vocab_16e6/bpe_simple_vocab_16e6.txt.gz
4042
```
43+
Then, the file structure is as following:
4144

4245
```
4346
PaddleSeg/contrib
4447
├── SegmentAnything
4548
│ ├── examples
4649
│ │ └── cityscapes_demo.png
4750
│ ├── segment_anything
48-
│ └── scripts
51+
│ ├── scripts
52+
│ └── bpe_simple_vocab_16e6.txt.gz
4953

5054
```
55+
### 2. Segment Anything on webpage.
5156

52-
### 2. Segment the whole image on webpage.
5357
In this step, we start a gradio service with the following scrip on local machine and you can try out our project with your own images.
58+
Based on this service, You can experience the ability to **segment the whole image** and **segment the object based on text prompts**.
5459

5560
1. Run the following script:
5661
```bash
57-
python scripts/amg_paddle.py --model-type [vit_l/vit_b/vit_h] # default is vit_h
58-
62+
python scripts/text_to_sam_clip.py --model-type [vit_l/vit_b/vit_h] # default is vit_h
5963
```
6064
Note:
61-
* There are three model options for you, vit_b, vit_l and vit_h, represent vit_base, vit_large and vit_huge. Large model is more accurate and also slower. You can choose the model size based on your device.
62-
* The test result shows that vit_h needs 16G video memory and needs around 10s to infer an image on V100.
63-
64-
2. Open the webpage on your localhost: ```http://0.0.0.0:8017```
65+
* There are three SAM model options for you, `vit_b`, `vit_l` and `vit_h`, represent vit_base, vit_large and vit_huge. Large model is more accurate but slower. You can choose the suitable model size based on your device.
66+
* We support `CLIP Vit-B` model for extracting text and image features.
67+
* `SAM vit_h` needs 16G memory and costs around 10s to infer an image on V100.
6568

69+
2. Open the webpage on your localhost: ```http://0.0.0.0:8078```
6670
3. Try it out by clear and upload the test image! Our example looks like:
6771

6872
<div align="center">
69-
<img src="https://user-images.githubusercontent.com/34859558/230873989-9597527e-bef6-47ce-988b-977198794d75.jpg" width = "1000" />
73+
<img src="https://user-images.githubusercontent.com/18344247/232427677-a7f913df-4abf-46ce-be2c-e37cbd495105.png" width = "1000" />
7074
</div>
7175

72-
### 3. Segment the object with prompts
73-
You can run the following commands to produce masks from different types of prompts including points, boxes, and masks, as follow:
76+
77+
### 3. Segment the object with point or box prompts
78+
79+
You can run the following commands to produce masks from different types of prompts including points and boxes, as follow:
7480

7581

7682
1. Box prompt
@@ -84,10 +90,9 @@ python scripts/promt_predict.py --input_path xxx.png --box_prompt 1050 370 1500
8490
python scripts/promt_predict.py --input_path xxx.png --point_prompt 1200 450 --model-type [vit_l/vit_b/vit_h] # default is vit_h
8591
```
8692

87-
3. Mask prompt
88-
```bash
89-
python scripts/promt_predict.py --input_path xxx.png --mask_prompt xxx.png --model-type [vit_l/vit_b/vit_h] # default is vit_h
90-
```
9193

92-
Note:
93-
* mask_prompt is the path of a binary image.
94+
## Reference
95+
96+
> Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollár, Ross Girshick. [Segment Anything](https://ai.facebook.com/research/publications/segment-anything/).
97+
98+
> Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever Proceedings of the 38th International Conference on Machine Learning, PMLR 139:8748-8763, 2021. [CLIP](https://github.com/openai/CLIP)
97.5 KB
Loading
48.9 KB
Loading

contrib/SegmentAnything/scripts/promt_predict.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
def get_args():
4141
parser = argparse.ArgumentParser(
42-
description='Segment image with point promp, box or mask')
42+
description='Segment image with point promp or box')
4343
# Parameters
4444
parser.add_argument(
4545
'--input_path', type=str, required=True, help='The directory of image.')
@@ -61,8 +61,6 @@ def get_args():
6161
nargs='+',
6262
default=None,
6363
help='box promt format as xyxy.')
64-
parser.add_argument(
65-
'--mask_prompt', type=str, default=None, help='The path of mask.')
6664
parser.add_argument(
6765
'--output_path',
6866
type=str,
@@ -88,18 +86,14 @@ def main(args):
8886
paddle.set_device("cpu")
8987
input_path = args.input_path
9088
output_path = args.output_path
91-
point, box, mask_path = args.point_prompt, args.box_prompt, args.mask_prompt
89+
point, box = args.point_prompt, args.box_prompt
9290
if point is not None:
9391
point = np.array([point])
9492
input_label = np.array([1])
9593
else:
9694
input_label = None
9795
if box is not None:
9896
box = np.array([[box[0], box[1]], [box[2], box[3]]])
99-
if mask_path is not None:
100-
mask = cv2.imread(mask_path, -1)
101-
else:
102-
mask = None
10397

10498
image = cv2.imread(input_path)
10599
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -112,7 +106,6 @@ def main(args):
112106
point_coords=point,
113107
point_labels=input_label,
114108
box=box,
115-
mask_input=mask,
116109
multimask_output=True, )
117110

118111
plt.figure(figsize=(10, 10))
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import cv2
17+
import time
18+
import sys
19+
import argparse
20+
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
21+
22+
import paddle
23+
import paddle.nn.functional as F
24+
import numpy as np
25+
from PIL import Image, ImageDraw
26+
import matplotlib.pyplot as plt
27+
28+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
29+
from segment_anything.modeling.clip_paddle import build_clip_model, _transform
30+
from segment_anything.utils.sample_tokenizer import tokenize
31+
from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list
32+
33+
ID_PHOTO_IMAGE_DEMO = "./examples/cityscapes_demo.png"
34+
CACHE_DIR = ".temp"
35+
model_link = {
36+
'vit_h':
37+
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams",
38+
'vit_l':
39+
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
40+
'vit_b':
41+
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
42+
'clip_b_32':
43+
"https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams"
44+
}
45+
46+
parser = argparse.ArgumentParser(description=(
47+
"Runs automatic mask generation on an input image or directory of images, "
48+
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
49+
"as well as pycocotools if saving in RLE format."))
50+
51+
parser.add_argument(
52+
"--model-type",
53+
type=str,
54+
default="vit_h",
55+
required=True,
56+
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']", )
57+
58+
59+
def download(img):
60+
if not os.path.exists(CACHE_DIR):
61+
os.makedirs(CACHE_DIR)
62+
while True:
63+
name = str(int(time.time()))
64+
tmp_name = os.path.join(CACHE_DIR, name + '.jpg')
65+
if not os.path.exists(tmp_name):
66+
break
67+
else:
68+
time.sleep(1)
69+
img.save(tmp_name, 'png')
70+
return tmp_name
71+
72+
73+
def segment_image(image, segment_mask):
74+
image_array = np.array(image)
75+
gray_image = Image.new("RGB", image.size, (128, 128, 128))
76+
segmented_image_array = np.zeros_like(image_array)
77+
segmented_image_array[segment_mask] = image_array[segment_mask]
78+
segmented_image = Image.fromarray(segmented_image_array)
79+
transparency = np.zeros_like(segment_mask, dtype=np.uint8)
80+
transparency[segment_mask] = 255
81+
transparency_image = Image.fromarray(transparency, mode='L')
82+
gray_image.paste(segmented_image, mask=transparency_image)
83+
return gray_image
84+
85+
86+
def image_text_match(cropped_objects, text_query):
87+
transformed_images = [transform(image) for image in cropped_objects]
88+
tokenized_text = tokenize([text_query])
89+
batch_images = paddle.stack(transformed_images)
90+
image_features = model.encode_image(batch_images)
91+
print("encode_image done!")
92+
text_features = model.encode_text(tokenized_text)
93+
print("encode_text done!")
94+
image_features /= image_features.norm(axis=-1, keepdim=True)
95+
text_features /= text_features.norm(axis=-1, keepdim=True)
96+
probs = 100. * image_features @text_features.T
97+
return F.softmax(probs[:, 0], axis=0)
98+
99+
100+
def masks2pseudomap(masks):
101+
result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255
102+
for i, mask_data in enumerate(masks):
103+
result[mask_data["segmentation"] == 1] = i + 1
104+
pred_result = result
105+
result = get_pseudo_color_map(result)
106+
return pred_result, result
107+
108+
109+
def visualize(image, result, color_map, weight=0.6):
110+
"""
111+
Convert predict result to color image, and save added image.
112+
113+
Args:
114+
image (str): The path of origin image.
115+
result (np.ndarray): The predict result of image.
116+
color_map (list): The color used to save the prediction results.
117+
save_dir (str): The directory for saving visual image. Default: None.
118+
weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
119+
120+
Returns:
121+
vis_result (np.ndarray): If `save_dir` is None, return the visualized result.
122+
"""
123+
124+
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
125+
color_map = np.array(color_map).astype("uint8")
126+
# Use OpenCV LUT for color mapping
127+
c1 = cv2.LUT(result, color_map[:, 0])
128+
c2 = cv2.LUT(result, color_map[:, 1])
129+
c3 = cv2.LUT(result, color_map[:, 2])
130+
pseudo_img = np.dstack((c3, c2, c1))
131+
132+
vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0)
133+
return vis_result
134+
135+
136+
def get_id_photo_output(image, text):
137+
"""
138+
Get the special size and background photo.
139+
140+
Args:
141+
img(numpy:ndarray): The image array.
142+
size(str): The size user specified.
143+
bg(str): The background color user specified.
144+
download_size(str): The size for image saving.
145+
146+
"""
147+
image_ori = image.copy()
148+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
149+
masks = mask_generator.generate(image)
150+
pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image
151+
added_pseudo_map = visualize(
152+
image, pred_result, color_map=get_color_map_list(256))
153+
cropped_objects = []
154+
image_pil = Image.fromarray(image)
155+
for mask in masks:
156+
bbox = [
157+
mask["bbox"][0], mask["bbox"][1], mask["bbox"][0] + mask["bbox"][2],
158+
mask["bbox"][1] + mask["bbox"][3]
159+
]
160+
cropped_objects.append(
161+
segment_image(image_pil, mask["segmentation"]).crop(bbox))
162+
163+
scores = image_text_match(cropped_objects, str(text))
164+
text_matching_masks = []
165+
for idx, score in enumerate(scores):
166+
if score < 0.05:
167+
continue
168+
text_matching_mask = Image.fromarray(
169+
masks[idx]["segmentation"].astype('uint8') * 255)
170+
text_matching_masks.append(text_matching_mask)
171+
172+
image_pil_ori = Image.fromarray(image_ori)
173+
alpha_image = Image.new('RGBA', image_pil_ori.size, (0, 0, 0, 0))
174+
alpha_color = (255, 0, 0, 180)
175+
176+
draw = ImageDraw.Draw(alpha_image)
177+
for text_matching_mask in text_matching_masks:
178+
draw.bitmap((0, 0), text_matching_mask, fill=alpha_color)
179+
180+
result_image = Image.alpha_composite(
181+
image_pil_ori.convert('RGBA'), alpha_image)
182+
res_download = download(result_image)
183+
return result_image, added_pseudo_map, res_download
184+
185+
186+
def gradio_display():
187+
import gradio as gr
188+
examples_sam = [["./examples/cityscapes_demo.png", "a photo of car"],
189+
["examples/dog.jpg", "dog"],
190+
["examples/zixingche.jpeg", "kid"]]
191+
192+
demo_mask_sam = gr.Interface(
193+
fn=get_id_photo_output,
194+
inputs=[
195+
gr.Image(
196+
value=ID_PHOTO_IMAGE_DEMO,
197+
label="Input image").style(height=400), gr.inputs.Textbox(
198+
lines=3,
199+
placeholder=None,
200+
default="a photo of car",
201+
label='🔥 Input text prompt 🔥',
202+
optional=False)
203+
],
204+
outputs=[
205+
gr.Image(
206+
label="Output based on text",
207+
interactive=False).style(height=300), gr.Image(
208+
label="Output mask", interactive=False).style(height=300)
209+
],
210+
examples=examples_sam,
211+
description="<p> \
212+
<strong>SAM+CLIP: Text prompt for segmentation. </strong> <br>\
213+
Choose an example below; Or, upload by yourself: <br>\
214+
1. Upload images to be tested to 'input image'. 2. Input a text prompt to 'input text prompt' and click 'submit'</strong>. <br>\
215+
</p>",
216+
cache_examples=False,
217+
allow_flagging="never", )
218+
219+
demo = gr.TabbedInterface(
220+
[demo_mask_sam, ], ['SAM+CLIP(Text to Segment)'],
221+
title=" 🔥 Text to Segment Anything with PaddleSeg 🔥")
222+
demo.launch(
223+
server_name="0.0.0.0", enable_queue=False, server_port=8078, share=True)
224+
225+
226+
args = parser.parse_args()
227+
print("Loading model...")
228+
229+
if paddle.is_compiled_with_cuda():
230+
paddle.set_device("gpu")
231+
else:
232+
paddle.set_device("cpu")
233+
234+
sam = sam_model_registry[args.model_type](
235+
checkpoint=model_link[args.model_type])
236+
mask_generator = SamAutomaticMaskGenerator(sam)
237+
238+
model, transform = build_clip_model(model_link["clip_b_32"])
239+
gradio_display()

0 commit comments

Comments
 (0)