Skip to content

Commit 6697c78

Browse files
committed
enh: add test checking stacked models
Related: #12.
1 parent 24efee4 commit 6697c78

File tree

4 files changed

+112
-51
lines changed

4 files changed

+112
-51
lines changed

src/nifreeze/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
117117
)
118118

119119
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
120+
kwargs = self._align_kwargs | kwargs
120121

121122
dataset_length = len(dataset)
122123
with TemporaryDirectory() as tmp_dir:

src/nifreeze/registration/ants.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _prepare_registration_data(
9797
affine: np.ndarray,
9898
vol_idx: int,
9999
dirname: Path | str,
100-
clip: str | None = None,
100+
clip: str | bool | None = None,
101101
init_affine: np.ndarray | None = None,
102102
) -> tuple[Path, Path, Path | None]:
103103
"""
@@ -129,20 +129,21 @@ def _prepare_registration_data(
129129
An initialization affine (for second and further estimators).
130130
131131
"""
132-
clip = clip or "none"
132+
133133
predicted_path = Path(dirname) / f"predicted_{vol_idx:05d}.nii.gz"
134134
sample_path = Path(dirname) / f"sample_{vol_idx:05d}.nii.gz"
135+
135136
_to_nifti(
136137
sample,
137138
affine,
138139
sample_path,
139-
clip=clip.lower() in ("sample", "both"),
140+
clip=str(clip).lower() in ("sample", "both", "true"),
140141
)
141142
_to_nifti(
142143
predicted,
143144
affine,
144145
predicted_path,
145-
clip=clip.lower() in ("predicted", "both"),
146+
clip=str(clip).lower() in ("predicted", "both", "true"),
146147
)
147148

148149
init_path = None
@@ -232,6 +233,9 @@ def generate_command(
232233
movingmask_path: str | Path | list[str] | None = None,
233234
init_affine: str | Path | None = None,
234235
default: str = "b0-to-b0_level0",
236+
terminal_output: str | None = None,
237+
num_threads: int | None = None,
238+
environ: dict | None = None,
235239
**kwargs,
236240
) -> Registration:
237241
"""
@@ -251,6 +255,12 @@ def generate_command(
251255
Initial affine transformation.
252256
default : :obj:`str`, optional
253257
Default settings configuration.
258+
terminal_output : :obj:`str`, optional
259+
Redirect terminal output (Nipype configuration)
260+
environ : :obj:`dict`, optional
261+
Add environment variables to the execution.
262+
num_threads : :obj:`int`, optional
263+
Set the number of threads for ANTs' execution.
254264
**kwargs : :obj:`dict`
255265
Additional parameters for ANTs registration.
256266
@@ -413,11 +423,17 @@ def generate_command(
413423
settings["initial_moving_transform"] = str(init_affine)
414424

415425
# Generate command line with nipype and return
416-
return Registration(
426+
reg_iface = Registration(
417427
fixed_image=str(Path(fixed_path).absolute()),
418428
moving_image=str(Path(moving_path).absolute()),
429+
terminal_output=terminal_output,
430+
environ=environ or {},
419431
**settings,
420432
)
433+
if num_threads:
434+
reg_iface.inputs.num_threads = num_threads
435+
436+
return reg_iface
421437

422438

423439
def _run_registration(
@@ -451,10 +467,11 @@ def _run_registration(
451467
"""
452468

453469
align_kwargs = kwargs.copy()
454-
environ = align_kwargs.pop("environ", {})
470+
environ = align_kwargs.pop("environ", None)
455471
num_threads = align_kwargs.pop("num_threads", None)
456472

457473
if (seed := align_kwargs.pop("seed", None)) is not None:
474+
environ = environ or {}
458475
environ["ANTS_RANDOM_SEED"] = str(seed)
459476

