26
26
27
27
import numpy as np
28
28
29
- from nifreeze .exceptions import ModelNotFittedError
30
-
31
29
32
30
class ModelFactory :
33
31
"""A factory for instantiating data models."""
@@ -61,7 +59,7 @@ def init(model=None, **kwargs):
61
59
return AverageDWIModel (** kwargs )
62
60
63
61
if model .lower () in ("avg" , "average" , "mean" ):
64
- return AverageModel (** kwargs )
62
+ return ExpectationModel (** kwargs )
65
63
66
64
if model .lower () in ("dti" , "dki" , "pet" ):
67
65
Model = globals ()[f"{ model .upper ()} Model" ]
@@ -81,114 +79,80 @@ class BaseModel:
81
79
82
80
"""
83
81
84
- __slots__ = (
85
- "_model" ,
86
- "_mask" ,
87
- "_models" ,
88
- "_datashape" ,
89
- "_is_fitted" ,
90
- "_modelargs" ,
91
- )
82
+ __slots__ = {
83
+ "_dataset" : "Reference to a :obj:`~nifreeze.data.base.BaseDataset` object." ,
84
+ }
92
85
93
- def __init__ (self , mask = None , ** kwargs ):
86
+ def __init__ (self , dataset , ** kwargs ):
94
87
"""Base initialization."""
95
88
96
- # Keep model state
97
- self ._model = None # "Main" model
98
- self ._models = None # For parallel (chunked) execution
99
-
100
89
# Setup brain mask
101
- if mask is None :
90
+ if dataset . brainmask is None :
102
91
warn (
103
92
"No mask provided; consider using a mask to avoid issues in model optimization." ,
104
93
stacklevel = 2 ,
105
94
)
106
95
107
- self ._mask = mask
108
-
109
- self ._datashape = None
110
- self ._is_fitted = False
111
-
112
- self ._modelargs = ()
113
-
114
- @property
115
- def is_fitted (self ):
116
- return self ._is_fitted
117
-
118
- def fit (self , data , ** kwargs ):
119
- """Abstract member signature of fit()."""
120
- raise NotImplementedError ("Cannot call fit() on a BaseModel instance." )
121
-
122
- def predict (self , * args , ** kwargs ):
123
- """Abstract member signature of predict()."""
124
- raise NotImplementedError ("Cannot call predict() on a BaseModel instance." )
96
+ def fit_predict (self , * _ , ** kwargs ):
97
+ """Fit and predict the indicate index of the dataset (abstract signature)."""
98
+ raise NotImplementedError ("Cannot call fit_predict() on a BaseModel instance." )
125
99
126
100
127
101
class TrivialModel (BaseModel ):
128
102
"""A trivial model that returns a given map always."""
129
103
130
- __slots__ = ("_predicted" ,)
104
+ __slots__ = {
105
+ "_predicted" : "A :obj:`~numpy.ndarray` with shape matching the dataset containing the map"
106
+ "that will always be returned as prediction (that is, a reference volume)." ,
107
+ }
131
108
132
- def __init__ (self , predicted = None , ** kwargs ):
109
+ def __init__ (self , dataset , predicted = None , ** kwargs ):
133
110
"""Implement object initialization."""
134
- if predicted is None :
135
- raise TypeError ("This model requires the predicted map at initialization" )
136
111
137
- super ().__init__ (** kwargs )
138
- self ._predicted = predicted
139
- self ._datashape = predicted .shape
112
+ super ().__init__ (dataset , ** kwargs )
113
+ self ._predicted = (
114
+ predicted
115
+ if predicted is not None
116
+ # Infer from dataset if not provided at initialization
117
+ else getattr (dataset , "reference" , getattr (dataset , "bzero" , None ))
118
+ )
140
119
141
- @property
142
- def is_fitted (self ):
143
- return True
144
-
145
- def fit (self , data , ** kwargs ):
146
- """Do nothing."""
120
+ if self ._predicted is None :
121
+ raise TypeError ("This model requires the predicted map at initialization" )
147
122
148
- def predict (self , * _ , ** kwargs ):
123
+ def fit_predict (self , * _ , ** kwargs ):
149
124
"""Return the reference map."""
150
125
151
126
# No need to check fit (if not fitted, has raised already)
152
127
return self ._predicted
153
128
154
129
155
- class AverageModel (BaseModel ):
156
- """A trivial model that returns an average map."""
130
+ class ExpectationModel (BaseModel ):
131
+ """A trivial model that returns an expectation map (for example, average) ."""
157
132
158
- __slots__ = ( "_data" ,)
133
+ __slots__ = { "_stat" : "The statistical operation to obtain the expectation map." }
159
134
160
- def __init__ (self , ** kwargs ):
135
+ def __init__ (self , dataset , stat = "median" , ** kwargs ):
161
136
"""Initialize a new model."""
162
- super ().__init__ (** kwargs )
163
- self ._data = None
137
+ super ().__init__ (dataset , ** kwargs )
138
+ self ._stat = stat
164
139
165
- def fit (self , data , ** kwargs ):
166
- """Calculate the average."""
140
+ def fit_predict (self , index , * _ , ** kwargs ):
141
+ """
142
+ Return the expectation map.
167
143
168
- # Regress out global signal differences
169
- if kwargs .pop ("equalize" , False ):
170
- data = data .copy ().astype ("float32" )
171
- reshaped_data = (
172
- data .reshape ((- 1 , data .shape [- 1 ])) if self ._mask is None else data [self ._mask ]
173
- )
174
- p5 = np .percentile (reshaped_data , 5.0 , axis = 0 )
175
- p95 = np .percentile (reshaped_data , 95.0 , axis = 0 ) - p5
176
- data = (data - p5 ) * p95 .mean () / p95 + p5 .mean ()
144
+ Parameters
145
+ ----------
146
+ index : :obj:`int`
147
+ The volume index that is left-out in fitting, and then predicted.
177
148
149
+ """
178
150
# Select the summary statistic
179
- avg_func = getattr (np , kwargs .pop ("stat" , "mean" ))
151
+ avg_func = getattr (np , kwargs .pop ("stat" , self . _stat ))
180
152
181
- # Calculate the average
182
- self ._data = avg_func (data , axis = - 1 )
183
-
184
- @property
185
- def is_fitted (self ):
186
- return self ._data is not None
187
-
188
- def predict (self , * _ , ** kwargs ):
189
- """Return the average map."""
153
+ # Create index mask
154
+ mask = np .ones (len (self ._dataset ), dtype = bool )
155
+ mask [index ] = False
190
156
191
- if self ._data is None :
192
- raise ModelNotFittedError (f"{ type (self ).__name__ } must be fitted before predicting" )
193
-
194
- return self ._data
157
+ # Calculate the average
158
+ return avg_func (self ._dataset .dataobj [mask ][0 ], axis = - 1 )
0 commit comments