Skip to content

Commit cd0703e

Browse files
yuminguwclaude
andcommitted
Add soft IoU validation to test_fire_2d_angle.py
Adds class TestSoftIoU with three test tiers that measure spatial accuracy of the FIRE fiber extraction pipeline by comparing 1-pixel-thick skeleton images. Soft IoU metric --------------- Both skeletons (ground truth and extracted) are first smoothed with a Gaussian at sigma=5 (≈10-px FWHM) and rescaled to [0, 1]. The IoU is then computed as: IoU = (m1 * m2).sum() / (m1² + m2² - m1 * m2).sum() Operating at ~10-px FWHM measures agreement at fiber scale rather than pixel scale, tolerating the sub-pixel registration differences that are normal for a distance-transform medial-axis algorithm. Thresholds: 0.70 for synthetic images (known GT), 0.60 for MATLAB comparison (real images, higher variation expected). Test tiers ---------- A. test_soft_iou_synthetic (seed=42, always runs) - Generates 8 straight-line Gaussian-profile fibers on a low background - Records the exact 1-px skeleton as ground truth at generation time - Asserts soft IoU > 0.70, total extracted length > 20% of GT length, and angle std > 0.3 rad (confirms angle computation is non-degenerate) - Uses thresh_im=0.2 (fractional) to cleanly separate fiber signal (peak ~180) from background (mean ~10); thresh_im2=5 (absolute) would include virtually all background pixels and produce hundreds of spurious short zigzag fibers B. test_soft_iou_multiple_seeds (seeds 0, 7, 99) - Same structure as A; guards against a lucky pass on a single seed C. test_soft_iou_matlab_vs_python (@pytest.mark.matlab, skips if absent) - Rasterizes MATLAB Xa/Fa as ground truth, compares to Python Xf/Ff - Also checks total length within ±40% and mean |angle| within 10° Supporting additions -------------------- - _smooth_mask, _soft_iou, _rasterize_fibers, _make_synthetic_fiber_image helper functions (pattern adapted from tme_quant/tests/test_ctfire.py) - _default_fire_params() reads JSON params for real-image tests - _synthetic_fire_params() overrides thresh_im=0.2 for synthetic tests - __main__ block replaced with standalone demo (saves overlay PNG via plot_fiber_overlay from src/ctfire_py/test_fire_2d.py) - pyproject.toml: register 'matlab' pytest marker Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent b5f81f3 commit cd0703e

2 files changed

Lines changed: 332 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,6 @@ napari_curvealign = ["napari.yaml"]
5252
testpaths = [
5353
"tests",
5454
]
55+
markers = [
56+
"matlab: tests requiring MATLAB reference .mat files (skip if absent)",
57+
]

tests/test_fire_2d_angle.py

Lines changed: 329 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,150 @@ def load_matlab_reference(mat_file_path):
186186
pytest.skip("Neither h5py nor scipy.io available for loading MATLAB files")
187187

188188

