2222#
2323"""Unit tests exercising models."""
2424
25+ import contextlib
26+
2527import numpy as np
2628import pytest
2729from dipy .sims .voxel import single_tensor
3133from nifreeze .data .splitting import lovo_split
3234from nifreeze .exceptions import ModelNotFittedError
3335from nifreeze .model ._dipy import GaussianProcessModel
36+ from nifreeze .model .base import mask_absence_warn_msg
3437from nifreeze .model .dmri import DEFAULT_MAX_S0 , DEFAULT_MIN_S0
3538from nifreeze .testing import simulations as _sim
3639
3740
38- def test_trivial_model ():
41+ @pytest .mark .parametrize ("use_mask" , (False , True ))
42+ def test_trivial_model (use_mask ):
3943 """Check the implementation of the trivial B0 model."""
4044
4145 rng = np .random .default_rng (1234 )
@@ -45,7 +49,12 @@ def test_trivial_model():
4549 model .TrivialModel ()
4650
4751 size = (2 , 2 , 2 )
48- mask = np .ones (size , dtype = bool )
52+ mask = None
53+ if use_mask :
54+ mask = np .ones (size , dtype = bool )
55+ context = contextlib .nullcontext ()
56+ else :
57+ context = pytest .warns (UserWarning , match = mask_absence_warn_msg )
4958
5059 _S0 = rng .normal (size = size )
5160
@@ -55,7 +64,8 @@ def test_trivial_model():
5564 a_max = DEFAULT_MAX_S0 ,
5665 )
5766
58- tmodel = model .TrivialModel (mask = mask , predicted = _clipped_S0 )
67+ with context :
68+ tmodel = model .TrivialModel (mask = mask , predicted = _clipped_S0 )
5969
6070 data = None
6171 assert tmodel .fit (data ) is None
0 commit comments