|
29 | 29 | import nibabel as nb |
30 | 30 | import numpy as np |
31 | 31 | import pytest |
| 32 | +from dipy.core.geometry import normalized_vector |
32 | 33 |
|
33 | 34 | from nifreeze.data import load |
34 | 35 | from nifreeze.data.dmri.base import ( |
|
39 | 40 | GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG, |
40 | 41 | GRADIENT_DATA_MISSING_ERROR, |
41 | 42 | from_nii, |
| 43 | + to_nifti, |
42 | 44 | ) |
43 | 45 | from nifreeze.data.dmri.utils import ( |
| 46 | + DEFAULT_LOWB_THRESHOLD, |
44 | 47 | DTI_MIN_ORIENTATIONS, |
45 | 48 | GRADIENT_ABSENCE_ERROR_MSG, |
46 | 49 | GRADIENT_EXPECTED_COLUMNS_ERROR_MSG, |
@@ -109,6 +112,89 @@ def _serialize_dwi_data( |
109 | 112 | ) |
110 | 113 |
|
111 | 114 |
|
| 115 | +@pytest.mark.parametrize("vol_size", [(11, 11, 7)]) |
| 116 | +@pytest.mark.parametrize("b0_count", [0, 1]) |
| 117 | +@pytest.mark.parametrize("bval_min, bval_max", [(800.0, 1200.0)]) |
| 118 | +@pytest.mark.parametrize("provide_bzero", [False, True]) |
| 119 | +def test_dwi_post_init_b0_handling(request, vol_size, b0_count, bval_min, bval_max, provide_bzero): |
| 120 | + """Check b0 handling when instantiating the DWI class. |
| 121 | +
|
| 122 | + For each parameter combination: |
| 123 | + - Build a gradient table whose first `b0_count` volumes have b=0 |
| 124 | + and the rest have b-values in the range (bvalmin, bvalmax); |
| 125 | + - Build a random dataobj of shape (**vol_size, N) where N is the number |
| 126 | + of DWI volumes; |
| 127 | + - If `provide_bzero` is True, pass explicit bzero data that must be |
| 128 | + preserved; else, rely on the bzero computed at instantiation, i.e. |
| 129 | + if a single bzero is provided, set the attribute to that value; if there |
| 130 | + are multiple bzeros, set the attribute to the median value. |
| 131 | + """ |
| 132 | + rng = request.node.rng |
| 133 | + |
| 134 | + # Choose n_vols safely above the minimum DTI orientations |
| 135 | + n_vols = max(10, DTI_MIN_ORIENTATIONS + 2) |
| 136 | + |
| 137 | + # Build b-values array: first b0_count are zeros |
| 138 | + non_b0_count = n_vols - b0_count |
| 139 | + # Sample non-b0 bvals between min and max values |
| 140 | + rest_bvals = rng.uniform(bval_min, bval_max, size=non_b0_count) |
| 141 | + bvals = np.concatenate((np.zeros(b0_count), rest_bvals)).astype(int) |
| 142 | + |
| 143 | + # Create bvecs and assemble gradients |
| 144 | + bzeros = np.zeros((b0_count, 3)) |
| 145 | + bvecs = normalized_vector(rng.random((3, non_b0_count)), axis=0).T |
| 146 | + bvecs = np.vstack((bzeros, bvecs)) |
| 147 | + gradients = np.column_stack((bvecs, bvals)) |
| 148 | + |
| 149 | + # Create random dataobj with shape |
| 150 | + dataobj = rng.standard_normal((*vol_size, n_vols)).astype(float) |
| 151 | + |
| 152 | + # Optionally supply a bzero |
| 153 | + provided = None |
| 154 | + affine = np.eye(4) |
| 155 | + if provide_bzero: |
| 156 | + # Use a constant map so it's easy to assert equality |
| 157 | + provided = np.full((*vol_size, max(1, b0_count)), 42.0, dtype=float).squeeze() |
| 158 | + dwi_obj = DWI(dataobj=dataobj, affine=affine, gradients=gradients, bzero=provided) |
| 159 | + else: |
| 160 | + dwi_obj = DWI(dataobj=dataobj, affine=affine, gradients=gradients) |
| 161 | + |
| 162 | + # Count expected b0 frames according to the same threshold used by the code |
| 163 | + b0_mask = bvals <= DEFAULT_LOWB_THRESHOLD |
| 164 | + expected_b0_num = int(np.sum(b0_mask)) |
| 165 | + # In all cases where b0 frames existed (whether provided externally or not), |
| 166 | + # they should have been removed from the DWI object's internal gradients and |
| 167 | + # dataobj arrays |
| 168 | + expected_non_b0_count = n_vols - expected_b0_num |
| 169 | + |
| 170 | + # If no b0 frames expected, bzero should be None (unless user provided one) |
| 171 | + if expected_b0_num == 0 and not provide_bzero: |
| 172 | + assert dwi_obj.bzero is None, ( |
| 173 | + "Expected bzero to be None when no low-b frames and no provided bzero" |
| 174 | + ) |
| 175 | + else: |
| 176 | + assert dwi_obj.bzero is not None |
| 177 | + # If provided_bzero is True, it must be preserved exactly |
| 178 | + if provide_bzero: |
| 179 | + assert provided is not None |
| 180 | + assert np.allclose(dwi_obj.bzero, provided) |
| 181 | + else: |
| 182 | + # When there are b0 frames and no provided bzero: |
| 183 | + # - If exactly one b0 frame, the stored bzero should be the 3D volume |
| 184 | + # - If multiple b0 frames, the stored bzero should be the median along last axis |
| 185 | + b0_vols = dataobj[ |
| 186 | + ..., b0_mask |
| 187 | + ].squeeze() # shape (X,Y,Z,expected_b0_num) or (X,Y,Z) if 1 |
| 188 | + expected_bzero = b0_vols if b0_vols.ndim == 3 else np.median(b0_vols, axis=-1) |
| 189 | + assert np.allclose(dwi_obj.bzero, expected_bzero) |
| 190 | + |
| 191 | + assert dwi_obj.gradients.shape[0] == expected_non_b0_count |
| 192 | + assert dwi_obj.dataobj.shape[-1] == expected_non_b0_count |
| 193 | + |
| 194 | + assert np.allclose(dwi_obj.gradients, gradients[~b0_mask]) |
| 195 | + assert np.allclose(dwi_obj.dataobj, dataobj[..., ~b0_mask]) |
| 196 | + |
| 197 | + |
112 | 198 | def test_main(datadir): |
113 | 199 | input_file = datadir / "dwi.h5" |
114 | 200 |
|
@@ -176,8 +262,20 @@ def test_format_gradients_basic(value, expect_transpose): |
176 | 262 | assert np.allclose(obtained, np.asarray(value)) |
177 | 263 |
|
178 | 264 |
|
179 | | -@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) |
180 | | -def test_dwi_post_init_errors(setup_random_uniform_spatial_data): |
| 265 | +@pytest.mark.parametrize( |
| 266 | + "case_mark", |
| 267 | + [ |
| 268 | + pytest.param( |
| 269 | + None, |
| 270 | + marks=pytest.mark.random_uniform_spatial_data((2, 2, 2, 6), 0.0, 1.0), |
| 271 | + ), |
| 272 | + pytest.param( |
| 273 | + None, |
| 274 | + marks=pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0), |
| 275 | + ), |
| 276 | + ], |
| 277 | +) |
| 278 | +def test_dwi_post_init_errors(setup_random_uniform_spatial_data, case_mark): |
181 | 279 | data, affine = setup_random_uniform_spatial_data |
182 | 280 | with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG): |
183 | 281 | DWI(dataobj=data, affine=affine) |
|
0 commit comments