189+
# ============================================================================
190+
# Soft IoU helpers
191+
# ============================================================================
192+
193+
# Gaussian sigma for smoothing 1-px skeletons before computing soft IoU.
194+
# sigma=5 → ~10-px FWHM, which measures agreement at fiber scale rather than
195+
# pixel scale. Adapted from tme_quant/tests/test_ctfire.py.
196+
_SMOOTH_SIGMA = 5.0
197+
198+
SOFT_IOU_THRESHOLD_SYNTHETIC = 0.70 # extracted skeleton vs known GT skeleton
199+
SOFT_IOU_THRESHOLD_MATLAB = 0.60 # Python centerlines vs MATLAB centerlines
200+
201+
202+
def _smooth_mask(mask, smooth_sigma=_SMOOTH_SIGMA):
203+
"""Gaussian-smooth a binary mask and rescale to [0, 1]."""
204+
from skimage import exposure, filters
205+
mask = mask.astype(np.float32)
206+
mask = exposure.rescale_intensity(mask, out_range=(0.0, 1.0))
207+
density = filters.gaussian(mask, sigma=smooth_sigma, preserve_range=False)
208+
return exposure.rescale_intensity(density, out_range=(0.0, 1.0)).astype(np.float32)
209+
210+
211+
def _soft_iou(mask_1, mask_2, beta=1e-3):
212+
"""
213+
Soft IoU on two float masks already smoothed to [0, 1].
214+
215+
Formula: (m1·m2).sum() / (m1²+m2²-m1·m2).sum()
216+
beta prevents 0/0 when both masks are empty.
217+
"""
218+
intersection = mask_1 * mask_2
219+
union = mask_1**2 + mask_2**2 - mask_1 * mask_2
220+
return float((intersection.sum() + beta) / (union.sum() + beta))
221+
222+
223+
def _rasterize_fibers(X, F, image_shape):
224+
"""
225+
Rasterize fire_2d_angle fiber output to a 1-px skeleton image.
226+
227+
Uses Bresenham line drawing between consecutive fiber vertices then
228+
morphological skeletonize to guarantee 1-px thickness. Vertex indices
229+
follow the 0-based convention (X[v] = coordinates for vertex v) as
230+
established in CLAUDE_CTFIRE.md.
231+
"""
232+
from skimage.draw import line as draw_line
233+
from skimage.morphology import skeletonize
234+
canvas = np.zeros(image_shape, dtype=np.uint8)
235+
H, W = image_shape
236+
X_arr = np.asarray(X)
237+
for fiber in F:
238+
v_list = fiber['v'] if isinstance(fiber, dict) else list(fiber)
239+
for seg in range(len(v_list) - 1):
240+
v0, v1 = v_list[seg], v_list[seg + 1]
241+
if v0 >= len(X_arr) or v1 >= len(X_arr):
242+
continue
243+
r0 = int(np.clip(round(float(X_arr[v0, 0])), 0, H - 1))
244+
c0 = int(np.clip(round(float(X_arr[v0, 1])), 0, W - 1))
245+
r1 = int(np.clip(round(float(X_arr[v1, 0])), 0, H - 1))
246+
c1 = int(np.clip(round(float(X_arr[v1, 1])), 0, W - 1))
247+
rr, cc = draw_line(r0, c0, r1, c1)
248+
canvas[rr, cc] = 1
249+
return skeletonize(canvas > 0)
250+
251+
252+
def _make_synthetic_fiber_image(shape=(256, 256), n_fibers=8, fiber_sigma=2.5, rng_seed=42):
253+
"""
254+
Generate a synthetic fiber image with a known ground-truth skeleton.
255+
256+
Straight-line fibers are drawn with Gaussian cross-section profiles
257+
(signal amplitude 180, background mean 10 with Gaussian noise σ=4)
258+
so the FIRE distance-transform pipeline can find them at thresh_im2=5.
259+
Returns the image and the skeletonized 1-px ground-truth centerline mask.
260+
261+
Why synthetic images:
262+
- Exact centerline positions are known at generation time — soft IoU
263+
measures genuine spatial recovery, not just "did any fibers come out".
264+
- Deterministic RNG seeds make CI results reproducible.
265+
- No MATLAB .mat reference files required.
266+
- Straight-line Gaussian fibers are the canonical FIRE input; regressions
267+
in extend_xlink, trimxfv, or filtering stages show up as IoU drops.
268+
"""
269+
from scipy.ndimage import distance_transform_edt
270+
from skimage.draw import line as draw_line
271+
from skimage.morphology import skeletonize
272+
rng = np.random.default_rng(rng_seed)
273+
H, W = shape
274+
margin = 20
275+
skeleton = np.zeros(shape, dtype=bool)
276+
generated = 0
277+
while generated < n_fibers:
278+
r0 = int(rng.integers(margin, H - margin))
279+
c0 = int(rng.integers(margin, W - margin))
280+
angle = rng.uniform(0, np.pi)
281+
length = int(rng.integers(60, min(H, W) - 2 * margin))
282+
r1 = int(np.clip(r0 + length * np.sin(angle), margin, H - margin))
283+
c1 = int(np.clip(c0 + length * np.cos(angle), margin, W - margin))
284+
if np.hypot(r1 - r0, c1 - c0) < 50:
285+
continue
286+
rr, cc = draw_line(r0, c0, r1, c1)
287+
skeleton[rr, cc] = True
288+
generated += 1
289+
dist = distance_transform_edt(~skeleton).astype(np.float32)
290+
signal = 180.0 * np.exp(-dist**2 / (2.0 * fiber_sigma**2))
291+
bg = rng.normal(10.0, 4.0, size=shape).astype(np.float32)
292+
image = np.clip(bg + signal, 0.0, 255.0).astype(np.float32)
293+
return image, skeletonize(skeleton)
294+
295+
296+
def _default_fire_params():
297+
"""
298+
Return the standard FIRE algorithm parameters from test_cases_fire_2d.json.
299+
300+
Uses the 'real1_ctfire_params' test case (the only one whose image is
301+
available; 2B_D9_ROI1.tif is not present in the repo).
302+
Intended for use with real biological images.
303+
"""
304+
config_path = (
305+
Path(__file__).parent
306+
/ "test_results"
307+
/ "fire_2d_test_files"
308+
/ "test_cases_fire_2d.json"
309+
)
310+
with open(config_path, "r") as f:
311+
cfg = json.load(f)
312+
case = next(c for c in cfg["test_cases"] if c["name"] == "real1_ctfire_params")
313+
return dict(case["params"])
314+
315+
316+
def _synthetic_fire_params():
317+
"""
318+
Return FIRE algorithm parameters tuned for synthetic test images.
319+
320+
Synthetic images have Gaussian-profile fibers (peak ~180, background ~10).
321+
The JSON default of thresh_im2=5 (absolute) includes nearly all background
322+
pixels at that level, producing hundreds of spurious short zigzag fibers.
323+
Using a fractional threshold (thresh_im=0.2, i.e. 20% of image maximum)
324+
raises the effective cutoff to ~36, cleanly separating fiber signal from
325+
background and suppressing spurious detections.
326+
"""
327+
p = _default_fire_params()
328+
p["thresh_im"] = 0.2 # fractional: keep pixels > 20% of max (~36 for peak 180)
329+
p["thresh_im2"] = [] # disable absolute threshold when thresh_im is set
330+
return p
331+
332+
189333
# ============================================================================
190334
# Basic Functionality Tests
191335
# ============================================================================
@@ -445,6 +589,162 @@ def test_fire_2d_matches_matlab_angles(test_name, test_case):
445589
print(f"\nAngle distribution correlation: {correlation:.3f}")
446590

