Skip to content

Commit 301a0f5

Browse files
authored
Merge pull request #1367 from rubencardenes/fix_fill_holes_and_remove_small_masks
fix(utils): fixes a bug in fill_holes_and_remove_small_masks Added test
2 parents 7a0f2c4 + 8b3930e commit 301a0f5

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

cellpose/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,15 @@ def fill_holes_and_remove_small_masks(masks, min_size=15):
642642

643643
# Filter small masks
644644
if min_size > 0:
645-
counts = fastremap.unique(masks, return_counts=True)[1][1:]
646-
masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
645+
uniq, counts = fastremap.unique(masks, return_counts=True)
646+
# uniq[0] is background (0), so uniq[1:] are the actual mask labels
647+
# counts[1:] are the corresponding counts
648+
small_mask_indices = np.nonzero(counts[1:] < min_size)[0]
649+
# Get the actual label values to remove (not indices)
650+
labels_to_remove = uniq[1:][small_mask_indices]
651+
masks = fastremap.mask(masks, labels_to_remove)
647652
fastremap.renumber(masks, in_place=True)
648-
653+
649654
slices = find_objects(masks)
650655
j = 0
651656
for i, slc in enumerate(slices):
@@ -656,8 +661,11 @@ def fill_holes_and_remove_small_masks(masks, min_size=15):
656661
j += 1
657662

658663
if min_size > 0:
659-
counts = fastremap.unique(masks, return_counts=True)[1][1:]
660-
masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
664+
uniq, counts = fastremap.unique(masks, return_counts=True)
665+
small_mask_indices = np.nonzero(counts[1:] < min_size)[0]
666+
labels_to_remove = uniq[1:][small_mask_indices]
667+
masks = fastremap.mask(masks, labels_to_remove)
661668
fastremap.renumber(masks, in_place=True)
662669

670+
663671
return masks

tests/test_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
from cellpose.utils import fill_holes_and_remove_small_masks
3+
import fastremap
4+
5+
6+
def test_fill_holes_and_remove_small_masks():
7+
# make a 2-channel mask with holes and small objects. The first channel is the
8+
# "ground truth" without holes or small objects, the second channel needs to be cleaned.
9+
masks = np.zeros((2, 100, 100), dtype=np.uint16)
10+
masks[:, 10:30, 10:30] = 1 # object 1
11+
masks[1, 15:25, 15:25] = 0 # hole in object 1
12+
masks[1, 40:45, 40:45] = 2 # small object 2
13+
masks[:, 60:90, 60:90] = 4 # object 4 (skip 3)
14+
masks[1, 70:80, 70:80] = 0 # hole in object 4
15+
masks[1, 10:15, 80:82] = 5 # small object 4
16+
17+
# apply function
18+
min_size = 30
19+
masks_cleaned = fill_holes_and_remove_small_masks(masks[1], min_size=min_size)
20+
21+
gt_masks = fastremap.renumber(masks[0], in_place=False)[0]
22+
23+
assert (gt_masks == masks_cleaned).all()

0 commit comments

Comments
 (0)