Skip to content

Commit 72ae039

Browse files
authored
Merge pull request #50 from jhlegarreta/ProvideMasksToModelsInTests
ENH: Provide masks to models in tests
2 parents d9b390b + dfb90e1 commit 72ae039

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

src/nifreeze/model/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def init(model="DTI", **kwargs):
5050
5151
"""
5252
if model.lower() in ("s0", "b0"):
53-
return TrivialModel(predicted=kwargs.pop("S0"), gtab=kwargs.pop("gtab"))
53+
return TrivialModel(
54+
mask=kwargs.pop("mask"), predicted=kwargs.pop("S0"), gtab=kwargs.pop("gtab")
55+
)
5456

5557
if model.lower() in ("avgdwi", "averagedwi", "meandwi"):
5658
from nifreeze.model.dmri import AverageDWIModel

test/test_model.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,18 @@ def test_trivial_model():
4444
with pytest.raises(TypeError):
4545
model.TrivialModel()
4646

47-
_S0 = rng.normal(size=(2, 2, 2))
47+
size = (2, 2, 2)
48+
mask = np.ones(size, dtype=bool)
49+
50+
_S0 = rng.normal(size=size)
4851

4952
_clipped_S0 = np.clip(
5053
_S0.astype("float32") / _S0.max(),
5154
a_min=DEFAULT_MIN_S0,
5255
a_max=DEFAULT_MAX_S0,
5356
)
5457

55-
tmodel = model.TrivialModel(predicted=_clipped_S0)
58+
tmodel = model.TrivialModel(mask=mask, predicted=_clipped_S0)
5659

5760
data = None
5861
assert tmodel.fit(data) is None
@@ -63,7 +66,9 @@ def test_trivial_model():
6366
def test_average_model():
6467
"""Check the implementation of the average DW model."""
6568

66-
data = np.ones((100, 100, 100, 6), dtype=float)
69+
size = (100, 100, 100, 6)
70+
data = np.ones(size, dtype=float)
71+
mask = np.ones(size, dtype=bool)
6772

6873
gtab = np.array(
6974
[
@@ -78,10 +83,11 @@ def test_average_model():
7883

7984
data *= gtab[:, -1]
8085

81-
tmodel_mean = model.AverageDWIModel(gtab=gtab, bias=False, stat="mean")
82-
tmodel_median = model.AverageDWIModel(gtab=gtab, bias=False, stat="median")
83-
tmodel_1000 = model.AverageDWIModel(gtab=gtab, bias=False, th_high=1000, th_low=900)
86+
tmodel_mean = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, stat="mean")
87+
tmodel_median = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, stat="median")
88+
tmodel_1000 = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, th_high=1000, th_low=900)
8489
tmodel_2000 = model.AverageDWIModel(
90+
mask=mask,
8591
gtab=gtab,
8692
bias=False,
8793
th_high=2000,
@@ -154,6 +160,7 @@ def test_two_initialisations(datadir):
154160

155161
# Direct initialisation
156162
model1 = model.AverageDWIModel(
163+
mask=dmri_dataset.brainmask.astype(bool),
157164
gtab=data_train[-1],
158165
S0=dmri_dataset.bzero,
159166
th_low=100,

0 commit comments

Comments
 (0)