447591

592+
# ============================================================================
593+
# Soft IoU Validation
594+
# ============================================================================
595+
596+
597+
class TestSoftIoU:
598+
"""
599+
Validate that fire_2d_angle recovers fiber centerlines at fiber scale.
600+
601+
Soft IoU at sigma=5 (≈10-px FWHM) measures spatial overlap of 1-px
602+
skeleton images after Gaussian smoothing — a score > 0.30 means the
603+
extracted skeleton substantially overlaps the ground truth at fiber
604+
granularity, not pixel precision.
605+
"""
606+
607+
def test_soft_iou_synthetic(self):
608+
"""
609+
Synthetic image with known GT skeleton — always runs, no .mat needed.
610+
611+
Validates:
612+
1. Soft IoU of extracted centerlines vs ground-truth skeleton > 0.30
613+
2. Total extracted fiber length is at least 20% of GT skeleton length
614+
3. Extracted fiber angles span a reasonable range (std > 0.3 rad),
615+
confirming the angle computation is not degenerate
616+
"""
617+
if not CPP_AVAILABLE:
618+
pytest.skip("C++ backend not available")
619+
620+
image, gt_skeleton = _make_synthetic_fiber_image(rng_seed=42)
621+
p = _synthetic_fire_params()
622+
data = fire_2d_angle(p=p, im=image, plotflag=0)
623+
624+
assert len(data["Ff"]) > 0, "no filtered fibers extracted"
625+
626+
# 1. Soft IoU
627+
pred = _rasterize_fibers(data["Xf"], data["Ff"], image.shape)
628+
iou = _soft_iou(_smooth_mask(gt_skeleton.astype(np.float32)),
629+
_smooth_mask(pred.astype(np.float32)))
630+
assert iou > SOFT_IOU_THRESHOLD_SYNTHETIC, (
631+
f"soft IoU {iou:.3f} < {SOFT_IOU_THRESHOLD_SYNTHETIC} — "
632+
"extracted centerlines do not overlap ground truth at fiber scale"
633+
)
634+
635+
# 2. Total extracted length vs GT skeleton pixel count (arc-length proxy)
636+
tot_L = float(data["M"]["totL"])
637+
assert tot_L > 0, "total extracted fiber length is zero"
638+
gt_length = float(gt_skeleton.sum())
639+
assert tot_L > 0.2 * gt_length, (
640+
f"extracted total length {tot_L:.1f} px < 20% of GT skeleton "
641+
f"length {gt_length:.1f} px — pipeline may be discarding too many fibers"
642+
)
643+
644+
# 3. Angle range sanity check
645+
angles = np.asarray(data["M"].get("angle_xy", []))
646+
assert len(angles) > 0, "no fiber angles in M['angle_xy']"
647+
assert np.all(np.abs(angles) <= np.pi + 1e-6), "angle value outside [-π, π]"
648+
assert angles.std() > 0.3, (
649+
f"angle std {angles.std():.3f} rad unexpectedly small — "
650+
"all fibers have nearly the same orientation on a random image"
651+
)
652+
653+
@pytest.mark.parametrize("seed", [0, 7, 99])
654+
def test_soft_iou_multiple_seeds(self, seed):
655+
"""
656+
Soft IoU > 0.30 across multiple random synthetic configurations.
657+
658+
Guards against a lucky pass on a single seed by checking that the
659+
pipeline recovers fibers consistently across different random images.
660+
"""
661+
if not CPP_AVAILABLE:
662+
pytest.skip("C++ backend not available")
663+
664+
image, gt_skeleton = _make_synthetic_fiber_image(rng_seed=seed)
665+
p = _synthetic_fire_params()
666+
data = fire_2d_angle(p=p, im=image, plotflag=0)
667+
668+
assert len(data["Ff"]) > 0, f"no fibers extracted (seed={seed})"
669+
pred = _rasterize_fibers(data["Xf"], data["Ff"], image.shape)
670+
iou = _soft_iou(_smooth_mask(gt_skeleton.astype(np.float32)),
671+
_smooth_mask(pred.astype(np.float32)))
672+
assert iou > SOFT_IOU_THRESHOLD_SYNTHETIC, (
673+
f"soft IoU {iou:.3f} < {SOFT_IOU_THRESHOLD_SYNTHETIC} (seed={seed})"
674+
)
675+
676+
@pytest.mark.matlab
677+
@pytest.mark.parametrize(
678+
"test_name,test_case",
679+
load_test_cases(matlab_only=True),
680+
ids=[name for name, _ in load_test_cases(matlab_only=True)],
681+
)
682+
def test_soft_iou_matlab_vs_python(self, test_name, test_case):
683+
"""
684+
Compare Python centerlines against MATLAB Xa/Fa via soft IoU.
685+
686+
Skips gracefully if the .mat reference file is absent.
687+
Also checks total length (within ±40%) and mean |angle| (within 10°).
688+
"""
689+
if not CPP_AVAILABLE:
690+
pytest.skip("C++ backend not available")
691+
692+
# Skip if test image is not present (e.g. 2B_D9_ROI1.tif is not in the repo)
693+
img_path = Path(__file__).parent / "test_images" / test_case["image"]
694+
if not img_path.exists():
695+
pytest.skip(f"Test image not found: {test_case['image']}")
696+
697+
# Load image and run Python extraction
698+
img = load_test_image(test_case["image"])
699+
if img.ndim == 2:
700+
img = img[np.newaxis, :, :]
701+
702+
data_py = fire_2d_angle(p=test_case["params"], im=img, plotflag=0)
703+
704+
# Load MATLAB reference (skips if file absent)
705+
mat_path = (
706+
Path(__file__).parent
707+
/ "test_results"
708+
/ "fire_2d_test_files"
709+
/ test_case["matlab_reference_mat"]
710+
)
711+
data_mat = load_matlab_reference(mat_path)
712+
713+
image_2d = img[0] if img.ndim == 3 else img
714+
H, W = image_2d.shape
715+
716+
# Soft IoU — rasterize MATLAB Xa/Fa vs Python Xf/Ff
717+
if data_mat["Xa"] is not None and data_mat.get("Fa") is not None:
718+
mat_skel = _rasterize_fibers(data_mat["Xa"], data_mat["Fa"], (H, W))
719+
py_skel = _rasterize_fibers(data_py["Xf"], data_py["Ff"], (H, W))
720+
iou = _soft_iou(_smooth_mask(mat_skel.astype(np.float32)),
721+
_smooth_mask(py_skel.astype(np.float32)))
722+
assert iou > SOFT_IOU_THRESHOLD_MATLAB, (
723+
f"soft IoU {iou:.3f} < {SOFT_IOU_THRESHOLD_MATLAB} "
724+
f"({test_name}): Python and MATLAB centerlines diverge spatially"
725+
)
726+
print(f"\n{test_name} soft IoU = {iou:.3f}")
727+
728+
# Total length within 40% of MATLAB reference
729+
mat_totL = data_mat["M"].get("totL", 0)
730+
py_totL = float(data_py["M"]["totL"])
731+
if mat_totL > 0 and py_totL > 0:
732+
assert 0.6 * mat_totL <= py_totL <= 1.4 * mat_totL, (
733+
f"total length {py_totL:.1f} not within 40% of MATLAB {mat_totL:.1f}"
734+
)
735+
736+
# Mean |angle| within 10° of MATLAB reference
737+
mat_angles = np.asarray(data_mat["M"].get("angle_xy", []))
738+
py_angles = np.asarray(data_py["M"].get("angle_xy", []))
739+
if len(mat_angles) > 0 and len(py_angles) > 0:
740+
delta_deg = np.degrees(
741+
abs(np.mean(np.abs(py_angles)) - np.mean(np.abs(mat_angles)))
742+
)
743+
assert delta_deg < 10.0, (
744+
f"mean |angle| differs by {delta_deg:.1f}° > 10° ({test_name})"
745+
)
746+
747+
448748
# ============================================================================
449749
# Utility Tests
450750
# ============================================================================
@@ -514,5 +814,32 @@ def test_implementation_status_documented():
514814

