Skip to content

Commit 4278c3b

Browse files
jhlegarretaoesteban
andcommitted
ENH: Add a shell data property to DWI data class
Add a shell data property to `DWI` data class that returns a list of pairs consisting of the estimated b-value and the associated DWI data. Co-authored-by: Oscar Esteban <[email protected]>
1 parent 628aa8f commit 4278c3b

File tree

2 files changed

+113
-7
lines changed

2 files changed

+113
-7
lines changed

src/nifreeze/data/dmri.py

+36
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,42 @@ def to_nifti(
256256
np.savetxt(bvecs_file, self.bvecs, fmt=f"%.{bvecs_dec_places}f")
257257
np.savetxt(bvals_file, self.bvals[np.newaxis, :], fmt=f"%.{bvals_dec_places}f")
258258

259+
def shells(
260+
self,
261+
num_bins: int = DEFAULT_NUM_BINS,
262+
multishell_nonempty_bin_count_thr: int = DEFAULT_MULTISHELL_BIN_COUNT_THR,
263+
bval_cap: int = DEFAULT_HIGHB_THRESHOLD,
264+
) -> list:
265+
"""Get the shell data according to the b-value groups.
266+
267+
Bin the shell data according to the b-value groups found by `~find_shelling_scheme`.
268+
269+
Parameters
270+
----------
271+
num_bins : :obj:`int`, optional
272+
Number of bins.
273+
multishell_nonempty_bin_count_thr : :obj:`int`, optional
274+
Bin count to consider a multi-shell scheme.
275+
bval_cap : :obj:`int`, optional
276+
Maximum b-value to be considered in a multi-shell scheme.
277+
278+
Returns
279+
-------
280+
:obj:`list`
281+
Tuples of binned b-values and corresponding shell data.
282+
"""
283+
284+
_, bval_groups, bval_estimated = find_shelling_scheme(
285+
self.gradients[-1, ...],
286+
num_bins=num_bins,
287+
multishell_nonempty_bin_count_thr=multishell_nonempty_bin_count_thr,
288+
bval_cap=bval_cap,
289+
)
290+
indices = [
291+
np.hstack(np.where(np.isin(self.gradients[-1, ...], bvals))) for bvals in bval_groups
292+
]
293+
return [(bval_estimated[idx], *self[indices]) for idx, indices in enumerate(indices)]
294+
259295

260296
def load(
261297
filename: Path | str,

test/test_data_dmri.py

+77-7
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,28 @@
2525
import nibabel as nb
2626
import numpy as np
2727
import pytest
28+
from dipy.io.gradients import read_bvals_bvecs
2829

29-
from nifreeze.data.dmri import find_shelling_scheme, load
30+
from nifreeze.data.dmri import DWI, find_shelling_scheme, load
3031

3132

32-
def _create_dwi_random_dataobj(request):
33+
def _create_random_gtab_dataobj(request, n_gradients=10, b0s=1):
3334
rng = request.node.rng
3435

35-
n_gradients = 10
36-
b0s = 1
37-
volumes = n_gradients + b0s
38-
b0_thres = 50
3936
bvals = np.hstack([b0s * [0], n_gradients * [1000]])
4037
bvecs = np.hstack([np.zeros((3, b0s)), rng.random((3, n_gradients))])
4138

39+
return bvals, bvecs
40+
41+
42+
def _create_dwi_random_dataobj(request, bvals, bvecs):
43+
rng = request.node.rng
44+
45+
n_gradients = np.count_nonzero(bvecs)
46+
b0s = len(bvals) - n_gradients
47+
volumes = n_gradients + b0s
48+
b0_thres = 50
49+
4250
vol_size = (34, 36, 24)
4351

4452
dwi_dataobj = rng.random((*vol_size, volumes), dtype="float32")
@@ -136,14 +144,16 @@ def test_load(datadir, tmp_path):
136144

137145
def test_equality_operator(tmp_path, request):
138146
# Create some random data
147+
bvals, bvecs = _create_random_gtab_dataobj(request)
148+
139149
(
140150
dwi_dataobj,
141151
affine,
142152
brainmask_dataobj,
143153
b0_dataobj,
144154
gradients,
145155
b0_thres,
146-
) = _create_dwi_random_dataobj(request)
156+
) = _create_dwi_random_dataobj(request, bvals, bvecs)
147157

148158
dwi, brainmask, b0 = _create_dwi_random_data(
149159
dwi_dataobj,
@@ -182,6 +192,66 @@ def test_equality_operator(tmp_path, request):
182192
assert round_trip_dwi_obj == dwi_obj
183193

184194

195+
def test_shells(request, repodata):
196+
bvals, bvecs = read_bvals_bvecs(
197+
str(repodata / "hcph_multishell.bval"),
198+
str(repodata / "hcph_multishell.bvec"),
199+
)
200+
201+
(
202+
dwi_dataobj,
203+
affine,
204+
brainmask_dataobj,
205+
b0_dataobj,
206+
gradients,
207+
_,
208+
) = _create_dwi_random_dataobj(request, bvals, bvecs)
209+
210+
dwi_obj = DWI(
211+
dataobj=dwi_dataobj,
212+
affine=affine,
213+
brainmask=brainmask_dataobj,
214+
bzero=b0_dataobj,
215+
gradients=gradients,
216+
)
217+
218+
num_bins = 3
219+
_, expected_bval_groups, expected_bval_est = find_shelling_scheme(
220+
dwi_obj.gradients[-1, ...], num_bins=num_bins
221+
)
222+
223+
indices = [
224+
np.hstack(np.where(np.isin(dwi_obj.gradients[-1, ...], bvals)))
225+
for bvals in expected_bval_groups
226+
]
227+
expected_dwi_data = [dwi_obj.dataobj[..., idx] for idx in indices]
228+
expected_motion_affines = [
229+
dwi_obj.motion_affines[idx] if dwi_obj.motion_affines else None for idx in indices
230+
]
231+
expected_gradients = [dwi_obj.gradients[..., idx] for idx in indices]
232+
233+
shell_data = dwi_obj.shells(num_bins=num_bins)
234+
obtained_bval_est, obtained_dwi_data, obtained_motion_affines, obtained_gradients = zip(
235+
*shell_data, strict=True
236+
)
237+
238+
assert len(shell_data) == num_bins
239+
assert list(obtained_bval_est) == expected_bval_est
240+
assert all(
241+
np.allclose(arr1, arr2)
242+
for arr1, arr2 in zip(list(obtained_dwi_data), expected_dwi_data, strict=True)
243+
)
244+
assert all(
245+
(arr1 is None and arr2 is None)
246+
or (arr1 is not None and arr2 is not None and np.allclose(arr1, arr2))
247+
for arr1, arr2 in zip(list(obtained_motion_affines), expected_motion_affines, strict=True)
248+
)
249+
assert all(
250+
np.allclose(arr1, arr2)
251+
for arr1, arr2 in zip(list(obtained_gradients), expected_gradients, strict=True)
252+
)
253+
254+
185255
@pytest.mark.parametrize(
186256
("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"),
187257
[

0 commit comments

Comments
 (0)