Skip to content

Commit 4f32934

Browse files
committed
enh: Implicitly-pipelined and modality-agnostic Estimator
Changes our current implementation of the estimator with a new architecture that allows stacking (#12 (comment)): ```Python estimator_level1 = Estimator(model="b0", ...) # e.g., 6 dof registration estimator_level2 = Estimator(model="b0", input=estimator_level1, ...) # e.g., 9 dof registration estimator_level2.fit(dataset_object) # Checks "input" and it's another estimator, so runs that first ``` This allow for adding *filters* (to be implemented): ``` Python downsample = Filter(...) # e.g., downsampling filter estimator_level1 = Estimator(model="b0", input=downsample, ...) # e.g., 6 dof registration estimator_level2 = Estimator(model="b0", input=estimator_level1, ...) # e.g., 9 dof registration estimator_level2.fit(dataset_object) # Checks "input" and it's another estimator, so runs that first ``` In this case, the filtered dataset only feeds ``estimator_level1``. The second level will work on a full dataset. If you want to interleave another downsampling filter: ``` Python downsample1 = Filter(...) # e.g., downsampling filter 1 estimator_level1 = Estimator(model="b0", input=downsample, ...) # e.g., 6 dof registration downsample2 = Filter(input=estimator_level1, ...) # e.g., downsampling filter 2 estimator_level2 = Estimator(model="b0", input=downsaple2, ...) # e.g., 9 dof registration estimator_level2.fit(dataset_object) # Checks "input" and it's another estimator, so runs that first ``` Resolves: #12. Resolves: #21.
1 parent a8b4dd5 commit 4f32934

File tree

2 files changed

+128
-245
lines changed

2 files changed

+128
-245
lines changed

src/nifreeze/estimator.py

Lines changed: 89 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -22,229 +22,115 @@
2222
#
2323
"""A model-based algorithm for the realignment of dMRI data."""
2424

25+
from __future__ import annotations
26+
2527
from pathlib import Path
2628
from tempfile import TemporaryDirectory, mkstemp
29+
from typing import Self
2730

2831
import nibabel as nb
2932
from tqdm import tqdm
3033

34+
from nifreeze.data.base import BaseDataset
3135
from nifreeze.data.splitting import lovo_split
32-
from nifreeze.model.base import ModelFactory
36+
from nifreeze.model.base import BaseModel, ModelFactory
3337
from nifreeze.registration.ants import _prepare_registration_data, _run_registration
3438
from nifreeze.utils import iterators
3539

3640

41+
class Filter:
42+
"""Alters an input data object (e.g., downsampling)."""
43+
3744
class Estimator:
3845
"""Estimates rigid-body head-motion and distortions derived from eddy-currents."""
39-
40-
@staticmethod
41-
def estimate(
42-
data,
43-
*,
44-
align_kwargs=None,
45-
iter_kwargs=None,
46-
models=("b0",),
47-
omp_nthreads=None,
48-
n_jobs=None,
46+
47+
__slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
48+
49+
def __init__(
50+
self,
51+
model: BaseModel | str,
52+
strategy: str = "random",
53+
prev: Self | None = None,
54+
model_kwargs: dict | None = None,
4955
**kwargs,
5056
):
51-
r"""
52-
Estimate head-motion and Eddy currents.
53-
54-
Parameters
55-
----------
56-
data : :obj:`~nifreeze.dmri.DWI`
57-
The target DWI dataset, represented by this tool's internal
58-
type. The object is used in-place, and will contain the estimated
59-
parameters in its ``motion_affines`` property, as well as the rotated
60-
*b*-vectors within its ``gradients`` property.
61-
n_iter : :obj:`int`
62-
Number of iterations this particular model is going to be repeated.
63-
align_kwargs : :obj:`dict`
64-
Parameters to configure the image registration process.
65-
iter_kwargs : :obj:`dict`
66-
Parameters to configure the iterator strategy to traverse timepoints/orientations.
67-
models : :obj:`list`
68-
Selects the diffusion model that will generate the registration target
69-
corresponding to each gradient map.
70-
See :obj:`~nifreeze.model.ModelFactory` for allowed models (and corresponding
71-
keywords).
72-
omp_nthreads : :obj:`int`
73-
Maximum number of threads an individual process may use.
74-
n_jobs : :obj:`int`
75-
Number of parallel jobs.
76-
77-
Return
78-
------
79-
:obj:`list` of :obj:`numpy.ndarray`
80-
A list of :math:`4 \times 4` affine matrices encoding the estimated
81-
parameters of the deformations caused by head-motion and eddy-currents.
82-
83-
"""
84-
85-
# Massage iterator configuration
86-
iter_kwargs = iter_kwargs or {}
87-
iter_kwargs = {
88-
"seed": None,
89-
"bvals": None, # TODO: extract b-vals here if pertinent
90-
} | iter_kwargs
91-
iter_kwargs["size"] = len(data)
92-
93-
iterfunc = getattr(iterators, f"{iter_kwargs.pop('strategy', 'random')}_iterator")
94-
index_order = list(iterfunc(**iter_kwargs))
57+
self._model = model
58+
self._prev = prev
59+
self._strategy = strategy
60+
self._model_kwargs = model_kwargs
61+
self._align_kwargs = kwargs
62+
63+
def run(dataset: BaseDataset, **kwargs):
64+
if self._prev is not None:
65+
result = self._prev.run(dataset, **kwargs)
66+
if isinstance(self._prev, Filter):
67+
dataset = result
68+
69+
n_jobs = kwargs.get("n_jobs", None)
70+
71+
# Prepare iterator
72+
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
73+
index_iter = iterfunc(dataset, seed=kwargs.get("seed", None))
74+
75+
# Initialize model
76+
if isinstance(self._model, str):
77+
# Factory creates the appropriate model and pipes arguments
78+
self._model = ModelFactory.init(
79+
model=self._model,
80+
dataset=dataset,
81+
**self._model_kwargs,
82+
)
83+
84+
if self._model.is_static:
85+
self._model.fit(dataset, **kwargs)
9586

9687
align_kwargs = align_kwargs or {}
97-
98-
if "num_threads" not in align_kwargs and omp_nthreads is not None:
99-
align_kwargs["num_threads"] = omp_nthreads
100-
101-
n_iter = len(models)
88+
align_kwargs["num_threads"] = (
89+
align_kwargs.pop("omp_nthreads", None)
90+
or align_kwargs.pop("num_threads", None)
91+
)
10292

10393
reg_target_type = (
10494
align_kwargs.pop("fixed_modality", None),
10595
align_kwargs.pop("moving_modality", None),
10696
)
10797

108-
for i_iter, model in enumerate(models):
109-
# When downsampling these need to be set per-level
110-
bmask_img = _prepare_brainmask_data(data.brainmask, data.affine)
111-
112-
_prepare_kwargs(data, kwargs)
113-
114-
single_model = model.lower() in (
115-
"b0",
116-
"s0",
117-
"avg",
118-
"average",
119-
"mean",
120-
"gp",
121-
) or model.lower().startswith("full")
122-
123-
dwmodel = None
124-
if single_model:
125-
if model.lower().startswith("full"):
126-
model = model[4:]
127-
128-
# Factory creates the appropriate model and pipes arguments
129-
dwmodel = ModelFactory.init(
130-
model=model,
131-
**kwargs,
132-
)
133-
dwmodel.fit(data.dataobj, n_jobs=n_jobs)
134-
135-
with TemporaryDirectory() as tmp_dir:
136-
print(f"Processing in <{tmp_dir}>")
137-
ptmp_dir = Path(tmp_dir)
138-
with tqdm(total=len(index_order), unit="dwi") as pbar:
139-
# run a original-to-synthetic affine registration
140-
for i in index_order:
141-
pbar.set_description_str(
142-
f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>"
143-
)
144-
data_train, data_test = lovo_split(data, i)
145-
grad_str = f"{i}, {data_test[-1][:3]}, b={int(data_test[-1][3])}"
146-
pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs")
147-
148-
if not single_model: # A true LOGO estimator
149-
if hasattr(data, "gradients"):
150-
kwargs["gtab"] = data_train[-1]
151-
# Factory creates the appropriate model and pipes arguments
152-
dwmodel = ModelFactory.init(
153-
model=model,
154-
n_jobs=n_jobs,
155-
**kwargs,
156-
)
157-
158-
# fit the model
159-
dwmodel.fit(
160-
data_train[0],
161-
n_jobs=n_jobs,
162-
)
163-
164-
# generate a synthetic dw volume for the test gradient
165-
predicted = dwmodel.predict(data_test[-1])
166-
167-
# prepare data for running ANTs
168-
fixed, moving = _prepare_registration_data(
169-
data_test[0], predicted, data.affine, i, ptmp_dir, reg_target_type
170-
)
171-
172-
pbar.set_description_str(
173-
f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{i}>"
174-
)
175-
176-
xform = _run_registration(
177-
fixed,
178-
moving,
179-
bmask_img,
180-
data.motion_affines,
181-
data.affine,
182-
data.dataobj.shape[:3],
183-
data_test[-1][3],
184-
i_iter,
185-
i,
186-
ptmp_dir,
187-
reg_target_type,
188-
align_kwargs,
189-
)
190-
191-
# update
192-
data.set_transform(i, xform.matrix)
193-
pbar.update()
194-
195-
return data.motion_affines
196-
197-
198-
def _prepare_brainmask_data(brainmask, affine):
199-
"""Prepare the brainmask data: save the data to disk.
200-
201-
Parameters
202-
----------
203-
brainmask : :obj:`numpy.ndarray`
204-
Brainmask data.
205-
affine : :obj:`numpy.ndarray`
206-
Affine transformation matrix.
207-
208-
Returns
209-
-------
210-
bmask_img : :class:`~nibabel.nifti1.Nifti1Image`
211-
Brainmask image.
212-
"""
213-
214-
bmask_img = None
215-
if brainmask is not None:
216-
_, bmask_img = mkstemp(suffix="_bmask.nii.gz")
217-
nb.Nifti1Image(brainmask.astype("uint8"), affine, None).to_filename(bmask_img)
218-
return bmask_img
219-
220-
221-
def _prepare_kwargs(data, kwargs):
222-
"""Prepare the keyword arguments depending on the DWI data: add attributes corresponding to
223-
the ``brainmask``, ``bzero``, ``gradients``, ``frame_time``, and ``total_duration`` DWI data
224-
properties.
225-
226-
Modifies kwargs in-place.
227-
228-
Parameters
229-
----------
230-
data : :class:`nifreeze.data.dmri.DWI`
231-
DWI data object.
232-
kwargs : :obj:`dict`
233-
Keyword arguments.
234-
"""
235-
from nifreeze.data.filtering import advanced_clip as _advanced_clip
236-
237-
if data.brainmask is not None:
238-
kwargs["mask"] = data.brainmask
239-
240-
if hasattr(data, "bzero") and data.bzero is not None:
241-
kwargs["S0"] = _advanced_clip(data.bzero)
242-
243-
if hasattr(data, "gradients"):
244-
kwargs["gtab"] = data.gradients
245-
246-
if hasattr(data, "frame_time"):
247-
kwargs["timepoints"] = data.frame_time
248-
249-
if hasattr(data, "total_duration"):
250-
kwargs["xlim"] = data.total_duration
98+
dataset_length = len(dataset)
99+
with TemporaryDirectory() as tmp_dir:
100+
print(f"Processing in <{tmp_dir}>")
101+
ptmp_dir = Path(tmp_dir)
102+
with tqdm(total=dataset_length, unit="vols.") as pbar:
103+
# run a original-to-synthetic affine registration
104+
for i in index_iter:
105+
pbar.set_description_str(f"Fit and predict vol. <{i}>")
106+
107+
# fit the model
108+
reference, predicted = self._model.fit_predict(
109+
i,
110+
n_jobs=n_jobs,
111+
)
112+
113+
# prepare data for running ANTs
114+
fixed, moving = _prepare_registration_data(
115+
reference, predicted, dataset.affine, i, ptmp_dir, align_kwargs.get("clip", "both")
116+
)
117+
118+
pbar.set_description_str(f"Realign vol. <{i}>")
119+
120+
xform = _run_registration(
121+
fixed,
122+
moving,
123+
dataset.brainmask,
124+
dataset.motion_affines,
125+
dataset.affine,
126+
dataset.dataobj.shape[:3],
127+
i,
128+
ptmp_dir,
129+
**align_kwargs,
130+
)
131+
132+
# update
133+
dataset.set_transform(i, xform.matrix)
134+
pbar.update()
135+
136+
return self

0 commit comments

Comments
 (0)