@@ -44,15 +44,18 @@ def test_trivial_model():
44
44
with pytest .raises (TypeError ):
45
45
model .TrivialModel ()
46
46
47
- _S0 = rng .normal (size = (2 , 2 , 2 ))
47
+ size = (2 , 2 , 2 )
48
+ mask = np .ones (size , dtype = bool )
49
+
50
+ _S0 = rng .normal (size = size )
48
51
49
52
_clipped_S0 = np .clip (
50
53
_S0 .astype ("float32" ) / _S0 .max (),
51
54
a_min = DEFAULT_MIN_S0 ,
52
55
a_max = DEFAULT_MAX_S0 ,
53
56
)
54
57
55
- tmodel = model .TrivialModel (predicted = _clipped_S0 )
58
+ tmodel = model .TrivialModel (mask = mask , predicted = _clipped_S0 )
56
59
57
60
data = None
58
61
assert tmodel .fit (data ) is None
@@ -63,7 +66,9 @@ def test_trivial_model():
63
66
def test_average_model ():
64
67
"""Check the implementation of the average DW model."""
65
68
66
- data = np .ones ((100 , 100 , 100 , 6 ), dtype = float )
69
+ size = (100 , 100 , 100 , 6 )
70
+ data = np .ones (size , dtype = float )
71
+ mask = np .ones (size , dtype = bool )
67
72
68
73
gtab = np .array (
69
74
[
@@ -78,10 +83,11 @@ def test_average_model():
78
83
79
84
data *= gtab [:, - 1 ]
80
85
81
- tmodel_mean = model .AverageDWIModel (gtab = gtab , bias = False , stat = "mean" )
82
- tmodel_median = model .AverageDWIModel (gtab = gtab , bias = False , stat = "median" )
83
- tmodel_1000 = model .AverageDWIModel (gtab = gtab , bias = False , th_high = 1000 , th_low = 900 )
86
+ tmodel_mean = model .AverageDWIModel (mask = mask , gtab = gtab , bias = False , stat = "mean" )
87
+ tmodel_median = model .AverageDWIModel (mask = mask , gtab = gtab , bias = False , stat = "median" )
88
+ tmodel_1000 = model .AverageDWIModel (mask = mask , gtab = gtab , bias = False , th_high = 1000 , th_low = 900 )
84
89
tmodel_2000 = model .AverageDWIModel (
90
+ mask = mask ,
85
91
gtab = gtab ,
86
92
bias = False ,
87
93
th_high = 2000 ,
@@ -154,6 +160,7 @@ def test_two_initialisations(datadir):
154
160
155
161
# Direct initialisation
156
162
model1 = model .AverageDWIModel (
163
+ mask = dmri_dataset .brainmask .astype (bool ),
157
164
gtab = data_train [- 1 ],
158
165
S0 = dmri_dataset .bzero ,
159
166
th_low = 100 ,
0 commit comments