diff --git a/src/nifreeze/data/dmri.py b/src/nifreeze/data/dmri.py index 878c72c6..d2ec95b6 100644 --- a/src/nifreeze/data/dmri.py +++ b/src/nifreeze/data/dmri.py @@ -256,6 +256,42 @@ def to_nifti( np.savetxt(bvecs_file, self.bvecs, fmt=f"%.{bvecs_dec_places}f") np.savetxt(bvals_file, self.bvals[np.newaxis, :], fmt=f"%.{bvals_dec_places}f") + def shells( + self, + num_bins: int = DEFAULT_NUM_BINS, + multishell_nonempty_bin_count_thr: int = DEFAULT_MULTISHELL_BIN_COUNT_THR, + bval_cap: int = DEFAULT_HIGHB_THRESHOLD, + ) -> list: + """Get the shell data according to the b-value groups. + + Bin the shell data according to the b-value groups found by `~find_shelling_scheme`. + + Parameters + ---------- + num_bins : :obj:`int`, optional + Number of bins. + multishell_nonempty_bin_count_thr : :obj:`int`, optional + Bin count to consider a multi-shell scheme. + bval_cap : :obj:`int`, optional + Maximum b-value to be considered in a multi-shell scheme. + + Returns + ------- + :obj:`list` + Tuples of binned b-values and corresponding shell data. + """ + + _, bval_groups, bval_estimated = find_shelling_scheme( + self.gradients[-1, ...], + num_bins=num_bins, + multishell_nonempty_bin_count_thr=multishell_nonempty_bin_count_thr, + bval_cap=bval_cap, + ) + indices = [ + np.hstack(np.where(np.isin(self.gradients[-1, ...], bvals))) for bvals in bval_groups + ] + return [(bval_estimated[idx], *self[indices]) for idx, indices in enumerate(indices)] + def load( filename: Path | str, diff --git a/test/test_data_dmri.py b/test/test_data_dmri.py index 741c4ff0..dd65acd3 100644 --- a/test/test_data_dmri.py +++ b/test/test_data_dmri.py @@ -25,20 +25,28 @@ import nibabel as nb import numpy as np import pytest +from dipy.io.gradients import read_bvals_bvecs -from nifreeze.data.dmri import find_shelling_scheme, load +from nifreeze.data.dmri import DWI, find_shelling_scheme, load -def _create_dwi_random_dataobj(request): +def _create_random_gtab_dataobj(request, n_gradients=10, b0s=1): rng = request.node.rng - n_gradients = 10 - b0s = 1 - volumes = n_gradients + b0s - b0_thres = 50 bvals = np.hstack([b0s * [0], n_gradients * [1000]]) bvecs = np.hstack([np.zeros((3, b0s)), rng.random((3, n_gradients))]) + return bvals, bvecs + + +def _create_dwi_random_dataobj(request, bvals, bvecs): + rng = request.node.rng + + n_gradients = np.count_nonzero(bvals) + b0s = len(bvals) - n_gradients + volumes = n_gradients + b0s + b0_thres = 50 + vol_size = (34, 36, 24) dwi_dataobj = rng.random((*vol_size, volumes), dtype="float32") @@ -136,6 +144,8 @@ def test_load(datadir, tmp_path): def test_equality_operator(tmp_path, request): # Create some random data + bvals, bvecs = _create_random_gtab_dataobj(request) + ( dwi_dataobj, affine, @@ -143,7 +153,7 @@ def test_equality_operator(tmp_path, request): b0_dataobj, gradients, b0_thres, - ) = _create_dwi_random_dataobj(request) + ) = _create_dwi_random_dataobj(request, bvals, bvecs) dwi, brainmask, b0 = _create_dwi_random_data( dwi_dataobj, @@ -182,6 +192,66 @@ def test_equality_operator(tmp_path, request): assert round_trip_dwi_obj == dwi_obj +def test_shells(request, repodata): + bvals, bvecs = read_bvals_bvecs( + str(repodata / "hcph_multishell.bval"), + str(repodata / "hcph_multishell.bvec"), + ) + + ( + dwi_dataobj, + affine, + brainmask_dataobj, + b0_dataobj, + gradients, + _, + ) = _create_dwi_random_dataobj(request, bvals, bvecs.T) + + dwi_obj = DWI( + dataobj=dwi_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + bzero=b0_dataobj, + gradients=gradients, + ) + + num_bins = 3 + _, expected_bval_groups, expected_bval_est = find_shelling_scheme( + dwi_obj.gradients[-1, ...], num_bins=num_bins + ) + + indices = [ + np.hstack(np.where(np.isin(dwi_obj.gradients[-1, ...], bvals))) + for bvals in expected_bval_groups + ] + expected_dwi_data = [dwi_obj.dataobj[..., idx] for idx in indices] + expected_motion_affines = [ + dwi_obj.motion_affines[idx] if dwi_obj.motion_affines else None for idx in indices + ] + expected_gradients = [dwi_obj.gradients[..., idx] for idx in indices] + + shell_data = dwi_obj.shells(num_bins=num_bins) + obtained_bval_est, obtained_dwi_data, obtained_motion_affines, obtained_gradients = zip( + *shell_data, strict=True + ) + + assert len(shell_data) == num_bins + assert list(obtained_bval_est) == expected_bval_est + assert all( + np.allclose(arr1, arr2) + for arr1, arr2 in zip(list(obtained_dwi_data), expected_dwi_data, strict=True) + ) + assert all( + (arr1 is None and arr2 is None) + or (arr1 is not None and arr2 is not None and np.allclose(arr1, arr2)) + for arr1, arr2 in zip(list(obtained_motion_affines), expected_motion_affines, strict=True) + ) + assert all( + np.allclose(arr1, arr2) + for arr1, arr2 in zip(list(obtained_gradients), expected_gradients, strict=True) + ) + + @pytest.mark.parametrize( ("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), [