Skip to content

ENH: Add a shell data property to DWI data class #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/nifreeze/data/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,42 @@ def to_nifti(
np.savetxt(bvecs_file, self.bvecs, fmt=f"%.{bvecs_dec_places}f")
np.savetxt(bvals_file, self.bvals[np.newaxis, :], fmt=f"%.{bvals_dec_places}f")

def shells(
self,
num_bins: int = DEFAULT_NUM_BINS,
multishell_nonempty_bin_count_thr: int = DEFAULT_MULTISHELL_BIN_COUNT_THR,
bval_cap: int = DEFAULT_HIGHB_THRESHOLD,
) -> list:
"""Get the shell data according to the b-value groups.

Bin the shell data according to the b-value groups found by `~find_shelling_scheme`.

Parameters
----------
num_bins : :obj:`int`, optional
Number of bins.
multishell_nonempty_bin_count_thr : :obj:`int`, optional
Bin count to consider a multi-shell scheme.
bval_cap : :obj:`int`, optional
Maximum b-value to be considered in a multi-shell scheme.

Returns
-------
:obj:`list`
Tuples of binned b-values and corresponding shell data.
"""

_, bval_groups, bval_estimated = find_shelling_scheme(
self.gradients[-1, ...],
num_bins=num_bins,
multishell_nonempty_bin_count_thr=multishell_nonempty_bin_count_thr,
bval_cap=bval_cap,
)
indices = [
np.hstack(np.where(np.isin(self.gradients[-1, ...], bvals))) for bvals in bval_groups
]
return [(bval_estimated[idx], *self[indices]) for idx, indices in enumerate(indices)]


def load(
filename: Path | str,
Expand Down
84 changes: 77 additions & 7 deletions test/test_data_dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,28 @@
import nibabel as nb
import numpy as np
import pytest
from dipy.io.gradients import read_bvals_bvecs

from nifreeze.data.dmri import find_shelling_scheme, load
from nifreeze.data.dmri import DWI, find_shelling_scheme, load


def _create_dwi_random_dataobj(request):
def _create_random_gtab_dataobj(request, n_gradients=10, b0s=1):
rng = request.node.rng

n_gradients = 10
b0s = 1
volumes = n_gradients + b0s
b0_thres = 50
bvals = np.hstack([b0s * [0], n_gradients * [1000]])
bvecs = np.hstack([np.zeros((3, b0s)), rng.random((3, n_gradients))])

return bvals, bvecs


def _create_dwi_random_dataobj(request, bvals, bvecs):
rng = request.node.rng

n_gradients = np.count_nonzero(bvals)
b0s = len(bvals) - n_gradients
volumes = n_gradients + b0s
b0_thres = 50

vol_size = (34, 36, 24)

dwi_dataobj = rng.random((*vol_size, volumes), dtype="float32")
Expand Down Expand Up @@ -136,14 +144,16 @@ def test_load(datadir, tmp_path):

def test_equality_operator(tmp_path, request):
# Create some random data
bvals, bvecs = _create_random_gtab_dataobj(request)

(
dwi_dataobj,
affine,
brainmask_dataobj,
b0_dataobj,
gradients,
b0_thres,
) = _create_dwi_random_dataobj(request)
) = _create_dwi_random_dataobj(request, bvals, bvecs)

dwi, brainmask, b0 = _create_dwi_random_data(
dwi_dataobj,
Expand Down Expand Up @@ -182,6 +192,66 @@ def test_equality_operator(tmp_path, request):
assert round_trip_dwi_obj == dwi_obj


def test_shells(request, repodata):
bvals, bvecs = read_bvals_bvecs(
str(repodata / "hcph_multishell.bval"),
str(repodata / "hcph_multishell.bvec"),
)

(
dwi_dataobj,
affine,
brainmask_dataobj,
b0_dataobj,
gradients,
_,
) = _create_dwi_random_dataobj(request, bvals, bvecs.T)

dwi_obj = DWI(
dataobj=dwi_dataobj,
affine=affine,
brainmask=brainmask_dataobj,
bzero=b0_dataobj,
gradients=gradients,
)

num_bins = 3
_, expected_bval_groups, expected_bval_est = find_shelling_scheme(
dwi_obj.gradients[-1, ...], num_bins=num_bins
)

indices = [
np.hstack(np.where(np.isin(dwi_obj.gradients[-1, ...], bvals)))
for bvals in expected_bval_groups
]
expected_dwi_data = [dwi_obj.dataobj[..., idx] for idx in indices]
expected_motion_affines = [
dwi_obj.motion_affines[idx] if dwi_obj.motion_affines else None for idx in indices
]
expected_gradients = [dwi_obj.gradients[..., idx] for idx in indices]

shell_data = dwi_obj.shells(num_bins=num_bins)
obtained_bval_est, obtained_dwi_data, obtained_motion_affines, obtained_gradients = zip(
*shell_data, strict=True
)

assert len(shell_data) == num_bins
assert list(obtained_bval_est) == expected_bval_est
assert all(
np.allclose(arr1, arr2)
for arr1, arr2 in zip(list(obtained_dwi_data), expected_dwi_data, strict=True)
)
assert all(
(arr1 is None and arr2 is None)
or (arr1 is not None and arr2 is not None and np.allclose(arr1, arr2))
for arr1, arr2 in zip(list(obtained_motion_affines), expected_motion_affines, strict=True)
)
assert all(
np.allclose(arr1, arr2)
for arr1, arr2 in zip(list(obtained_gradients), expected_gradients, strict=True)
)


@pytest.mark.parametrize(
("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"),
[
Expand Down