460477
if "ants_config" in kwargs:
@@ -463,12 +480,11 @@ def _run_registration(
463480
registration = generate_command(
464481
fixed_path,
465482
moving_path,
466-
terminal_output="file",
467483
environ=environ,
484+
terminal_output="file_split",
485+
num_threads=num_threads,
468486
**align_kwargs,
469487
)
470-
if num_threads:
471-
registration.inputs.num_threads = num_threads
472488

473489
(dirname / f"cmd-{vol_idx:05d}.sh").write_text(registration.cmdline)
474490

test/conftest.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@
2626
from pathlib import Path
2727

2828
import nibabel as nb
29+
import nitransforms as nt
2930
import numpy as np
3031
import pytest
3132

33+
from nifreeze.data.dmri import DWI
34+
3235
test_data_env = os.getenv("TEST_DATA_HOME", str(Path.home() / "nifreeze-tests"))
3336
test_output_dir = os.getenv("TEST_OUTPUT_DIR")
3437
test_workdir = os.getenv("TEST_WORK_DIR")
@@ -54,19 +57,19 @@ def doctest_imports(doctest_namespace):
5457
doctest_namespace["repodata"] = _datadir
5558

5659

57-
@pytest.fixture
60+
@pytest.fixture(scope="session")
5861
def outdir():
5962
"""Determine if test artifacts should be stored somewhere or deleted."""
6063
return None if test_output_dir is None else Path(test_output_dir)
6164

6265

63-
@pytest.fixture
66+
@pytest.fixture(scope="session")
6467
def datadir():
6568
"""Return a data path outside the package's structure (i.e., large datasets)."""
6669
return Path(test_data_env)
6770

6871

69-
@pytest.fixture
72+
@pytest.fixture(scope="session")
7073
def repodata():
7174
"""Return the path to this repository's test data folder."""
7275
return _datadir
@@ -80,6 +83,59 @@ def pytest_addoption(parser):
8083
)
8184

8285

86+
@pytest.fixture(scope="session")
87+
def motion_data(tmp_path_factory, datadir):
88+
# Temporary directory for session-scoped fixtures
89+
tmp_path = tmp_path_factory.mktemp("motion_test_data")
90+
91+
dwdata = DWI.from_filename(datadir / "dwi.h5")
92+
b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None)
93+
masknii = nb.Nifti1Image(dwdata.brainmask.astype("uint8"), dwdata.affine, None)
94+
95+
# Generate a list of large-yet-plausible bulk-head motion
96+
xfms = nt.linear.LinearTransformsMapping(
97+
[
98+
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.03, z=0.005), (0.8, 0.2, 0.2)),
99+
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.005), (0.8, 0.2, 0.2)),
100+
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.02), (0.4, 0.2, 0.2)),
101+
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.02), (0.4, 0.2, 0.2)),
102+
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.002), (0.0, 0.2, 0.2)),
103+
nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.02, z=0.002), (0.0, 0.2, 0.2)),
104+
nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.01, z=0.002), (0.0, 0.4, 0.2)),
105+
],
106+
reference=b0nii,
107+
)
108+
109+
# Induce motion into dataset (i.e., apply the inverse transforms)
110+
moved_nii = (~xfms).apply(b0nii, reference=b0nii)
111+
112+
# Save the moved dataset for debugging or further processing
113+
moved_path = tmp_path / "test.nii.gz"
114+
ground_truth_path = tmp_path / "ground_truth.nii.gz"
115+
moved_nii.to_filename(moved_path)
116+
xfms.apply(moved_nii).to_filename(ground_truth_path)
117+
118+
# Wrap into dataset object
119+
dwi_motion = DWI(
120+
dataobj=np.asanyarray(moved_nii.dataobj),
121+
affine=b0nii.affine,
122+
bzero=dwdata.bzero,
123+
gradients=dwdata.gradients[..., : len(xfms)],
124+
brainmask=dwdata.brainmask,
125+
)
126+
127+
# Return data as a dictionary (or any format that makes sense for your tests)
128+
return {
129+
"b0nii": b0nii,
130+
"masknii": masknii,
131+
"moved_nii": moved_nii,
132+
"xfms": xfms,
133+
"moved_path": moved_path,
134+
"ground_truth_path": ground_truth_path,
135+
"moved_nifreeze": dwi_motion,
136+
}
137+
138+
83139
@pytest.hookimpl(trylast=True)
84140
def pytest_sessionfinish(session, exitstatus):
85141
have_werrors = os.getenv("NIFREEZE_WERRORS", False)

