-
Notifications
You must be signed in to change notification settings - Fork 2k
/
Copy pathimage_segmentation_transformers.py
190 lines (152 loc) · 5.38 KB
/
image_segmentation_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# %% [markdown]
# # Set up environment
# %%
!pip install transformers
# %%
from IPython.display import clear_output
# !pip3 install transformers
clear_output()
# %%
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import pipeline, SegformerImageProcessor, SegformerForSemanticSegmentation
import requests
from PIL import Image
import urllib.parse as parse
import os
# %%
# a function to determine whether a string is a URL or not
def is_url(string):
try:
result = parse.urlparse(string)
return all([result.scheme, result.netloc, result.path])
except:
return False
# a function to load an image
def load_image(image_path):
"""Helper function to load images from their URLs or paths."""
if is_url(image_path):
return Image.open(requests.get(image_path, stream=True).raw)
elif os.path.exists(image_path):
return Image.open(image_path)
# %% [markdown]
# # Load Image
# %%
img_path = "https://shorthaircatbreeds.com/wp-content/uploads/2020/06/Urban-cat-crossing-a-road-300x180.jpg"
image = load_image(img_path)
# %%
image
# %%
# convert PIL Image to pytorch tensors
transform = transforms.ToTensor()
image_tensor = image.convert("RGB")
image_tensor = transform(image_tensor)
image_tensor.shape
# %% [markdown]
# # Helper functions
# %%
def color_palette():
"""Color palette to map each class to its corresponding color."""
return [[0, 128, 128],
[255, 170, 0],
[161, 19, 46],
[118, 171, 47],
[255, 255, 0],
[84, 170, 127],
[170, 84, 127],
[33, 138, 200],
[255, 84, 0],
[255, 140, 208]]
# %%
def overlay_segments(image, seg_mask):
"""Return different segments predicted by the model overlaid on image."""
H, W = seg_mask.shape
image_mask = np.zeros((H, W, 3), dtype=np.uint8)
colors = np.array(color_palette())
# convert to a pytorch tensor if seg_mask is not one already
seg_mask = seg_mask if torch.is_tensor(seg_mask) else torch.tensor(seg_mask)
unique_labels = torch.unique(seg_mask)
# map each segment label to a unique color
for i, label in enumerate(unique_labels):
image_mask[seg_mask == label.item(), :] = colors[i]
image = np.array(image)
# percentage of original image in the final overlaid iamge
img_weight = 0.5
# overlay input image and the generated segment mask
img = img_weight * np.array(image) * 255 + (1 - img_weight) * image_mask
return img.astype(np.uint8)
# %%
def replace_label(mask, label):
"""Replace the segment masks values with label."""
mask = np.array(mask)
mask[mask == 255] = label
return mask
# %% [markdown]
# # Image segmentation using Hugging Face Pipeline
# %%
# load the entire image segmentation pipeline
img_segmentation_pipeline = pipeline('image-segmentation',
model="nvidia/segformer-b5-finetuned-ade-640-640")
# %%
output = img_segmentation_pipeline(image)
output
# %%
output[0]['mask']
# %%
output[2]['mask']
# %%
# load the feature extractor (to preprocess images) and the model (to get outputs)
W, H = image.size
segmentation_mask = np.zeros((H, W), dtype=np.uint8)
for i in range(len(output)):
segmentation_mask += replace_label(output[i]['mask'], i)
# %%
# overlay the predicted segmentation masks on the original image
segmented_img = overlay_segments(image_tensor.permute(1, 2, 0), segmentation_mask)
# convert to PIL Image
Image.fromarray(segmented_img)
# %% [markdown]
# # Image segmentation using custom Hugging Face models
# %%
# load the feature extractor (to preprocess images) and the model (to get outputs)
feature_extractor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640")
# %%
def to_tensor(image):
"""Convert PIL Image to pytorch tensor."""
transform = transforms.ToTensor()
image_tensor = image.convert("RGB")
image_tensor = transform(image_tensor)
return image_tensor
# a function that takes an image and return the segmented image
def get_segmented_image(model, feature_extractor, image_path):
"""Return the predicted segmentation mask for the input image."""
# load the image
image = load_image(image_path)
# preprocess input
inputs = feature_extractor(images=image, return_tensors="pt")
# convert to pytorch tensor
image_tensor = to_tensor(image)
# pass the processed input to the model
outputs = model(**inputs)
print("outputs.logits.shape:", outputs.logits.shape)
# interpolate output logits to the same shape as the input image
upsampled_logits = F.interpolate(
outputs.logits, # tensor to be interpolated
size=image_tensor.shape[1:], # output size we want
mode='bilinear', # do bilinear interpolation
align_corners=False)
# get the class with max probabilities
segmentation_mask = upsampled_logits.argmax(dim=1)[0]
print(f"{segmentation_mask.shape=}")
# get the segmented image
segmented_img = overlay_segments(image_tensor.permute(1, 2, 0), segmentation_mask)
# convert to PIL Image
return Image.fromarray(segmented_img)
# %%
get_segmented_image(model, feature_extractor, "https://shorthaircatbreeds.com/wp-content/uploads/2020/06/Urban-cat-crossing-a-road-300x180.jpg")
# %%
get_segmented_image(model, feature_extractor, "http://images.cocodataset.org/test-stuff2017/000000000001.jpg")
# %%