Skip to content

Commit c9fac88

Browse files
committed
TST: Test comprehensively dMRI b0 volume handling
Test comprehensively dMRI b0 volume handling: - Add a test that checks the behavior of the DWI class `__attrs_post_init__` method that handles setting the `bzero` attribute and masking the `dataobj` and gradient data corresponding to the b=0 volumes. - Further parametrize the post init error handling function to check the behavior of the `__attrs_post_init__` method in a broader range of values. Specifically, check that the initialization fails when the number of provided volumes is right at the limit of the number of the required DTI orientations.
1 parent 301429b commit c9fac88

1 file changed

Lines changed: 100 additions & 2 deletions

File tree

test/test_data_dmri.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import nibabel as nb
3030
import numpy as np
3131
import pytest
32+
from dipy.core.geometry import normalized_vector
3233

3334
from nifreeze.data import load
3435
from nifreeze.data.dmri.base import (
@@ -39,8 +40,10 @@
3940
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG,
4041
GRADIENT_DATA_MISSING_ERROR,
4142
from_nii,
43+
to_nifti,
4244
)
4345
from nifreeze.data.dmri.utils import (
46+
DEFAULT_LOWB_THRESHOLD,
4447
DTI_MIN_ORIENTATIONS,
4548
GRADIENT_ABSENCE_ERROR_MSG,
4649
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG,
@@ -109,6 +112,89 @@ def _serialize_dwi_data(
109112
)
110113

111114

115+
@pytest.mark.parametrize("vol_size", [(11, 11, 7)])
116+
@pytest.mark.parametrize("b0_count", [0, 1])
117+
@pytest.mark.parametrize("bval_min, bval_max", [(800.0, 1200.0)])
118+
@pytest.mark.parametrize("provide_bzero", [False, True])
119+
def test_dwi_post_init_b0_handling(request, vol_size, b0_count, bval_min, bval_max, provide_bzero):
120+
"""Check b0 handling when instantiating the DWI class.
121+
122+
For each parameter combination:
123+
- Build a gradient table whose first `b0_count` volumes have b=0
124+
and the rest have b-values in the range (bvalmin, bvalmax);
125+
- Build a random dataobj of shape (**vol_size, N) where N is the number
126+
of DWI volumes;
127+
- If `provide_bzero` is True, pass explicit bzero data that must be
128+
preserved; else, rely on the bzero computed at instantiation, i.e.
129+
if a single bzero is provided, set the attribute to that value; if there
130+
are multiple bzeros, set the attribute to the median value.
131+
"""
132+
rng = request.node.rng
133+
134+
# Choose n_vols safely above the minimum DTI orientations
135+
n_vols = max(10, DTI_MIN_ORIENTATIONS + 2)
136+
137+
# Build b-values array: first b0_count are zeros
138+
non_b0_count = n_vols - b0_count
139+
# Sample non-b0 bvals between min and max values
140+
rest_bvals = rng.uniform(bval_min, bval_max, size=non_b0_count)
141+
bvals = np.concatenate((np.zeros(b0_count), rest_bvals)).astype(int)
142+
143+
# Create bvecs and assemble gradients
144+
bzeros = np.zeros((b0_count, 3))
145+
bvecs = normalized_vector(rng.random((3, non_b0_count)), axis=0).T
146+
bvecs = np.vstack((bzeros, bvecs))
147+
gradients = np.column_stack((bvecs, bvals))
148+
149+
# Create random dataobj with shape
150+
dataobj = rng.standard_normal((*vol_size, n_vols)).astype(float)
151+
152+
# Optionally supply a bzero
153+
provided = None
154+
affine = np.eye(4)
155+
if provide_bzero:
156+
# Use a constant map so it's easy to assert equality
157+
provided = np.full((*vol_size, max(1, b0_count)), 42.0, dtype=float).squeeze()
158+
dwi_obj = DWI(dataobj=dataobj, affine=affine, gradients=gradients, bzero=provided)
159+
else:
160+
dwi_obj = DWI(dataobj=dataobj, affine=affine, gradients=gradients)
161+
162+
# Count expected b0 frames according to the same threshold used by the code
163+
b0_mask = bvals <= DEFAULT_LOWB_THRESHOLD
164+
expected_b0_num = int(np.sum(b0_mask))
165+
# In all cases where b0 frames existed (whether provided externally or not),
166+
# they should have been removed from the DWI object's internal gradients and
167+
# dataobj arrays
168+
expected_non_b0_count = n_vols - expected_b0_num
169+
170+
# If no b0 frames expected, bzero should be None (unless user provided one)
171+
if expected_b0_num == 0 and not provide_bzero:
172+
assert dwi_obj.bzero is None, (
173+
"Expected bzero to be None when no low-b frames and no provided bzero"
174+
)
175+
else:
176+
assert dwi_obj.bzero is not None
177+
# If provided_bzero is True, it must be preserved exactly
178+
if provide_bzero:
179+
assert provided is not None
180+
assert np.allclose(dwi_obj.bzero, provided)
181+
else:
182+
# When there are b0 frames and no provided bzero:
183+
# - If exactly one b0 frame, the stored bzero should be the 3D volume
184+
# - If multiple b0 frames, the stored bzero should be the median along last axis
185+
b0_vols = dataobj[
186+
..., b0_mask
187+
].squeeze() # shape (X,Y,Z,expected_b0_num) or (X,Y,Z) if 1
188+
expected_bzero = b0_vols if b0_vols.ndim == 3 else np.median(b0_vols, axis=-1)
189+
assert np.allclose(dwi_obj.bzero, expected_bzero)
190+
191+
assert dwi_obj.gradients.shape[0] == expected_non_b0_count
192+
assert dwi_obj.dataobj.shape[-1] == expected_non_b0_count
193+
194+
assert np.allclose(dwi_obj.gradients, gradients[~b0_mask])
195+
assert np.allclose(dwi_obj.dataobj, dataobj[..., ~b0_mask])
196+
197+
112198
def test_main(datadir):
113199
input_file = datadir / "dwi.h5"
114200

@@ -176,8 +262,20 @@ def test_format_gradients_basic(value, expect_transpose):
176262
assert np.allclose(obtained, np.asarray(value))
177263

178264

179-
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
180-
def test_dwi_post_init_errors(setup_random_uniform_spatial_data):
265+
@pytest.mark.parametrize(
266+
"case_mark",
267+
[
268+
pytest.param(
269+
None,
270+
marks=pytest.mark.random_uniform_spatial_data((2, 2, 2, 6), 0.0, 1.0),
271+
),
272+
pytest.param(
273+
None,
274+
marks=pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0),
275+
),
276+
],
277+
)
278+
def test_dwi_post_init_errors(setup_random_uniform_spatial_data, case_mark):
181279
data, affine = setup_random_uniform_spatial_data
182280
with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG):
183281
DWI(dataobj=data, affine=affine)

0 commit comments

Comments
 (0)