Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions test/test_data_dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,179 @@ def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups, exp_bval
assert np.allclose(obt_bval_estimated, exp_bval_estimated)


@pytest.mark.parametrize(
"bvals",
[
np.asarray(
[
5,
300,
305,
5,
1005,
995,
1000,
1005,
5,
995,
1000,
995,
995,
5,
2005,
2000,
2005,
2005,
1995,
]
)
],
)
@pytest.mark.parametrize(
"num_bins, multishell_nonempty_bin_count_thr, bval_cap, exp_scheme, exp_bval_groups, exp_bval_estimated",
[
# Low multi-shell bin count threshold value
(
4,
3,
2500,
"DSI",
[
[5, 300, 305, 5, 5, 5],
[995, 1000, 995, 1000, 995, 995],
[1005, 1005],
[2005, 2000, 2005, 2005, 1995],
],
[5.0, 995.0, 1005.0, 2005.0],
),
(
5,
3,
2500,
"DSI",
[
[5, 300, 305, 5, 5, 5],
[1005, 995, 1000, 1005, 995, 1000, 995, 995],
[2005, 2000, 2005, 2005, 1995],
],
[5.0, 997.5, 2005.0],
),
# Fewer bins: ensure function still returns a consistent scheme
(
3,
3,
2500,
"DSI",
[
[5, 300, 305, 5, 5, 5],
[1005, 995, 1000, 1005, 995, 1000, 995, 995],
[2005, 2000, 2005, 2005, 1995],
],
[5.0, 997.5, 2005.0],
),
# Tighter cap: high shells beyond cap should be handled
(
5,
3,
1500,
"DSI",
[[5, 5, 5, 5], [300, 305], [1005, 995, 1000, 1005, 995, 1000, 995, 995]],
[5.0, 302.5, 997.5],
),
# Increase threshold to determine as multi-shell
(
3,
6,
2500,
"multi-shell",
[
[5, 300, 305, 5, 5, 5],
[1005, 995, 1000, 1005, 995, 1000, 995, 995],
[2005, 2000, 2005, 2005, 1995],
],
[5.0, 997.5, 2005.0],
),
(
4,
10,
2500,
"multi-shell",
[
[5, 300, 305, 5, 5, 5],
[995, 1000, 995, 1000, 995, 995],
[1005, 1005],
[2005, 2000, 2005, 2005, 1995],
],
[5.0, 995.0, 1005.0, 2005.0],
),
# Decrease num bins to determine as single-shell
(
2,
10,
2500,
"single-shell",
[
[5, 300, 305, 5, 995, 1000, 5, 995, 1000, 995, 995, 5],
[1005, 1005, 2005, 2000, 2005, 2005, 1995],
],
[650.0, 2000.0],
),
# Limit high-shell cap
(
2,
10,
1000,
"single-shell",
[[5, 300, 305, 5, 5, 5], [995, 1000, 995, 1000, 995, 995]],
[5.0, 995.0],
),
],
)
def test_find_shelling_scheme_params(
bvals,
num_bins,
multishell_nonempty_bin_count_thr,
bval_cap,
exp_scheme,
exp_bval_groups,
exp_bval_estimated,
):
"""Test find_shelling_scheme on the same bvals vector with different
parameter settings.

For the baseline parameter set we assert exact equality against the known expected
scheme, groups and estimated b-values. For other parameter sets we assert structural
invariants and basic sanity checks (no unexpected shapes, estimated b-values within
the provided cap when applicable, and that groups contain only b-values from the input).
"""
obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(
bvals,
num_bins=num_bins,
multishell_nonempty_bin_count_thr=multishell_nonempty_bin_count_thr,
bval_cap=bval_cap,
)

# Basic structural checks
assert obt_scheme == exp_scheme
assert isinstance(obt_bval_groups, list)
assert isinstance(obt_bval_estimated, list)

# Estimated values length should match number of groups
assert len(obt_bval_estimated) == len(obt_bval_groups)

# Compare groups and estimated bvals: same number and same elements (order-preserving)
assert all(
np.allclose(obt_arr, exp_arr)
for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True)
)

# If a finite bval_cap is given, make sure estimated bvals don't exceed it
if np.isfinite(bval_cap):
for est in np.asarray(obt_bval_estimated).ravel():
assert est <= bval_cap + 1e-8 # Tiny tolerance
assert np.allclose(obt_bval_estimated, exp_bval_estimated)


@pytest.mark.parametrize(
("dwi_btable", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"),
[
Expand Down