|
25 | 25 | import nibabel as nb
|
26 | 26 | import numpy as np
|
27 | 27 | import pytest
|
| 28 | +from dipy.io.gradients import read_bvals_bvecs |
28 | 29 |
|
29 |
| -from nifreeze.data.dmri import find_shelling_scheme, load |
| 30 | +from nifreeze.data.dmri import DWI, find_shelling_scheme, load |
30 | 31 |
|
31 | 32 |
|
32 |
| -def _create_dwi_random_dataobj(request): |
| 33 | +def _create_random_gtab_dataobj(request, n_gradients=10, b0s=1): |
33 | 34 | rng = request.node.rng
|
34 | 35 |
|
35 |
| - n_gradients = 10 |
36 |
| - b0s = 1 |
37 |
| - volumes = n_gradients + b0s |
38 |
| - b0_thres = 50 |
39 | 36 | bvals = np.hstack([b0s * [0], n_gradients * [1000]])
|
40 | 37 | bvecs = np.hstack([np.zeros((3, b0s)), rng.random((3, n_gradients))])
|
41 | 38 |
|
| 39 | + return bvals, bvecs |
| 40 | + |
| 41 | + |
| 42 | +def _create_dwi_random_dataobj(request, bvals, bvecs): |
| 43 | + rng = request.node.rng |
| 44 | + |
| 45 | + n_gradients = np.count_nonzero(bvecs) |
| 46 | + b0s = len(bvals) - n_gradients |
| 47 | + volumes = n_gradients + b0s |
| 48 | + b0_thres = 50 |
| 49 | + |
42 | 50 | vol_size = (34, 36, 24)
|
43 | 51 |
|
44 | 52 | dwi_dataobj = rng.random((*vol_size, volumes), dtype="float32")
|
@@ -136,14 +144,16 @@ def test_load(datadir, tmp_path):
|
136 | 144 |
|
137 | 145 | def test_equality_operator(tmp_path, request):
|
138 | 146 | # Create some random data
|
| 147 | + bvals, bvecs = _create_random_gtab_dataobj(request) |
| 148 | + |
139 | 149 | (
|
140 | 150 | dwi_dataobj,
|
141 | 151 | affine,
|
142 | 152 | brainmask_dataobj,
|
143 | 153 | b0_dataobj,
|
144 | 154 | gradients,
|
145 | 155 | b0_thres,
|
146 |
| - ) = _create_dwi_random_dataobj(request) |
| 156 | + ) = _create_dwi_random_dataobj(request, bvals, bvecs) |
147 | 157 |
|
148 | 158 | dwi, brainmask, b0 = _create_dwi_random_data(
|
149 | 159 | dwi_dataobj,
|
@@ -182,6 +192,66 @@ def test_equality_operator(tmp_path, request):
|
182 | 192 | assert round_trip_dwi_obj == dwi_obj
|
183 | 193 |
|
184 | 194 |
|
| 195 | +def test_shells(request, repodata): |
| 196 | + bvals, bvecs = read_bvals_bvecs( |
| 197 | + str(repodata / "hcph_multishell.bval"), |
| 198 | + str(repodata / "hcph_multishell.bvec"), |
| 199 | + ) |
| 200 | + |
| 201 | + ( |
| 202 | + dwi_dataobj, |
| 203 | + affine, |
| 204 | + brainmask_dataobj, |
| 205 | + b0_dataobj, |
| 206 | + gradients, |
| 207 | + _, |
| 208 | + ) = _create_dwi_random_dataobj(request, bvals, bvecs) |
| 209 | + |
| 210 | + dwi_obj = DWI( |
| 211 | + dataobj=dwi_dataobj, |
| 212 | + affine=affine, |
| 213 | + brainmask=brainmask_dataobj, |
| 214 | + bzero=b0_dataobj, |
| 215 | + gradients=gradients, |
| 216 | + ) |
| 217 | + |
| 218 | + num_bins = 3 |
| 219 | + _, expected_bval_groups, expected_bval_est = find_shelling_scheme( |
| 220 | + dwi_obj.gradients[-1, ...], num_bins=num_bins |
| 221 | + ) |
| 222 | + |
| 223 | + indices = [ |
| 224 | + np.hstack(np.where(np.isin(dwi_obj.gradients[-1, ...], bvals))) |
| 225 | + for bvals in expected_bval_groups |
| 226 | + ] |
| 227 | + expected_dwi_data = [dwi_obj.dataobj[..., idx] for idx in indices] |
| 228 | + expected_motion_affines = [ |
| 229 | + dwi_obj.motion_affines[idx] if dwi_obj.motion_affines else None for idx in indices |
| 230 | + ] |
| 231 | + expected_gradients = [dwi_obj.gradients[..., idx] for idx in indices] |
| 232 | + |
| 233 | + shell_data = dwi_obj.shells(num_bins=num_bins) |
| 234 | + obtained_bval_est, obtained_dwi_data, obtained_motion_affines, obtained_gradients = zip( |
| 235 | + *shell_data, strict=True |
| 236 | + ) |
| 237 | + |
| 238 | + assert len(shell_data) == num_bins |
| 239 | + assert list(obtained_bval_est) == expected_bval_est |
| 240 | + assert all( |
| 241 | + np.allclose(arr1, arr2) |
| 242 | + for arr1, arr2 in zip(list(obtained_dwi_data), expected_dwi_data, strict=True) |
| 243 | + ) |
| 244 | + assert all( |
| 245 | + (arr1 is None and arr2 is None) |
| 246 | + or (arr1 is not None and arr2 is not None and np.allclose(arr1, arr2)) |
| 247 | + for arr1, arr2 in zip(list(obtained_motion_affines), expected_motion_affines, strict=True) |
| 248 | + ) |
| 249 | + assert all( |
| 250 | + np.allclose(arr1, arr2) |
| 251 | + for arr1, arr2 in zip(list(obtained_gradients), expected_gradients, strict=True) |
| 252 | + ) |
| 253 | + |
| 254 | + |
185 | 255 | @pytest.mark.parametrize(
|
186 | 256 | ("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"),
|
187 | 257 | [
|
|
0 commit comments