diff --git a/test/test_data_dmri.py b/test/test_data_dmri.py index faf764622..604eec63c 100644 --- a/test/test_data_dmri.py +++ b/test/test_data_dmri.py @@ -910,6 +910,179 @@ def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups, exp_bval assert np.allclose(obt_bval_estimated, exp_bval_estimated) +@pytest.mark.parametrize( + "bvals", + [ + np.asarray( + [ + 5, + 300, + 305, + 5, + 1005, + 995, + 1000, + 1005, + 5, + 995, + 1000, + 995, + 995, + 5, + 2005, + 2000, + 2005, + 2005, + 1995, + ] + ) + ], +) +@pytest.mark.parametrize( + "num_bins, multishell_nonempty_bin_count_thr, bval_cap, exp_scheme, exp_bval_groups, exp_bval_estimated", + [ + # Low multi-shell bin count threshold value + ( + 4, + 3, + 2500, + "DSI", + [ + [5, 300, 305, 5, 5, 5], + [995, 1000, 995, 1000, 995, 995], + [1005, 1005], + [2005, 2000, 2005, 2005, 1995], + ], + [5.0, 995.0, 1005.0, 2005.0], + ), + ( + 5, + 3, + 2500, + "DSI", + [ + [5, 300, 305, 5, 5, 5], + [1005, 995, 1000, 1005, 995, 1000, 995, 995], + [2005, 2000, 2005, 2005, 1995], + ], + [5.0, 997.5, 2005.0], + ), + # Fewer bins: ensure function still returns a consistent scheme + ( + 3, + 3, + 2500, + "DSI", + [ + [5, 300, 305, 5, 5, 5], + [1005, 995, 1000, 1005, 995, 1000, 995, 995], + [2005, 2000, 2005, 2005, 1995], + ], + [5.0, 997.5, 2005.0], + ), + # Tighter cap: high shells beyond cap should be handled + ( + 5, + 3, + 1500, + "DSI", + [[5, 5, 5, 5], [300, 305], [1005, 995, 1000, 1005, 995, 1000, 995, 995]], + [5.0, 302.5, 997.5], + ), + # Increase threshold to determine as multi-shell + ( + 3, + 6, + 2500, + "multi-shell", + [ + [5, 300, 305, 5, 5, 5], + [1005, 995, 1000, 1005, 995, 1000, 995, 995], + [2005, 2000, 2005, 2005, 1995], + ], + [5.0, 997.5, 2005.0], + ), + ( + 4, + 10, + 2500, + "multi-shell", + [ + [5, 300, 305, 5, 5, 5], + [995, 1000, 995, 1000, 995, 995], + [1005, 1005], + [2005, 2000, 2005, 2005, 1995], + ], + [5.0, 995.0, 1005.0, 2005.0], + ), + # Decrease num bins to determine as single-shell + ( + 2, + 10, + 2500, + "single-shell", + [ + [5, 300, 305, 5, 995, 1000, 5, 995, 1000, 995, 995, 5], + [1005, 1005, 2005, 2000, 2005, 2005, 1995], + ], + [650.0, 2000.0], + ), + # Limit high-shell cap + ( + 2, + 10, + 1000, + "single-shell", + [[5, 300, 305, 5, 5, 5], [995, 1000, 995, 1000, 995, 995]], + [5.0, 995.0], + ), + ], +) +def test_find_shelling_scheme_params( + bvals, + num_bins, + multishell_nonempty_bin_count_thr, + bval_cap, + exp_scheme, + exp_bval_groups, + exp_bval_estimated, +): + """Test find_shelling_scheme on the same bvals vector with different + parameter settings. + + For the baseline parameter set we assert exact equality against the known expected + scheme, groups and estimated b-values. For other parameter sets we assert structural + invariants and basic sanity checks (no unexpected shapes, estimated b-values within + the provided cap when applicable, and that groups contain only b-values from the input). + """ + obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme( + bvals, + num_bins=num_bins, + multishell_nonempty_bin_count_thr=multishell_nonempty_bin_count_thr, + bval_cap=bval_cap, + ) + + # Basic structural checks + assert obt_scheme == exp_scheme + assert isinstance(obt_bval_groups, list) + assert isinstance(obt_bval_estimated, list) + + # Estimated values length should match number of groups + assert len(obt_bval_estimated) == len(obt_bval_groups) + + # Compare groups and estimated bvals: same number and same elements (order-preserving) + assert all( + np.allclose(obt_arr, exp_arr) + for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) + ) + + # If a finite bval_cap is given, make sure estimated bvals don't exceed it + if np.isfinite(bval_cap): + for est in np.asarray(obt_bval_estimated).ravel(): + assert est <= bval_cap + 1e-8 # Tiny tolerance + assert np.allclose(obt_bval_estimated, exp_bval_estimated) + + @pytest.mark.parametrize( ("dwi_btable", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), [