Skip to content

Commit 7a489e3

Browse files
committed
enh: continue with the refactor
1 parent 544249b commit 7a489e3

File tree

9 files changed

+153
-175
lines changed

9 files changed

+153
-175
lines changed

docs/notebooks/bold_realignment.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@
337337
"metadata": {},
338338
"outputs": [],
339339
"source": [
340-
"from nifreeze.model.base import AverageModel\n",
340+
"from nifreeze.model.base import ExpectationModel\n",
341341
"from nifreeze.utils.iterators import random_iterator"
342342
]
343343
},
@@ -358,7 +358,7 @@
358358
" t_mask[t] = True\n",
359359
"\n",
360360
" # Fit and predict using the model\n",
361-
" model = AverageModel()\n",
361+
" model = ExpectationModel()\n",
362362
" model.fit(\n",
363363
" data[..., ~t_mask],\n",
364364
" stat=\"median\",\n",
@@ -376,7 +376,7 @@
376376
" fixedmask_path=brainmask_path,\n",
377377
" output_transform_prefix=f\"conversion-{t:02d}\",\n",
378378
" num_threads=8,\n",
379-
" )\n",
379+
" ).cmdline\n",
380380
"\n",
381381
" # Run the command\n",
382382
" proc = await asyncio.create_subprocess_shell(\n",

scripts/optimize_registration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def train_coro(
133133
fixedmask_path=brainmask_path,
134134
output_transform_prefix=f"conversion-{index:04d}",
135135
**align_kwargs,
136-
)
136+
).cmdline
137137

