Skip to content

Commit 40a017a

Browse files
committed
add skeleton-aware distance transform
1 parent 5c29c0a commit 40a017a

File tree

1 file changed

+129
-5
lines changed

1 file changed

+129
-5
lines changed

connectomics/data/utils/data_transform.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,26 @@
55
import scipy
66
import numpy as np
77
from scipy.ndimage import distance_transform_edt
8-
from skimage.morphology import remove_small_holes
8+
from skimage.morphology import remove_small_holes, skeletonize
99
from skimage.measure import label as label_cc # avoid namespace conflict
10+
from skimage.filters import gaussian
1011

1112
from .data_misc import get_padsize, array_unpad
1213

1314
__all__ = [
1415
'edt_semantic',
1516
'edt_instance',
17+
'sdt_instance',
1618
'decode_quantize',
1719
]
1820

1921

2022
def edt_semantic(
21-
label: np.ndarray,
22-
mode: str = '2d',
23-
alpha_fore: float = 8.0,
24-
alpha_back: float = 50.0):
23+
label: np.ndarray,
24+
mode: str = '2d',
25+
alpha_fore: float = 8.0,
26+
alpha_back: float = 50.0
27+
):
2528
"""Euclidean distance transform (DT or EDT) for binary semantic mask.
2629
"""
2730
assert mode in ['2d', '3d']
@@ -84,6 +87,35 @@ def edt_instance(label: np.ndarray,
8487
return vol_distance
8588

8689

90+
def sdt_instance(label: np.ndarray,
91+
mode: str = '2d',
92+
quantize: bool = True,
93+
resolution: Tuple[float] = (1.0, 1.0),
94+
padding: bool = True):
95+
"""Skeleton-based distance transform (SDT) for a stack of label images.
96+
97+
Lin, Zudi, et al. "Structure-Preserving Instance Segmentation via Skeleton-Aware
98+
Distance Transform." International Conference on Medical Image Computing and
99+
Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.
100+
"""
101+
assert mode == "2d", "Only 2d skeletonization is currently supported."
102+
103+
vol_distance = []
104+
vol_semantic = []
105+
for i in range(label.shape[0]):
106+
label_img = label[i].copy()
107+
distance, semantic = skeleton_aware_distance_transform(label_img, padding=padding)
108+
vol_distance.append(distance)
109+
vol_semantic.append(semantic)
110+
111+
vol_distance = np.stack(vol_distance, 0)
112+
vol_semantic = np.stack(vol_semantic, 0)
113+
if quantize:
114+
vol_distance = energy_quantize(vol_distance)
115+
116+
return vol_distance
117+
118+
87119
def distance_transform(label: np.ndarray,
88120
bg_value: float = -1.0,
89121
relabel: bool = True,
@@ -135,6 +167,98 @@ def distance_transform(label: np.ndarray,
135167
return distance, semantic
136168

137169

170+
def smooth_edge(binary, smooth_sigma: float = 2.0, smooth_threshold: float = 0.5):
171+
"""Smooth the object contour."""
172+
for _ in range(2):
173+
binary = gaussian(binary, sigma=smooth_sigma, preserve_range=True)
174+
binary = (binary > smooth_threshold).astype(np.uint8)
175+
176+
return binary
177+
178+
179+
def skeleton_aware_distance_transform(
180+
label: np.ndarray,
181+
bg_value: float = -1.0,
182+
relabel: bool = True,
183+
padding: bool = False,
184+
resolution: Tuple[float] = (1.0, 1.0),
185+
alpha: float = 0.8,
186+
smooth: bool = True,
187+
smooth_skeleton_only: bool = True,
188+
):
189+
"""Skeleton-based distance transform (SDT).
190+
191+
Lin, Zudi, et al. "Structure-Preserving Instance Segmentation via Skeleton-Aware
192+
Distance Transform." International Conference on Medical Image Computing and
193+
Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.
194+
"""
195+
eps = 1e-6
196+
pad_size = 2
197+
198+
if relabel:
199+
label = label_cc(label)
200+
201+
if padding:
202+
# The distance_transform_edt function does not treat image border
203+
# as background. If image border needs to be considered as background
204+
# in distance calculation, set padding to True.
205+
label = np.pad(label, pad_size, mode='constant', constant_values=0)
206+
207+
label_shape = label.shape
208+
all_bg_sample = False
209+
210+
skeleton = np.zeros(label_shape, dtype=np.uint8)
211+
distance = np.zeros(label_shape, dtype=np.float32) + bg_value
212+
semantic = np.zeros(label_shape, dtype=np.uint8)
213+
214+
indices = np.unique(label)
215+
if indices[0] == 0:
216+
if len(indices) > 1: # exclude background
217+
indices = indices[1:]
218+
else: # all-background sample
219+
all_bg_sample = True
220+
221+
if not all_bg_sample:
222+
for idx in indices:
223+
temp1 = label.copy() == idx
224+
temp2 = remove_small_holes(temp1, 16, connectivity=1)
225+
binary = temp2.copy()
226+
227+
if smooth:
228+
binary = smooth_edge(binary)
229+
if binary.astype(int).sum() <= 32:
230+
# Reverse the smoothing operation if it makes
231+
# the output mask empty (or very small).
232+
binary = temp2.copy()
233+
else:
234+
if smooth_skeleton_only:
235+
binary = binary * temp2
236+
else:
237+
temp2 = binary.copy()
238+
239+
semantic += temp2.astype(np.uint8)
240+
241+
skeleton_mask = skeletonize(binary)
242+
skeleton_mask = (skeleton_mask != 0).astype(np.uint8)
243+
skeleton += skeleton_mask
244+
245+
skeleton_edt = distance_transform_edt(1-skeleton_mask, resolution)
246+
boundary_edt = distance_transform_edt(temp2, resolution)
247+
248+
energy = boundary_edt / (skeleton_edt + boundary_edt + eps) # normalize
249+
energy = energy ** alpha
250+
distance = np.maximum(distance, energy * temp2.astype(np.float32))
251+
252+
if padding:
253+
# Unpad the output array to preserve original shape.
254+
distance = array_unpad(distance, get_padsize(
255+
pad_size, ndim=distance.ndim))
256+
semantic = array_unpad(semantic, get_padsize(
257+
pad_size, ndim=distance.ndim))
258+
259+
return distance, semantic
260+
261+
138262
def energy_quantize(energy, levels=10):
139263
"""Convert the continuous energy map into the quantized version.
140264
"""

0 commit comments

Comments
 (0)