|
22 | 22 | #
|
23 | 23 | """A model-based algorithm for the realignment of dMRI data."""
|
24 | 24 |
|
| 25 | +from __future__ import annotations |
| 26 | + |
25 | 27 | from pathlib import Path
|
26 | 28 | from tempfile import TemporaryDirectory, mkstemp
|
| 29 | +from typing import Self |
27 | 30 |
|
28 | 31 | import nibabel as nb
|
29 | 32 | from tqdm import tqdm
|
30 | 33 |
|
| 34 | +from nifreeze.data.base import BaseDataset |
31 | 35 | from nifreeze.data.splitting import lovo_split
|
32 |
| -from nifreeze.model.base import ModelFactory |
| 36 | +from nifreeze.model.base import BaseModel, ModelFactory |
33 | 37 | from nifreeze.registration.ants import _prepare_registration_data, _run_registration
|
34 | 38 | from nifreeze.utils import iterators
|
35 | 39 |
|
36 | 40 |
|
| 41 | +class Filter: |
| 42 | + """Alters an input data object (e.g., downsampling).""" |
| 43 | + |
37 | 44 | class Estimator:
|
38 | 45 | """Estimates rigid-body head-motion and distortions derived from eddy-currents."""
|
39 |
| - |
40 |
| - @staticmethod |
41 |
| - def estimate( |
42 |
| - data, |
43 |
| - *, |
44 |
| - align_kwargs=None, |
45 |
| - iter_kwargs=None, |
46 |
| - models=("b0",), |
47 |
| - omp_nthreads=None, |
48 |
| - n_jobs=None, |
| 46 | + |
| 47 | + __slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs") |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + model: BaseModel | str, |
| 52 | + strategy: str = "random", |
| 53 | + prev: Self | None = None, |
| 54 | + model_kwargs: dict | None = None, |
49 | 55 | **kwargs,
|
50 | 56 | ):
|
51 |
| - r""" |
52 |
| - Estimate head-motion and Eddy currents. |
53 |
| -
|
54 |
| - Parameters |
55 |
| - ---------- |
56 |
| - data : :obj:`~nifreeze.dmri.DWI` |
57 |
| - The target DWI dataset, represented by this tool's internal |
58 |
| - type. The object is used in-place, and will contain the estimated |
59 |
| - parameters in its ``motion_affines`` property, as well as the rotated |
60 |
| - *b*-vectors within its ``gradients`` property. |
61 |
| - n_iter : :obj:`int` |
62 |
| - Number of iterations this particular model is going to be repeated. |
63 |
| - align_kwargs : :obj:`dict` |
64 |
| - Parameters to configure the image registration process. |
65 |
| - iter_kwargs : :obj:`dict` |
66 |
| - Parameters to configure the iterator strategy to traverse timepoints/orientations. |
67 |
| - models : :obj:`list` |
68 |
| - Selects the diffusion model that will generate the registration target |
69 |
| - corresponding to each gradient map. |
70 |
| - See :obj:`~nifreeze.model.ModelFactory` for allowed models (and corresponding |
71 |
| - keywords). |
72 |
| - omp_nthreads : :obj:`int` |
73 |
| - Maximum number of threads an individual process may use. |
74 |
| - n_jobs : :obj:`int` |
75 |
| - Number of parallel jobs. |
76 |
| -
|
77 |
| - Return |
78 |
| - ------ |
79 |
| - :obj:`list` of :obj:`numpy.ndarray` |
80 |
| - A list of :math:`4 \times 4` affine matrices encoding the estimated |
81 |
| - parameters of the deformations caused by head-motion and eddy-currents. |
82 |
| -
|
83 |
| - """ |
84 |
| - |
85 |
| - # Massage iterator configuration |
86 |
| - iter_kwargs = iter_kwargs or {} |
87 |
| - iter_kwargs = { |
88 |
| - "seed": None, |
89 |
| - "bvals": None, # TODO: extract b-vals here if pertinent |
90 |
| - } | iter_kwargs |
91 |
| - iter_kwargs["size"] = len(data) |
92 |
| - |
93 |
| - iterfunc = getattr(iterators, f"{iter_kwargs.pop('strategy', 'random')}_iterator") |
94 |
| - index_order = list(iterfunc(**iter_kwargs)) |
| 57 | + self._model = model |
| 58 | + self._prev = prev |
| 59 | + self._strategy = strategy |
| 60 | + self._model_kwargs = model_kwargs |
| 61 | + self._align_kwargs = kwargs |
| 62 | + |
| 63 | + def run(dataset: BaseDataset, **kwargs): |
| 64 | + if self._prev is not None: |
| 65 | + result = self._prev.run(dataset, **kwargs) |
| 66 | + if isinstance(self._prev, Filter): |
| 67 | + dataset = result |
| 68 | + |
| 69 | + n_jobs = kwargs.get("n_jobs", None) |
| 70 | + |
| 71 | + # Prepare iterator |
| 72 | + iterfunc = getattr(iterators, f"{self._strategy}_iterator") |
| 73 | + index_iter = iterfunc(dataset, seed=kwargs.get("seed", None)) |
| 74 | + |
| 75 | + # Initialize model |
| 76 | + if isinstance(self._model, str): |
| 77 | + # Factory creates the appropriate model and pipes arguments |
| 78 | + self._model = ModelFactory.init( |
| 79 | + model=self._model, |
| 80 | + dataset=dataset, |
| 81 | + **self._model_kwargs, |
| 82 | + ) |
| 83 | + |
| 84 | + if self._model.is_static: |
| 85 | + self._model.fit(dataset, **kwargs) |
95 | 86 |
|
96 | 87 | align_kwargs = align_kwargs or {}
|
97 |
| - |
98 |
| - if "num_threads" not in align_kwargs and omp_nthreads is not None: |
99 |
| - align_kwargs["num_threads"] = omp_nthreads |
100 |
| - |
101 |
| - n_iter = len(models) |
| 88 | + align_kwargs["num_threads"] = ( |
| 89 | + align_kwargs.pop("omp_nthreads", None) |
| 90 | + or align_kwargs.pop("num_threads", None) |
| 91 | + ) |
102 | 92 |
|
103 | 93 | reg_target_type = (
|
104 | 94 | align_kwargs.pop("fixed_modality", None),
|
105 | 95 | align_kwargs.pop("moving_modality", None),
|
106 | 96 | )
|
107 | 97 |
|
108 |
| - for i_iter, model in enumerate(models): |
109 |
| - # When downsampling these need to be set per-level |
110 |
| - bmask_img = _prepare_brainmask_data(data.brainmask, data.affine) |
111 |
| - |
112 |
| - _prepare_kwargs(data, kwargs) |
113 |
| - |
114 |
| - single_model = model.lower() in ( |
115 |
| - "b0", |
116 |
| - "s0", |
117 |
| - "avg", |
118 |
| - "average", |
119 |
| - "mean", |
120 |
| - "gp", |
121 |
| - ) or model.lower().startswith("full") |
122 |
| - |
123 |
| - dwmodel = None |
124 |
| - if single_model: |
125 |
| - if model.lower().startswith("full"): |
126 |
| - model = model[4:] |
127 |
| - |
128 |
| - # Factory creates the appropriate model and pipes arguments |
129 |
| - dwmodel = ModelFactory.init( |
130 |
| - model=model, |
131 |
| - **kwargs, |
132 |
| - ) |
133 |
| - dwmodel.fit(data.dataobj, n_jobs=n_jobs) |
134 |
| - |
135 |
| - with TemporaryDirectory() as tmp_dir: |
136 |
| - print(f"Processing in <{tmp_dir}>") |
137 |
| - ptmp_dir = Path(tmp_dir) |
138 |
| - with tqdm(total=len(index_order), unit="dwi") as pbar: |
139 |
| - # run a original-to-synthetic affine registration |
140 |
| - for i in index_order: |
141 |
| - pbar.set_description_str( |
142 |
| - f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>" |
143 |
| - ) |
144 |
| - data_train, data_test = lovo_split(data, i) |
145 |
| - grad_str = f"{i}, {data_test[-1][:3]}, b={int(data_test[-1][3])}" |
146 |
| - pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs") |
147 |
| - |
148 |
| - if not single_model: # A true LOGO estimator |
149 |
| - if hasattr(data, "gradients"): |
150 |
| - kwargs["gtab"] = data_train[-1] |
151 |
| - # Factory creates the appropriate model and pipes arguments |
152 |
| - dwmodel = ModelFactory.init( |
153 |
| - model=model, |
154 |
| - n_jobs=n_jobs, |
155 |
| - **kwargs, |
156 |
| - ) |
157 |
| - |
158 |
| - # fit the model |
159 |
| - dwmodel.fit( |
160 |
| - data_train[0], |
161 |
| - n_jobs=n_jobs, |
162 |
| - ) |
163 |
| - |
164 |
| - # generate a synthetic dw volume for the test gradient |
165 |
| - predicted = dwmodel.predict(data_test[-1]) |
166 |
| - |
167 |
| - # prepare data for running ANTs |
168 |
| - fixed, moving = _prepare_registration_data( |
169 |
| - data_test[0], predicted, data.affine, i, ptmp_dir, reg_target_type |
170 |
| - ) |
171 |
| - |
172 |
| - pbar.set_description_str( |
173 |
| - f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{i}>" |
174 |
| - ) |
175 |
| - |
176 |
| - xform = _run_registration( |
177 |
| - fixed, |
178 |
| - moving, |
179 |
| - bmask_img, |
180 |
| - data.motion_affines, |
181 |
| - data.affine, |
182 |
| - data.dataobj.shape[:3], |
183 |
| - data_test[-1][3], |
184 |
| - i_iter, |
185 |
| - i, |
186 |
| - ptmp_dir, |
187 |
| - reg_target_type, |
188 |
| - align_kwargs, |
189 |
| - ) |
190 |
| - |
191 |
| - # update |
192 |
| - data.set_transform(i, xform.matrix) |
193 |
| - pbar.update() |
194 |
| - |
195 |
| - return data.motion_affines |
196 |
| - |
197 |
| - |
198 |
| -def _prepare_brainmask_data(brainmask, affine): |
199 |
| - """Prepare the brainmask data: save the data to disk. |
200 |
| -
|
201 |
| - Parameters |
202 |
| - ---------- |
203 |
| - brainmask : :obj:`numpy.ndarray` |
204 |
| - Brainmask data. |
205 |
| - affine : :obj:`numpy.ndarray` |
206 |
| - Affine transformation matrix. |
207 |
| -
|
208 |
| - Returns |
209 |
| - ------- |
210 |
| - bmask_img : :class:`~nibabel.nifti1.Nifti1Image` |
211 |
| - Brainmask image. |
212 |
| - """ |
213 |
| - |
214 |
| - bmask_img = None |
215 |
| - if brainmask is not None: |
216 |
| - _, bmask_img = mkstemp(suffix="_bmask.nii.gz") |
217 |
| - nb.Nifti1Image(brainmask.astype("uint8"), affine, None).to_filename(bmask_img) |
218 |
| - return bmask_img |
219 |
| - |
220 |
| - |
221 |
| -def _prepare_kwargs(data, kwargs): |
222 |
| - """Prepare the keyword arguments depending on the DWI data: add attributes corresponding to |
223 |
| - the ``brainmask``, ``bzero``, ``gradients``, ``frame_time``, and ``total_duration`` DWI data |
224 |
| - properties. |
225 |
| -
|
226 |
| - Modifies kwargs in-place. |
227 |
| -
|
228 |
| - Parameters |
229 |
| - ---------- |
230 |
| - data : :class:`nifreeze.data.dmri.DWI` |
231 |
| - DWI data object. |
232 |
| - kwargs : :obj:`dict` |
233 |
| - Keyword arguments. |
234 |
| - """ |
235 |
| - from nifreeze.data.filtering import advanced_clip as _advanced_clip |
236 |
| - |
237 |
| - if data.brainmask is not None: |
238 |
| - kwargs["mask"] = data.brainmask |
239 |
| - |
240 |
| - if hasattr(data, "bzero") and data.bzero is not None: |
241 |
| - kwargs["S0"] = _advanced_clip(data.bzero) |
242 |
| - |
243 |
| - if hasattr(data, "gradients"): |
244 |
| - kwargs["gtab"] = data.gradients |
245 |
| - |
246 |
| - if hasattr(data, "frame_time"): |
247 |
| - kwargs["timepoints"] = data.frame_time |
248 |
| - |
249 |
| - if hasattr(data, "total_duration"): |
250 |
| - kwargs["xlim"] = data.total_duration |
| 98 | + dataset_length = len(dataset) |
| 99 | + with TemporaryDirectory() as tmp_dir: |
| 100 | + print(f"Processing in <{tmp_dir}>") |
| 101 | + ptmp_dir = Path(tmp_dir) |
| 102 | + with tqdm(total=dataset_length, unit="vols.") as pbar: |
| 103 | + # run a original-to-synthetic affine registration |
| 104 | + for i in index_iter: |
| 105 | + pbar.set_description_str(f"Fit and predict vol. <{i}>") |
| 106 | + |
| 107 | + # fit the model |
| 108 | + reference, predicted = self._model.fit_predict( |
| 109 | + i, |
| 110 | + n_jobs=n_jobs, |
| 111 | + ) |
| 112 | + |
| 113 | + # prepare data for running ANTs |
| 114 | + fixed, moving = _prepare_registration_data( |
| 115 | + reference, predicted, dataset.affine, i, ptmp_dir, align_kwargs.get("clip", "both") |
| 116 | + ) |
| 117 | + |
| 118 | + pbar.set_description_str(f"Realign vol. <{i}>") |
| 119 | + |
| 120 | + xform = _run_registration( |
| 121 | + fixed, |
| 122 | + moving, |
| 123 | + dataset.brainmask, |
| 124 | + dataset.motion_affines, |
| 125 | + dataset.affine, |
| 126 | + dataset.dataobj.shape[:3], |
| 127 | + i, |
| 128 | + ptmp_dir, |
| 129 | + **align_kwargs, |
| 130 | + ) |
| 131 | + |
| 132 | + # update |
| 133 | + dataset.set_transform(i, xform.matrix) |
| 134 | + pbar.update() |
| 135 | + |
| 136 | + return self |
0 commit comments