Skip to content

Commit 5025477

Browse files
committed
ENH: Provide masks to models in tests
Provide masks to models in tests. Fixes: ``` test/test_model.py::test_two_initialisations test/test_model.py::test_trivial_model test/test_model.py::test_average_model /home/runner/work/nifreeze/nifreeze/src/nifreeze/model/base.py:99: UserWarning: No mask provided; consider using a mask to avoid issues in model optimization. warn("No mask provided; consider using a mask to avoid issues in model optimization.") ``` raised for example in: https://github.com/nipreps/nifreeze/actions/runs/12769534850/job/35592349975#step:11:1032
1 parent 856a00e commit 5025477

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

test/test_model.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,18 @@ def test_trivial_model():
4444
with pytest.raises(TypeError):
4545
model.TrivialModel()
4646

47-
_S0 = rng.normal(size=(2, 2, 2))
47+
size = (2, 2, 2)
48+
mask = np.ones(size, dtype=bool)
49+
50+
_S0 = rng.normal(size=size)
4851

4952
_clipped_S0 = np.clip(
5053
_S0.astype("float32") / _S0.max(),
5154
a_min=DEFAULT_MIN_S0,
5255
a_max=DEFAULT_MAX_S0,
5356
)
5457

55-
tmodel = model.TrivialModel(predicted=_clipped_S0)
58+
tmodel = model.TrivialModel(mask=mask, predicted=_clipped_S0)
5659

5760
data = None
5861
assert tmodel.fit(data) is None
@@ -63,7 +66,9 @@ def test_trivial_model():
6366
def test_average_model():
6467
"""Check the implementation of the average DW model."""
6568

66-
data = np.ones((100, 100, 100, 6), dtype=float)
69+
size = (100, 100, 100, 6)
70+
data = np.ones(size, dtype=float)
71+
mask = np.ones(size, dtype=bool)
6772

6873
gtab = np.array(
6974
[
@@ -78,10 +83,11 @@ def test_average_model():
7883

7984
data *= gtab[:, -1]
8085

81-
tmodel_mean = model.AverageDWIModel(gtab=gtab, bias=False, stat="mean")
82-
tmodel_median = model.AverageDWIModel(gtab=gtab, bias=False, stat="median")
83-
tmodel_1000 = model.AverageDWIModel(gtab=gtab, bias=False, th_high=1000, th_low=900)
86+
tmodel_mean = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, stat="mean")
87+
tmodel_median = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, stat="median")
88+
tmodel_1000 = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, th_high=1000, th_low=900)
8489
tmodel_2000 = model.AverageDWIModel(
90+
mask=mask,
8591
gtab=gtab,
8692
bias=False,
8793
th_high=2000,
@@ -154,6 +160,7 @@ def test_two_initialisations(datadir):
154160

155161
# Direct initialisation
156162
model1 = model.AverageDWIModel(
163+
mask=dmri_dataset.brainmask.astype(bool),
157164
gtab=data_train[-1],
158165
S0=dmri_dataset.bzero,
159166
th_low=100,

0 commit comments

Comments
 (0)