Skip to content

Commit 90f4a57

Browse files
committed
enh: add test checking stacked models
Related: #12.
1 parent 24efee4 commit 90f4a57

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

src/nifreeze/registration/ants.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _prepare_registration_data(
9797
affine: np.ndarray,
9898
vol_idx: int,
9999
dirname: Path | str,
100-
clip: str | None = None,
100+
clip: str | bool | None = None,
101101
init_affine: np.ndarray | None = None,
102102
) -> tuple[Path, Path, Path | None]:
103103
"""
@@ -129,20 +129,20 @@ def _prepare_registration_data(
129129
An initialization affine (for second and further estimators).
130130
131131
"""
132-
clip = clip or "none"
132+
133133
predicted_path = Path(dirname) / f"predicted_{vol_idx:05d}.nii.gz"
134134
sample_path = Path(dirname) / f"sample_{vol_idx:05d}.nii.gz"
135135
_to_nifti(
136136
sample,
137137
affine,
138138
sample_path,
139-
clip=clip.lower() in ("sample", "both"),
139+
clip=str(clip).lower() in ("sample", "both", "true"),
140140
)
141141
_to_nifti(
142142
predicted,
143143
affine,
144144
predicted_path,
145-
clip=clip.lower() in ("predicted", "both"),
145+
clip=str(clip).lower() in ("predicted", "both", "true"),
146146
)
147147

148148
init_path = None
@@ -463,10 +463,11 @@ def _run_registration(
463463
registration = generate_command(
464464
fixed_path,
465465
moving_path,
466-
terminal_output="file",
467466
environ=environ,
468467
**align_kwargs,
469468
)
469+
registration.terminal_output = "file"
470+
470471
if num_threads:
471472
registration.inputs.num_threads = num_threads
472473

test/test_integration.py

+20
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,23 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path):
9595
).max()
9696
< 0.25
9797
)
98+
99+
100+
def test_stacked_estimators(datadir):
101+
"""Check that models can be stacked."""
102+
103+
# Load test data
104+
dmri_dataset = DWI.from_filename(datadir / "dwi.h5")
105+
106+
estimator1 = Estimator(
107+
TrivialModel(dmri_dataset),
108+
ants_config="dwi-to-dwi_level0.json",
109+
clip=False,
110+
)
111+
estimator2 = Estimator(
112+
TrivialModel(dmri_dataset),
113+
prev=estimator1,
114+
clip=False,
115+
)
116+
117+
estimator2.run(dmri_dataset)

0 commit comments

Comments
 (0)