@@ -165,6 +165,20 @@ def save_mask(self, mask: np.ndarray, output_path: str):
165
165
mask_image = (mask * 255 ).astype (np .uint8 )
166
166
cv2 .imwrite (output_path , mask_image )
167
167
168
+ def apply_mask_to_image (self , image_path , mask ):
169
+ image = cv2 .imread (image_path )
170
+ mask_binary = mask .astype (np .uint8 ) * 255
171
+ segmented = cv2 .bitwise_and (image , image , mask = mask_binary )
172
+
173
+ # Create white background for transparency
174
+ white_background = np .ones_like (image ) * 255
175
+ background = cv2 .bitwise_and (
176
+ white_background , white_background , mask = ~ mask_binary
177
+ )
178
+ # Combine segmented image with white background
179
+ final_image = cv2 .add (segmented , background )
180
+ return final_image
181
+
168
182
169
183
class PointSelector :
170
184
def __init__ (self , image_path , max_points = 2 ):
@@ -300,7 +314,13 @@ def main():
300
314
mask = sam .get_mask (original_size )
301
315
302
316
if mask is not None :
317
+ # Save the mask
303
318
sam .save_mask (mask , "output_mask.png" )
319
+
320
+ # Save segmented image
321
+ segmented_image = sam .apply_mask_to_image (image_path , mask )
322
+ cv2 .imwrite ("output_segmented.png" , segmented_image )
323
+
304
324
cv2 .destroyAllWindows ()
305
325
306
326
except Exception as e :
0 commit comments