|
5 | 5 | import scipy
|
6 | 6 | import numpy as np
|
7 | 7 | from scipy.ndimage import distance_transform_edt
|
8 |
| -from skimage.morphology import remove_small_holes |
| 8 | +from skimage.morphology import remove_small_holes, skeletonize |
9 | 9 | from skimage.measure import label as label_cc # avoid namespace conflict
|
| 10 | +from skimage.filters import gaussian |
10 | 11 |
|
11 | 12 | from .data_misc import get_padsize, array_unpad
|
12 | 13 |
|
13 | 14 | __all__ = [
|
14 | 15 | 'edt_semantic',
|
15 | 16 | 'edt_instance',
|
| 17 | + 'sdt_instance', |
16 | 18 | 'decode_quantize',
|
17 | 19 | ]
|
18 | 20 |
|
19 | 21 |
|
20 | 22 | def edt_semantic(
|
21 |
| - label: np.ndarray, |
22 |
| - mode: str = '2d', |
23 |
| - alpha_fore: float = 8.0, |
24 |
| - alpha_back: float = 50.0): |
| 23 | + label: np.ndarray, |
| 24 | + mode: str = '2d', |
| 25 | + alpha_fore: float = 8.0, |
| 26 | + alpha_back: float = 50.0 |
| 27 | +): |
25 | 28 | """Euclidean distance transform (DT or EDT) for binary semantic mask.
|
26 | 29 | """
|
27 | 30 | assert mode in ['2d', '3d']
|
@@ -84,6 +87,35 @@ def edt_instance(label: np.ndarray,
|
84 | 87 | return vol_distance
|
85 | 88 |
|
86 | 89 |
|
| 90 | +def sdt_instance(label: np.ndarray, |
| 91 | + mode: str = '2d', |
| 92 | + quantize: bool = True, |
| 93 | + resolution: Tuple[float] = (1.0, 1.0), |
| 94 | + padding: bool = True): |
| 95 | + """Skeleton-based distance transform (SDT) for a stack of label images. |
| 96 | +
|
| 97 | + Lin, Zudi, et al. "Structure-Preserving Instance Segmentation via Skeleton-Aware |
| 98 | + Distance Transform." International Conference on Medical Image Computing and |
| 99 | + Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023. |
| 100 | + """ |
| 101 | + assert mode == "2d", "Only 2d skeletonization is currently supported." |
| 102 | + |
| 103 | + vol_distance = [] |
| 104 | + vol_semantic = [] |
| 105 | + for i in range(label.shape[0]): |
| 106 | + label_img = label[i].copy() |
| 107 | + distance, semantic = skeleton_aware_distance_transform(label_img, padding=padding) |
| 108 | + vol_distance.append(distance) |
| 109 | + vol_semantic.append(semantic) |
| 110 | + |
| 111 | + vol_distance = np.stack(vol_distance, 0) |
| 112 | + vol_semantic = np.stack(vol_semantic, 0) |
| 113 | + if quantize: |
| 114 | + vol_distance = energy_quantize(vol_distance) |
| 115 | + |
| 116 | + return vol_distance |
| 117 | + |
| 118 | + |
87 | 119 | def distance_transform(label: np.ndarray,
|
88 | 120 | bg_value: float = -1.0,
|
89 | 121 | relabel: bool = True,
|
@@ -135,6 +167,98 @@ def distance_transform(label: np.ndarray,
|
135 | 167 | return distance, semantic
|
136 | 168 |
|
137 | 169 |
|
| 170 | +def smooth_edge(binary, smooth_sigma: float = 2.0, smooth_threshold: float = 0.5): |
| 171 | + """Smooth the object contour.""" |
| 172 | + for _ in range(2): |
| 173 | + binary = gaussian(binary, sigma=smooth_sigma, preserve_range=True) |
| 174 | + binary = (binary > smooth_threshold).astype(np.uint8) |
| 175 | + |
| 176 | + return binary |
| 177 | + |
| 178 | + |
| 179 | +def skeleton_aware_distance_transform( |
| 180 | + label: np.ndarray, |
| 181 | + bg_value: float = -1.0, |
| 182 | + relabel: bool = True, |
| 183 | + padding: bool = False, |
| 184 | + resolution: Tuple[float] = (1.0, 1.0), |
| 185 | + alpha: float = 0.8, |
| 186 | + smooth: bool = True, |
| 187 | + smooth_skeleton_only: bool = True, |
| 188 | +): |
| 189 | + """Skeleton-based distance transform (SDT). |
| 190 | +
|
| 191 | + Lin, Zudi, et al. "Structure-Preserving Instance Segmentation via Skeleton-Aware |
| 192 | + Distance Transform." International Conference on Medical Image Computing and |
| 193 | + Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023. |
| 194 | + """ |
| 195 | + eps = 1e-6 |
| 196 | + pad_size = 2 |
| 197 | + |
| 198 | + if relabel: |
| 199 | + label = label_cc(label) |
| 200 | + |
| 201 | + if padding: |
| 202 | + # The distance_transform_edt function does not treat image border |
| 203 | + # as background. If image border needs to be considered as background |
| 204 | + # in distance calculation, set padding to True. |
| 205 | + label = np.pad(label, pad_size, mode='constant', constant_values=0) |
| 206 | + |
| 207 | + label_shape = label.shape |
| 208 | + all_bg_sample = False |
| 209 | + |
| 210 | + skeleton = np.zeros(label_shape, dtype=np.uint8) |
| 211 | + distance = np.zeros(label_shape, dtype=np.float32) + bg_value |
| 212 | + semantic = np.zeros(label_shape, dtype=np.uint8) |
| 213 | + |
| 214 | + indices = np.unique(label) |
| 215 | + if indices[0] == 0: |
| 216 | + if len(indices) > 1: # exclude background |
| 217 | + indices = indices[1:] |
| 218 | + else: # all-background sample |
| 219 | + all_bg_sample = True |
| 220 | + |
| 221 | + if not all_bg_sample: |
| 222 | + for idx in indices: |
| 223 | + temp1 = label.copy() == idx |
| 224 | + temp2 = remove_small_holes(temp1, 16, connectivity=1) |
| 225 | + binary = temp2.copy() |
| 226 | + |
| 227 | + if smooth: |
| 228 | + binary = smooth_edge(binary) |
| 229 | + if binary.astype(int).sum() <= 32: |
| 230 | + # Reverse the smoothing operation if it makes |
| 231 | + # the output mask empty (or very small). |
| 232 | + binary = temp2.copy() |
| 233 | + else: |
| 234 | + if smooth_skeleton_only: |
| 235 | + binary = binary * temp2 |
| 236 | + else: |
| 237 | + temp2 = binary.copy() |
| 238 | + |
| 239 | + semantic += temp2.astype(np.uint8) |
| 240 | + |
| 241 | + skeleton_mask = skeletonize(binary) |
| 242 | + skeleton_mask = (skeleton_mask != 0).astype(np.uint8) |
| 243 | + skeleton += skeleton_mask |
| 244 | + |
| 245 | + skeleton_edt = distance_transform_edt(1-skeleton_mask, resolution) |
| 246 | + boundary_edt = distance_transform_edt(temp2, resolution) |
| 247 | + |
| 248 | + energy = boundary_edt / (skeleton_edt + boundary_edt + eps) # normalize |
| 249 | + energy = energy ** alpha |
| 250 | + distance = np.maximum(distance, energy * temp2.astype(np.float32)) |
| 251 | + |
| 252 | + if padding: |
| 253 | + # Unpad the output array to preserve original shape. |
| 254 | + distance = array_unpad(distance, get_padsize( |
| 255 | + pad_size, ndim=distance.ndim)) |
| 256 | + semantic = array_unpad(semantic, get_padsize( |
| 257 | + pad_size, ndim=distance.ndim)) |
| 258 | + |
| 259 | + return distance, semantic |
| 260 | + |
| 261 | + |
138 | 262 | def energy_quantize(energy, levels=10):
|
139 | 263 | """Convert the continuous energy map into the quantized version.
|
140 | 264 | """
|
|
0 commit comments