138138
tasks.append(
139139
ants(

src/nifreeze/estimator.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232

3333
from nifreeze.data.base import BaseDataset
3434
from nifreeze.model.base import BaseModel, ModelFactory
35-
from nifreeze.registration.ants import _prepare_registration_data, _run_registration
35+
from nifreeze.registration.ants import (
36+
_prepare_registration_data,
37+
_run_registration,
38+
)
3639
from nifreeze.utils import iterators
3740

3841

@@ -60,7 +63,7 @@ def run(self, dataset: BaseDataset, **kwargs):
6063
class Estimator:
6164
"""Estimates rigid-body head-motion and distortions derived from eddy-currents."""
6265

63-
__slots__ = ("_model", "_strategy", "_dataset", "_prev", "_model_kwargs", "_align_kwargs")
66+
__slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
6467

6568
def __init__(
6669
self,
@@ -111,29 +114,37 @@ def run(self, dataset: BaseDataset, **kwargs):
111114
**self._model_kwargs,
112115
)
113116

114-
if self._model.is_static:
115-
self._model.fit(dataset, **kwargs)
116-
117117
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
118118

119119
dataset_length = len(dataset)
120120
with TemporaryDirectory() as tmp_dir:
121121
print(f"Processing in <{tmp_dir}>")
122122
ptmp_dir = Path(tmp_dir)
123+
124+
bmask_path = None
125+
if dataset.brainmask is not None:
126+
import nibabel as nb
127+
128+
bmask_path = ptmp_dir / "brainmask.nii.gz"
129+
nb.Nifti1Image(
130+
dataset.brainmask.astype("uint8"), dataset.affine, None
131+
).to_filename(bmask_path)
132+
123133
with tqdm(total=dataset_length, unit="vols.") as pbar:
124134
# run a original-to-synthetic affine registration
125135
for i in index_iter:
126136
pbar.set_description_str(f"Fit and predict vol. <{i}>")
127137

128138
# fit the model
129-
reference, predicted = self._model.fit_predict(
139+
test_set = dataset[i]
140+
predicted = self._model.fit_predict(
130141
i,
131142
n_jobs=n_jobs,
132143
)
133144

134145
# prepare data for running ANTs
135-
fixed, moving = _prepare_registration_data(
136-
reference,
146+
predicted_path, volume_path, init_path = _prepare_registration_data(
147+
test_set[0],
137148
predicted,
138149
dataset.affine,
139150
i,
@@ -144,14 +155,13 @@ def run(self, dataset: BaseDataset, **kwargs):
144155
pbar.set_description_str(f"Realign vol. <{i}>")
145156

146157
xform = _run_registration(
147-
fixed,
148-
moving,
149-
dataset.brainmask,
150-
dataset.motion_affines,
151-
dataset.affine,
152-
dataset.dataobj.shape[:3],
158+
predicted_path,
159+
volume_path,
153160
i,
154161
ptmp_dir,
162+
init_affine=init_path,
163+
fixedmask_path=bmask_path,
164+
output_transform_prefix=f"ants-{i:05d}",
155165
**kwargs,
156166
)
157167

src/nifreeze/model/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"""Data models."""
2424

2525
from nifreeze.model.base import (
26-
AverageModel,
26+
ExpectationModel,
2727
ModelFactory,
2828
TrivialModel,
2929
)
@@ -37,7 +37,7 @@
3737

3838
__all__ = (
3939
"ModelFactory",
40-
"AverageModel",
40+
"ExpectationModel",
4141
"AverageDWIModel",
4242
"DKIModel",
4343
"DTIModel",

src/nifreeze/model/base.py

+44-80
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626

2727
import numpy as np
2828

29-
from nifreeze.exceptions import ModelNotFittedError
30-
3129

3230
class ModelFactory:
3331
"""A factory for instantiating data models."""
@@ -61,7 +59,7 @@ def init(model=None, **kwargs):
6159
return AverageDWIModel(**kwargs)
6260

6361
if model.lower() in ("avg", "average", "mean"):
64-
return AverageModel(**kwargs)
62+
return ExpectationModel(**kwargs)
6563

6664
if model.lower() in ("dti", "dki", "pet"):
6765
Model = globals()[f"{model.upper()}Model"]
@@ -81,114 +79,80 @@ class BaseModel:
8179
8280
"""
8381

84-
__slots__ = (
85-
"_model",
86-
"_mask",
87-
"_models",
88-
"_datashape",
89-
"_is_fitted",
90-
"_modelargs",
91-
)
82+
__slots__ = {
83+
"_dataset": "Reference to a :obj:`~nifreeze.data.base.BaseDataset` object.",
84+
}
9285

93-
def __init__(self, mask=None, **kwargs):
86+
def __init__(self, dataset, **kwargs):
9487
"""Base initialization."""
9588

96-
# Keep model state
97-
self._model = None # "Main" model
98-
self._models = None # For parallel (chunked) execution
99-
10089
# Setup brain mask
101-
if mask is None:
90+
if dataset.brainmask is None:
10291
warn(
10392
"No mask provided; consider using a mask to avoid issues in model optimization.",
10493
stacklevel=2,
10594
)
10695

107-
self._mask = mask
108-
109-
self._datashape = None
110-
self._is_fitted = False
111-
112-
self._modelargs = ()
113-
114-
@property
115-
def is_fitted(self):
116-
return self._is_fitted
117-
118-
def fit(self, data, **kwargs):
119-
"""Abstract member signature of fit()."""
120-
raise NotImplementedError("Cannot call fit() on a BaseModel instance.")
121-
122-
def predict(self, *args, **kwargs):
123-
"""Abstract member signature of predict()."""
124-
raise NotImplementedError("Cannot call predict() on a BaseModel instance.")
96+
def fit_predict(self, *_, **kwargs):
97+
"""Fit and predict the indicate index of the dataset (abstract signature)."""
98+
raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.")
12599

126100

127101
class TrivialModel(BaseModel):
128102
"""A trivial model that returns a given map always."""
129103

130-
__slots__ = ("_predicted",)
104+
__slots__ = {
105+
"_predicted": "A :obj:`~numpy.ndarray` with shape matching the dataset containing the map"
106+
"that will always be returned as prediction (that is, a reference volume).",
107+
}
131108

132-
def __init__(self, predicted=None, **kwargs):
109+
def __init__(self, dataset, predicted=None, **kwargs):
133110
"""Implement object initialization."""
134-
if predicted is None:
135-
raise TypeError("This model requires the predicted map at initialization")
136111

137-
super().__init__(**kwargs)
138-
self._predicted = predicted
139-
self._datashape = predicted.shape
112+
super().__init__(dataset, **kwargs)
113+
self._predicted = (
114+
predicted
115+
if predicted is not None
116+
# Infer from dataset if not provided at initialization
117+
else getattr(dataset, "reference", getattr(dataset, "bzero", None))
118+
)
140119

141-
@property
142-
def is_fitted(self):
143-
return True
144-
145-
def fit(self, data, **kwargs):
146-
"""Do nothing."""
120+
if self._predicted is None:
121+
raise TypeError("This model requires the predicted map at initialization")
147122

148-
def predict(self, *_, **kwargs):
123+
def fit_predict(self, *_, **kwargs):
149124
"""Return the reference map."""
150125

151126
# No need to check fit (if not fitted, has raised already)
152127
return self._predicted
153128

154129

155-
class AverageModel(BaseModel):
156-
"""A trivial model that returns an average map."""
130+
class ExpectationModel(BaseModel):
131+
"""A trivial model that returns an expectation map (for example, average)."""
157132

158-
__slots__ = ("_data",)
133+
__slots__ = {"_stat": "The statistical operation to obtain the expectation map."}
159134

160-
def __init__(self, **kwargs):
135+
def __init__(self, dataset, stat="median", **kwargs):
161136
"""Initialize a new model."""
162-
super().__init__(**kwargs)
163-
self._data = None
137+
super().__init__(dataset, **kwargs)
138+
self._stat = stat
164139

165-
def fit(self, data, **kwargs):
166-
"""Calculate the average."""
140+
def fit_predict(self, index, *_, **kwargs):
141+
"""
142+
Return the expectation map.
167143
168-
# Regress out global signal differences
169-
if kwargs.pop("equalize", False):
170-
data = data.copy().astype("float32")
171-
reshaped_data = (
172-
data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask]
173-
)
174-
p5 = np.percentile(reshaped_data, 5.0, axis=0)
175-
p95 = np.percentile(reshaped_data, 95.0, axis=0) - p5
176-
data = (data - p5) * p95.mean() / p95 + p5.mean()
144+
Parameters
145+
----------
146+
index : :obj:`int`
147+
The volume index that is left-out in fitting, and then predicted.
177148
149+
"""
178150
# Select the summary statistic
179-
avg_func = getattr(np, kwargs.pop("stat", "mean"))
151+
avg_func = getattr(np, kwargs.pop("stat", self._stat))
180152

181-
# Calculate the average
182-
self._data = avg_func(data, axis=-1)
183-
184-
@property
185-
def is_fitted(self):
186-
return self._data is not None
187-
188-
def predict(self, *_, **kwargs):
189-
"""Return the average map."""
153+
# Create index mask
154+
mask = np.ones(len(self._dataset), dtype=bool)
155+
mask[index] = False
190156

191-
if self._data is None:
192-
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")
193-
194-
return self._data
157+
# Calculate the average
158+
return avg_func(self._dataset.dataobj[mask][0], axis=-1)

0 commit comments

Comments
 (0)