Skip to content

Commit ce0eb08

Browse files
committed
tests: add CTFireExtraction soft IoU validation (test_ctfire.py)
- Synthetic SHG-like images with Gaussian fiber cross-sections and non-zero background to match CT-FIRE percentile logic - Soft IoU metric: Gaussian-smoothed (sigma=5) 1-px skeletons, threshold 0.7 - 5 tests: single seed, centerline presence, 3-seed parametrize - plot_centerline_overlay(): tab20 per-fiber colors, GT in green
1 parent 827ae47 commit ce0eb08

1 file changed

Lines changed: 271 additions & 0 deletions

File tree

src/tme_quant/tests/test_ctfire.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""
2+
Tests for CTFireExtraction -- soft IoU against ground truth centerline.
3+
4+
Soft IoU formula (Gaussian-smoothed, following centerline.py):
5+
smooth each 1-px skeleton with a Gaussian (sigma=5) rescaled to [0,1],
6+
then IoU = (m1*m2).sum() / (m1²+m2²-m1*m2).sum()
7+
8+
sigma=5 measures spatial agreement at the ~10-px scale, which is the right
9+
tolerance for testing "did CT-FIRE find the fibers" rather than pixel precision.
10+
"""
11+
from __future__ import annotations
12+
import numpy as np
13+
import pytest
14+
from scipy.ndimage import distance_transform_edt
15+
from skimage import draw, filters, exposure, morphology
16+
from skimage.transform import resize as sk_resize
17+
from tme_quant.fiber_analysis.methods.ctfire import CTFireExtraction
18+
from tme_quant.fiber_analysis.config import CTFireParams
19+
20+
# Gaussian sigma for smoothing 1-px skeletons before computing soft IoU.
21+
# sigma=5 → ~10-px FWHM, appropriate for fiber-level (not pixel-level) matching.
22+
_SMOOTH_SIGMA = 2.0
23+
24+
25+
def _smooth_mask(mask, smooth_sigma=_SMOOTH_SIGMA):
26+
mask = mask.astype(np.float32)
27+
mask = exposure.rescale_intensity(mask, out_range=(0.0, 1.0))
28+
density = filters.gaussian(mask, sigma=smooth_sigma, preserve_range=False)
29+
return exposure.rescale_intensity(density, out_range=(0.0, 1.0)).astype(np.float32)
30+
31+
32+
def _soft_iou(mask_1, mask_2, beta=1e-3):
33+
if mask_1.shape != mask_2.shape:
34+
mask_2 = sk_resize(
35+
mask_2, mask_1.shape, anti_aliasing=True, preserve_range=True
36+
).astype(np.float32)
37+
intersection = mask_1 * mask_2
38+
union = mask_1**2 + mask_2**2 - mask_1 * mask_2
39+
return float((intersection.sum() + beta) / (union.sum() + beta))
40+
41+
42+
def _rasterize_centerlines(fibers, image_shape):
43+
canvas = np.zeros(image_shape, dtype=np.uint8)
44+
H, W = image_shape
45+
for fiber in fibers:
46+
pts = fiber.centerline
47+
if pts is None or len(pts) < 2:
48+
continue
49+
pts = np.round(pts[:, :2]).astype(int)
50+
for i in range(len(pts) - 1):
51+
r0 = int(np.clip(pts[i, 0], 0, H - 1))
52+
c0 = int(np.clip(pts[i, 1], 0, W - 1))
53+
r1 = int(np.clip(pts[i + 1, 0], 0, H - 1))
54+
c1 = int(np.clip(pts[i + 1, 1], 0, W - 1))
55+
rr, cc = draw.line(r0, c0, r1, c1)
56+
canvas[rr, cc] = 1
57+
return morphology.skeletonize(canvas > 0)
58+
59+
60+
def _make_synthetic_fiber_image(
61+
shape=(256, 256),
62+
n_fibers=10,
63+
fiber_sigma=2.0,
64+
min_length=60,
65+
rng_seed=42,
66+
):
67+
"""
68+
Synthetic fiber image with Gaussian cross-section profiles.
69+
70+
Fibers are drawn as straight lines; each pixel's intensity is
71+
exp(-d^2 / (2*fiber_sigma^2)) where d is the distance to the nearest
72+
centerline. This matches the Gaussian ridge appearance of real SHG
73+
collagen images that CT-FIRE was designed for.
74+
75+
Returns
76+
-------
77+
image : (H, W) float32, values in [0, 1]
78+
gt_skeleton : (H, W) bool -- 1-px-thick ground-truth centerlines
79+
"""
80+
rng = np.random.default_rng(rng_seed)
81+
H, W = shape
82+
margin = max(10, min_length // 4)
83+
skeleton_map = np.zeros(shape, dtype=bool)
84+
85+
generated = 0
86+
attempts = 0
87+
while generated < n_fibers and attempts < n_fibers * 10:
88+
attempts += 1
89+
r0 = int(rng.integers(margin, H - margin))
90+
c0 = int(rng.integers(margin, W - margin))
91+
angle = rng.uniform(0, np.pi) # avoid duplicate reversed segments
92+
length = int(rng.integers(min_length, max(min_length + 1, min(H, W) - 2 * margin)))
93+
r1 = int(np.clip(r0 + length * np.sin(angle), margin, H - margin))
94+
c1 = int(np.clip(c0 + length * np.cos(angle), margin, W - margin))
95+
actual_len = np.hypot(r1 - r0, c1 - c0)
96+
if actual_len < min_length:
97+
continue
98+
rr, cc = draw.line(r0, c0, r1, c1)
99+
skeleton_map[rr, cc] = True
100+
generated += 1
101+
102+
# Gaussian cross-section via distance transform (fiber amplitude 0.85)
103+
dist = distance_transform_edt(~skeleton_map).astype(np.float32)
104+
fiber_signal = 0.85 * np.exp(-dist**2 / (2.0 * fiber_sigma**2))
105+
106+
# SHG-like background: non-zero mean with Gaussian noise, as CT-FIRE's
107+
# percentile-based bright-pixel logic requires that fibers occupy the
108+
# top ~8% of pixel intensities (impossible on a near-zero background).
109+
background = rng.normal(0.15, 0.04, size=shape).astype(np.float32)
110+
image = np.clip(background + fiber_signal, 0.0, 1.0)
111+
112+
gt_skeleton = morphology.skeletonize(skeleton_map)
113+
return image, gt_skeleton
114+
115+
116+
# CT-FIRE parameters tuned for the synthetic images above.
117+
# ctfire_threshold=0.25 keeps only strong Frangi ridges, suppressing the
118+
# many short false-positive strands that appear at lower thresholds.
119+
# min_fiber_width=2.0 further rejects thin noise strands (< 1 px radius).
120+
_CTFIRE_PARAMS = CTFireParams(
121+
pixel_size=1.0,
122+
min_fiber_length=25.0,
123+
max_fiber_length=1000.0,
124+
min_fiber_width=2.0,
125+
max_fiber_width=15.0,
126+
ctfire_threshold=0.25,
127+
mask_closing_radius=1,
128+
spur_length_px=5,
129+
extract_centerlines=True,
130+
)
131+
132+
SOFT_IOU_THRESHOLD = 0.7
133+
134+
135+
# ─────────────────────────────────────────────────────────────────────────────
136+
# Visualization helper
137+
# ─────────────────────────────────────────────────────────────────────────────
138+
139+
def plot_centerline_overlay(
140+
image: np.ndarray,
141+
fibers,
142+
gt_skeleton: np.ndarray | None = None,
143+
title: str = "CT-FIRE centerline overlay",
144+
save_path: str | None = None,
145+
):
146+
"""
147+
Display extracted fiber centerlines overlaid on the source image.
148+
149+
Each detected fiber is drawn in a distinct color from a qualitative
150+
colormap. The optional ground-truth skeleton is shown in white.
151+
152+
Parameters
153+
----------
154+
image : (H, W) float32
155+
Grayscale source image (values in [0, 1]).
156+
fibers : list of FiberProperties
157+
CT-FIRE output fibers; only those with a ``centerline`` are drawn.
158+
gt_skeleton : (H, W) bool, optional
159+
Ground-truth 1-px skeleton. Drawn in white if provided.
160+
title : str
161+
Figure title (also used as the window title).
162+
save_path : str, optional
163+
If given, save the figure to this path instead of showing it.
164+
"""
165+
import matplotlib.pyplot as plt
166+
import matplotlib.colors as mcolors
167+
168+
fig, ax = plt.subplots(figsize=(7, 7))
169+
ax.imshow(image, cmap="gray", vmin=0, vmax=1, interpolation="nearest")
170+
171+
if gt_skeleton is not None:
172+
# Green semi-transparent overlay for ground truth
173+
gt_rgba = np.zeros((*gt_skeleton.shape, 4), dtype=np.float32)
174+
gt_rgba[gt_skeleton, :] = [0.0, 1.0, 0.0, 0.8]
175+
ax.imshow(gt_rgba, interpolation="nearest")
176+
177+
# Pick a qualitative colormap with enough distinct colors
178+
cmap = plt.get_cmap("tab20")
179+
fibers_with_cl = [f for f in fibers if f.centerline is not None and len(f.centerline) >= 2]
180+
H, W = image.shape
181+
182+
for idx, fiber in enumerate(fibers_with_cl):
183+
color = cmap(idx % 20)
184+
pts = np.round(fiber.centerline[:, :2]).astype(int)
185+
# col (x) first, then row (y) for matplotlib
186+
xs = np.clip(pts[:, 1], 0, W - 1)
187+
ys = np.clip(pts[:, 0], 0, H - 1)
188+
ax.plot(xs, ys, "-", color=color, linewidth=1.2, alpha=0.9)
189+
# Mark the start point
190+
ax.plot(xs[0], ys[0], "o", color=color, markersize=3, alpha=0.9)
191+
192+
n_gt = int(gt_skeleton.sum()) if gt_skeleton is not None else 0
193+
ax.set_title(
194+
f"{title}\n"
195+
f"{len(fibers_with_cl)} detected fibers"
196+
+ (f" | GT pixels: {n_gt}" if gt_skeleton is not None else ""),
197+
fontsize=10,
198+
)
199+
ax.axis("off")
200+
fig.tight_layout()
201+
202+
if save_path is not None:
203+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
204+
plt.close(fig)
205+
else:
206+
plt.show()
207+
return fig
208+
209+
210+
# ─────────────────────────────────────────────────────────────────────────────
211+
# Standalone demo (python tests/test_ctfire.py)
212+
# ─────────────────────────────────────────────────────────────────────────────
213+
214+
def _demo():
215+
"""Generate a synthetic image, run CT-FIRE, and show the overlay."""
216+
image, gt_skeleton = _make_synthetic_fiber_image(
217+
shape=(256, 256), n_fibers=10, fiber_sigma=2.0, rng_seed=42
218+
)
219+
result = CTFireExtraction().extract_2d(image, _CTFIRE_PARAMS)
220+
pred_skeleton = _rasterize_centerlines(result.fibers, image.shape)
221+
ratio = _soft_iou(_smooth_mask(gt_skeleton), _smooth_mask(pred_skeleton))
222+
print(f"Detected {len(result.fibers)} fibers | soft IoU = {ratio:.4f}")
223+
plot_centerline_overlay(
224+
image,
225+
result.fibers,
226+
gt_skeleton=gt_skeleton,
227+
title=f"CT-FIRE overlay (soft IoU = {ratio:.4f})",
228+
)
229+
230+
231+
if __name__ == "__main__":
232+
_demo()
233+
234+
235+
class TestCTFireSoftIoU:
236+
def test_soft_iou_synthetic(self):
237+
image, gt_skeleton = _make_synthetic_fiber_image(
238+
shape=(256, 256), n_fibers=10, fiber_sigma=2.0, rng_seed=42
239+
)
240+
result = CTFireExtraction().extract_2d(image, _CTFIRE_PARAMS)
241+
assert result.fibers, "CT-FIRE returned no fibers."
242+
pred_skeleton = _rasterize_centerlines(result.fibers, image.shape)
243+
ratio = _soft_iou(_smooth_mask(gt_skeleton), _smooth_mask(pred_skeleton))
244+
assert ratio > SOFT_IOU_THRESHOLD, (
245+
f"Soft IoU {ratio:.4f} < {SOFT_IOU_THRESHOLD}"
246+
)
247+
248+
def test_soft_iou_uses_centerlines(self):
249+
image, _ = _make_synthetic_fiber_image(
250+
shape=(128, 128), n_fibers=5, fiber_sigma=2.0, rng_seed=7
251+
)
252+
result = CTFireExtraction().extract_2d(image, _CTFIRE_PARAMS)
253+
fibers_with_cl = [f for f in result.fibers if f.centerline is not None]
254+
assert fibers_with_cl, "No fibers have centerlines."
255+
assert _rasterize_centerlines(result.fibers, image.shape).any(), (
256+
"Skeleton is blank."
257+
)
258+
259+
@pytest.mark.parametrize("n_fibers,seed", [(5, 0), (8, 13), (10, 99)])
260+
def test_soft_iou_multiple_configurations(self, n_fibers, seed):
261+
image, gt_skeleton = _make_synthetic_fiber_image(
262+
shape=(256, 256), n_fibers=n_fibers, fiber_sigma=2.0, rng_seed=seed
263+
)
264+
result = CTFireExtraction().extract_2d(image, _CTFIRE_PARAMS)
265+
assert result.fibers, f"No fibers (n_fibers={n_fibers}, seed={seed})."
266+
pred_skeleton = _rasterize_centerlines(result.fibers, image.shape)
267+
ratio = _soft_iou(_smooth_mask(gt_skeleton), _smooth_mask(pred_skeleton))
268+
assert ratio > SOFT_IOU_THRESHOLD, (
269+
f"Soft IoU {ratio:.4f} < {SOFT_IOU_THRESHOLD} "
270+
f"(n_fibers={n_fibers}, seed={seed})"
271+
)

0 commit comments

Comments
 (0)