515815

516816
if __name__ == "__main__":
517-
# Run tests with verbose output
518-
pytest.main([__file__, "-v", "-s", "--tb=short"])
817+
# Standalone demo: generate synthetic image, run FIRE, print metrics, save overlay.
818+
# Run with: python tests/test_fire_2d_angle.py
819+
from ctfire_py.test_fire_2d import plot_fiber_overlay
820+
821+
image, gt_skeleton = _make_synthetic_fiber_image(rng_seed=42)
822+
p = _synthetic_fire_params()
823+
data = fire_2d_angle(p=p, im=image, plotflag=0)
824+
825+
pred = _rasterize_fibers(data["Xf"], data["Ff"], image.shape)
826+
iou = _soft_iou(_smooth_mask(gt_skeleton.astype(np.float32)),
827+
_smooth_mask(pred.astype(np.float32)))
828+
829+
angles = np.asarray(data["M"].get("angle_xy", []))
830+
mean_angle_deg = float(np.degrees(np.mean(np.abs(angles)))) if len(angles) > 0 else float("nan")
831+
832+
print(
833+
f"Soft IoU = {iou:.4f} (threshold {SOFT_IOU_THRESHOLD_SYNTHETIC})\n"
834+
f"Fibers = {len(data['Ff'])}\n"
835+
f"Total length= {data['M']['totL']:.1f} px\n"
836+
f"Mean |angle|= {mean_angle_deg:.1f}°"
837+
)
838+
839+
plot_fiber_overlay(
840+
image,
841+
data["Xf"],
842+
data["Ff"],
843+
title=f"FIRE overlay (soft IoU={iou:.3f})",
844+
save_path="fire_2d_soft_iou_overlay.png",
845+
)

0 commit comments

Comments
 (0)