|
| 1 | +""" |
| 2 | +Tests for CTFireExtraction -- soft IoU against ground truth centerline. |
| 3 | +
|
| 4 | +Soft IoU formula (Gaussian-smoothed, following centerline.py): |
| 5 | + smooth each 1-px skeleton with a Gaussian (sigma=5) rescaled to [0,1], |
| 6 | + then IoU = (m1*m2).sum() / (m1²+m2²-m1*m2).sum() |
| 7 | +
|
| 8 | +sigma=5 measures spatial agreement at the ~10-px scale, which is the right |
| 9 | +tolerance for testing "did CT-FIRE find the fibers" rather than pixel precision. |
| 10 | +""" |
| 11 | +from __future__ import annotations |
| 12 | +import numpy as np |
| 13 | +import pytest |
| 14 | +from scipy.ndimage import distance_transform_edt |
| 15 | +from skimage import draw, filters, exposure, morphology |
| 16 | +from skimage.transform import resize as sk_resize |
| 17 | +from tme_quant.fiber_analysis.methods.ctfire import CTFireExtraction |
| 18 | +from tme_quant.fiber_analysis.config import CTFireParams |
| 19 | + |
| 20 | +# Gaussian sigma for smoothing 1-px skeletons before computing soft IoU. |
| 21 | +# sigma=5 → ~10-px FWHM, appropriate for fiber-level (not pixel-level) matching. |
| 22 | +_SMOOTH_SIGMA = 2.0 |
| 23 | + |
| 24 | + |
| 25 | +def _smooth_mask(mask, smooth_sigma=_SMOOTH_SIGMA): |
| 26 | + mask = mask.astype(np.float32) |
| 27 | + mask = exposure.rescale_intensity(mask, out_range=(0.0, 1.0)) |
| 28 | + density = filters.gaussian(mask, sigma=smooth_sigma, preserve_range=False) |
| 29 | + return exposure.rescale_intensity(density, out_range=(0.0, 1.0)).astype(np.float32) |
| 30 | + |
| 31 | + |
| 32 | +def _soft_iou(mask_1, mask_2, beta=1e-3): |
| 33 | + if mask_1.shape != mask_2.shape: |
| 34 | + mask_2 = sk_resize( |
| 35 | + mask_2, mask_1.shape, anti_aliasing=True, preserve_range=True |
| 36 | + ).astype(np.float32) |
| 37 | + intersection = mask_1 * mask_2 |
| 38 | + union = mask_1**2 + mask_2**2 - mask_1 * mask_2 |
| 39 | + return float((intersection.sum() + beta) / (union.sum() + beta)) |
| 40 | + |
| 41 | + |
| 42 | +def _rasterize_centerlines(fibers, image_shape): |
| 43 | + canvas = np.zeros(image_shape, dtype=np.uint8) |
| 44 | + H, W = image_shape |
| 45 | + for fiber in fibers: |
| 46 | + pts = fiber.centerline |
| 47 | + if pts is None or len(pts) < 2: |
| 48 | + continue |
| 49 | + pts = np.round(pts[:, :2]).astype(int) |
| 50 | + for i in range(len(pts) - 1): |
| 51 | + r0 = int(np.clip(pts[i, 0], 0, H - 1)) |
| 52 | + c0 = int(np.clip(pts[i, 1], 0, W - 1)) |
| 53 | + r1 = int(np.clip(pts[i + 1, 0], 0, H - 1)) |
| 54 | + c1 = int(np.clip(pts[i + 1, 1], 0, W - 1)) |
| 55 | + rr, cc = draw.line(r0, c0, r1, c1) |
| 56 | + canvas[rr, cc] = 1 |
| 57 | + return morphology.skeletonize(canvas > 0) |
| 58 | + |
| 59 | + |
| 60 | +def _make_synthetic_fiber_image( |
| 61 | + shape=(256, 256), |
| 62 | + n_fibers=10, |
| 63 | + fiber_sigma=2.0, |
| 64 | + min_length=60, |
| 65 | + rng_seed=42, |
| 66 | +): |
| 67 | + """ |
| 68 | + Synthetic fiber image with Gaussian cross-section profiles. |
| 69 | +
|
| 70 | + Fibers are drawn as straight lines; each pixel's intensity is |
| 71 | + exp(-d^2 / (2*fiber_sigma^2)) where d is the distance to the nearest |
| 72 | + centerline. This matches the Gaussian ridge appearance of real SHG |
| 73 | + collagen images that CT-FIRE was designed for. |
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + image : (H, W) float32, values in [0, 1] |
| 78 | + gt_skeleton : (H, W) bool -- 1-px-thick ground-truth centerlines |
| 79 | + """ |
| 80 | + rng = np.random.default_rng(rng_seed) |
| 81 | + H, W = shape |
| 82 | + margin = max(10, min_length // 4) |
| 83 | + skeleton_map = np.zeros(shape, dtype=bool) |
| 84 | + |
| 85 | + generated = 0 |
| 86 | + attempts = 0 |
| 87 | + while generated < n_fibers and attempts < n_fibers * 10: |
| 88 | + attempts += 1 |
| 89 | + r0 = int(rng.integers(margin, H - margin)) |
| 90 | + c0 = int(rng.integers(margin, W - margin)) |
| 91 | + angle = rng.uniform(0, np.pi) # avoid duplicate reversed segments |
| 92 | + length = int(rng.integers(min_length, max(min_length + 1, min(H, W) - 2 * margin))) |
| 93 | + r1 = int(np.clip(r0 + length * np.sin(angle), margin, H - margin)) |
| 94 | + c1 = int(np.clip(c0 + length * np.cos(angle), margin, W - margin)) |
| 95 | + actual_len = np.hypot(r1 - r0, c1 - c0) |
| 96 | + if actual_len < min_length: |
| 97 | + continue |
| 98 | + rr, cc = draw.line(r0, c0, r1, c1) |
| 99 | + skeleton_map[rr, cc] = True |
| 100 | + generated += 1 |
| 101 | + |
| 102 | + # Gaussian cross-section via distance transform (fiber amplitude 0.85) |
| 103 | + dist = distance_transform_edt(~skeleton_map).astype(np.float32) |
| 104 | + fiber_signal = 0.85 * np.exp(-dist**2 / (2.0 * fiber_sigma**2)) |
| 105 | + |
| 106 | + # SHG-like background: non-zero mean with Gaussian noise, as CT-FIRE's |
| 107 | + # percentile-based bright-pixel logic requires that fibers occupy the |
| 108 | + # top ~8% of pixel intensities (impossible on a near-zero background). |
| 109 | + background = rng.normal(0.15, 0.04, size=shape).astype(np.float32) |
| 110 | + image = np.clip(background + fiber_signal, 0.0, 1.0) |
| 111 | + |
| 112 | + gt_skeleton = morphology.skeletonize(skeleton_map) |
| 113 | + return image, gt_skeleton |
| 114 | + |
| 115 | + |
| 116 | +# CT-FIRE parameters tuned for the synthetic images above. |
| 117 | +# ctfire_threshold=0.25 keeps only strong Frangi ridges, suppressing the |
| 118 | +# many short false-positive strands that appear at lower thresholds. |
| 119 | +# min_fiber_width=2.0 further rejects thin noise strands (< 1 px radius). |
| 120 | +_CTFIRE_PARAMS = CTFireParams( |
| 121 | + pixel_size=1.0, |
| 122 | + min_fiber_length=25.0, |
| 123 | + max_fiber_length=1000.0, |
| 124 | + min_fiber_width=2.0, |
| 125 | + max_fiber_width=15.0, |
| 126 | + ctfire_threshold=0.25, |
| 127 | + mask_closing_radius=1, |
| 128 | + spur_length_px=5, |
| 129 | + extract_centerlines=True, |
| 130 | +) |
| 131 | + |
| 132 | +SOFT_IOU_THRESHOLD = 0.7 |
| 133 | + |
| 134 | + |
| 135 | +# ───────────────────────────────────────────────────────────────────────────── |
| 136 | +# Visualization helper |
| 137 | +# ───────────────────────────────────────────────────────────────────────────── |
| 138 | + |
| 139 | +def plot_centerline_overlay( |
| 140 | + image: np.ndarray, |
| 141 | + fibers, |
| 142 | + gt_skeleton: np.ndarray | None = None, |
| 143 | + title: str = "CT-FIRE centerline overlay", |
| 144 | + save_path: str | None = None, |
| 145 | +): |
| 146 | + """ |
| 147 | + Display extracted fiber centerlines overlaid on the source image. |
| 148 | +
|
| 149 | + Each detected fiber is drawn in a distinct color from a qualitative |
| 150 | + colormap. The optional ground-truth skeleton is shown in white. |
| 151 | +
|
| 152 | + Parameters |
| 153 | + ---------- |
| 154 | + image : (H, W) float32 |
| 155 | + Grayscale source image (values in [0, 1]). |
| 156 | + fibers : list of FiberProperties |
| 157 | + CT-FIRE output fibers; only those with a ``centerline`` are drawn. |
| 158 | + gt_skeleton : (H, W) bool, optional |
| 159 | + Ground-truth 1-px skeleton. Drawn in white if provided. |
| 160 | + title : str |
| 161 | + Figure title (also used as the window title). |
| 162 | + save_path : str, optional |
| 163 | + If given, save the figure to this path instead of showing it. |
| 164 | + """ |
| 165 | + import matplotlib.pyplot as plt |
| 166 | + import matplotlib.colors as mcolors |
| 167 | + |
| 168 | + fig, ax = plt.subplots(figsize=(7, 7)) |
| 169 | + ax.imshow(image, cmap="gray", vmin=0, vmax=1, interpolation="nearest") |
| 170 | + |
| 171 | + if gt_skeleton is not None: |
| 172 | + # Green semi-transparent overlay for ground truth |
| 173 | + gt_rgba = np.zeros((*gt_skeleton.shape, 4), dtype=np.float32) |
| 174 | + gt_rgba[gt_skeleton, :] = [0.0, 1.0, 0.0, 0.8] |
| 175 | + ax.imshow(gt_rgba, interpolation="nearest") |
| 176 | + |
| 177 | + # Pick a qualitative colormap with enough distinct colors |
| 178 | + cmap = plt.get_cmap("tab20") |
| 179 | + fibers_with_cl = [f for f in fibers if f.centerline is not None and len(f.centerline) >= 2] |
| 180 | + H, W = image.shape |
| 181 | + |
| 182 | + for idx, fiber in enumerate(fibers_with_cl): |
| 183 | + color = cmap(idx % 20) |
| 184 | + pts = np.round(fiber.centerline[:, :2]).astype(int) |
| 185 | + # col (x) first, then row (y) for matplotlib |
| 186 | + xs = np.clip(pts[:, 1], 0, W - 1) |
| 187 | + ys = np.clip(pts[:, 0], 0, H - 1) |
| 188 | + ax.plot(xs, ys, "-", color=color, linewidth=1.2, alpha=0.9) |
| 189 | + # Mark the start point |
| 190 | + ax.plot(xs[0], ys[0], "o", color=color, markersize=3, alpha=0.9) |
| 191 | + |
| 192 | + n_gt = int(gt_skeleton.sum()) if gt_skeleton is not None else 0 |
| 193 | + ax.set_title( |
| 194 | + f"{title}\n" |
| 195 | + f"{len(fibers_with_cl)} detected fibers" |
| 196 | + + (f" | GT pixels: {n_gt}" if gt_skeleton is not None else ""), |
| 197 | + fontsize=10, |
| 198 | + ) |
| 199 | + ax.axis("off") |
| 200 | + fig.tight_layout() |
| 201 | + |
| 202 | + if save_path is not None: |
| 203 | + fig.savefig(save_path, dpi=150, bbox_inches="tight") |
| 204 | + plt.close(fig) |
| 205 | + else: |
| 206 | + plt.show() |
| 207 | + return fig |
| 208 | + |
| 209 | + |
| 210 | +# ───────────────────────────────────────────────────────────────────────────── |
| 211 | +# Standalone demo (python tests/test_ctfire.py) |
| 212 | +# ───────────────────────────────────────────────────────────────────────────── |
| 213 | + |
| 214 | +def _demo(): |
| 215 | + """Generate a synthetic image, run CT-FIRE, and show the overlay.""" |
| 216 | + image, gt_skeleton = _make_synthetic_fiber_image( |
| 217 | + shape=(256, 256), n_fibers=10, fiber_sigma=2.0, rng_seed=42 |
| 218 | + ) |
| 219 | + result = CTFireExtraction().extract_2d(image, _CTFIRE_PARAMS) |
| 220 | + pred_skeleton = _rasterize_centerlines(result.fibers, image.shape) |
| 221 | + ratio = _soft_iou(_smooth_mask(gt_skeleton), _smooth_mask(pred_skeleton)) |
| 222 | + print(f"Detected {len(result.fibers)} fibers | soft IoU = {ratio:.4f}") |
| 223 | + plot_centerline_overlay( |
| 224 | + image, |
| 225 | + result.fibers, |
| 226 | + gt_skeleton=gt_skeleton, |
| 227 | + title=f"CT-FIRE overlay (soft IoU = {ratio:.4f})", |
| 228 | + ) |
| 229 | + |
| 230 | + |
| 231 | +if __name__ == "__main__": |
| 232 | + _demo() |
| 233 | + |
| 234 | + |
| 235 | +class TestCTFireSoftIoU: |
| 236 | + def test_soft_iou_synthetic(self): |
| 237 | + image, gt_skeleton = _make_synthetic_fiber_image( |
| 238 | + shape=(256, 256), n_fibers=10, fiber_sigma=2.0, rng_seed=42 |
| 239 | + ) |
| 240 | + result = CTFireExtraction().extract_2d(image, _CTFIRE_PARAMS) |
| 241 | + assert result.fibers, "CT-FIRE returned no fibers." |
| 242 | + pred_skeleton = _rasterize_centerlines(result.fibers, image.shape) |
| 243 | + ratio = _soft_iou(_smooth_mask(gt_skeleton), _smooth_mask(pred_skeleton)) |
| 244 | + assert ratio > SOFT_IOU_THRESHOLD, ( |
| 245 | + f"Soft IoU {ratio:.4f} < {SOFT_IOU_THRESHOLD}" |
| 246 | + ) |
| 247 | + |
| 248 | + def test_soft_iou_uses_centerlines(self): |
| 249 | + image, _ = _make_synthetic_fiber_image( |
| 250 | + shape=(128, 128), n_fibers=5, fiber_sigma=2.0, rng_seed=7 |
| 251 | + ) |
| 252 | + result = CTFireExtraction().extract_2d(image, _CTFIRE_PARAMS) |
| 253 | + fibers_with_cl = [f for f in result.fibers if f.centerline is not None] |
| 254 | + assert fibers_with_cl, "No fibers have centerlines." |
| 255 | + assert _rasterize_centerlines(result.fibers, image.shape).any(), ( |
| 256 | + "Skeleton is blank." |
| 257 | + ) |
| 258 | + |
| 259 | + @pytest.mark.parametrize("n_fibers,seed", [(5, 0), (8, 13), (10, 99)]) |
| 260 | + def test_soft_iou_multiple_configurations(self, n_fibers, seed): |
| 261 | + image, gt_skeleton = _make_synthetic_fiber_image( |
| 262 | + shape=(256, 256), n_fibers=n_fibers, fiber_sigma=2.0, rng_seed=seed |
| 263 | + ) |
| 264 | + result = CTFireExtraction().extract_2d(image, _CTFIRE_PARAMS) |
| 265 | + assert result.fibers, f"No fibers (n_fibers={n_fibers}, seed={seed})." |
| 266 | + pred_skeleton = _rasterize_centerlines(result.fibers, image.shape) |
| 267 | + ratio = _soft_iou(_smooth_mask(gt_skeleton), _smooth_mask(pred_skeleton)) |
| 268 | + assert ratio > SOFT_IOU_THRESHOLD, ( |
| 269 | + f"Soft IoU {ratio:.4f} < {SOFT_IOU_THRESHOLD} " |
| 270 | + f"(n_fibers={n_fibers}, seed={seed})" |
| 271 | + ) |
0 commit comments