test/test_integration.py

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,54 +24,22 @@
2424

2525
from os import cpu_count
2626

27-
import nibabel as nb
2827
import nitransforms as nt
29-
import numpy as np
3028

31-
from nifreeze.data.dmri import DWI
3229
from nifreeze.estimator import Estimator
3330
from nifreeze.model.base import TrivialModel
3431
from nifreeze.registration.utils import displacements_within_mask
3532

3633

37-
def test_proximity_estimator_trivial_model(datadir, tmp_path):
34+
def test_proximity_estimator_trivial_model(motion_data, tmp_path):
3835
"""Check the proximity of transforms estimated by the estimator with a trivial B0 model."""
3936

40-
dwdata = DWI.from_filename(datadir / "dwi.h5")
41-
b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None)
42-
masknii = nb.Nifti1Image(dwdata.brainmask.astype(np.uint8), dwdata.affine, None)
37+
b0nii = motion_data["b0nii"]
38+
moved_nii = motion_data["moved_nii"]
39+
xfms = motion_data["xfms"]
40+
dwi_motion = motion_data["moved_nifreeze"]
4341

44-
# Generate a list of large-yet-plausible bulk-head motion.
45-
xfms = nt.linear.LinearTransformsMapping(
46-
[
47-
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.03, z=0.005), (0.8, 0.2, 0.2)),
48-
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.005), (0.8, 0.2, 0.2)),
49-
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.02), (0.4, 0.2, 0.2)),
50-
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.02), (0.4, 0.2, 0.2)),
51-
nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.002), (0.0, 0.2, 0.2)),
52-
nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.02, z=0.002), (0.0, 0.2, 0.2)),
53-
nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.01, z=0.002), (0.0, 0.4, 0.2)),
54-
],
55-
reference=b0nii,
56-
)
57-
58-
# Induce motion into dataset (i.e., apply the inverse transforms)
59-
moved_nii = (~xfms).apply(b0nii, reference=b0nii)
60-
61-
# Uncomment to see the moved dataset
62-
moved_nii.to_filename(tmp_path / "test.nii.gz")
63-
xfms.apply(moved_nii).to_filename(tmp_path / "ground_truth.nii.gz")
64-
65-
# Wrap into dataset object
66-
dwi_motion = DWI(
67-
dataobj=moved_nii.dataobj,
68-
affine=b0nii.affine,
69-
bzero=dwdata.bzero,
70-
gradients=dwdata.gradients[..., : len(xfms)],
71-
brainmask=dwdata.brainmask,
72-
)
73-
74-
model = TrivialModel(dwdata)
42+
model = TrivialModel(dwi_motion)
7543
estimator = Estimator(model)
7644
estimator.run(
7745
dwi_motion,
@@ -89,9 +57,29 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path):
8957
for i, est in enumerate(dwi_motion.motion_affines):
9058
assert (
9159
displacements_within_mask(
92-
masknii,
60+
motion_data["masknii"],
9361
nt.linear.Affine(est),
9462
xfms[i],
9563
).max()
9664
< 0.25
9765
)
66+
67+
68+
def test_stacked_estimators(motion_data):
69+
"""Check that models can be stacked."""
70+
71+
# Wrap into dataset object
72+
dmri_dataset = motion_data["moved_nifreeze"]
73+
74+
estimator1 = Estimator(
75+
TrivialModel(dmri_dataset),
76+
ants_config="dwi-to-dwi_level0.json",
77+
clip=False,
78+
)
79+
estimator2 = Estimator(
80+
TrivialModel(dmri_dataset),
81+
prev=estimator1,
82+
clip=False,
83+
)
84+
85+
estimator2.run(dmri_dataset)

0 commit comments

Comments
 (0)