From d5b63d02037823d0e94e20def434925297384bf7 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Mon, 20 Jan 2025 13:55:36 +0100 Subject: [PATCH 01/11] enh: Implicitly-pipelined and modality-agnostic ``Estimator`` Changes our current implementation of the estimator with a new architecture that allows stacking (https://github.com/nipreps/nifreeze/issues/12#issuecomment-2598471780): ```Python estimator_level1 = Estimator(model="b0", ...) # e.g., 6 dof registration estimator_level2 = Estimator(model="b0", input=estimator_level1, ...) # e.g., 9 dof registration estimator_level2.fit(dataset_object) # Checks "input" and it's another estimator, so runs that first ``` This allow for adding *filters* (to be implemented): ``` Python downsample = Filter(...) # e.g., downsampling filter estimator_level1 = Estimator(model="b0", input=downsample, ...) # e.g., 6 dof registration estimator_level2 = Estimator(model="b0", input=estimator_level1, ...) # e.g., 9 dof registration estimator_level2.fit(dataset_object) # Checks "input" and it's another estimator, so runs that first ``` In this case, the filtered dataset only feeds ``estimator_level1``. The second level will work on a full dataset. If you want to interleave another downsampling filter: ``` Python downsample1 = Filter(...) # e.g., downsampling filter 1 estimator_level1 = Estimator(model="b0", input=downsample, ...) # e.g., 6 dof registration downsample2 = Filter(input=estimator_level1, ...) # e.g., downsampling filter 2 estimator_level2 = Estimator(model="b0", input=downsaple2, ...) # e.g., 9 dof registration estimator_level2.fit(dataset_object) # Checks "input" and it's another estimator, so runs that first ``` Resolves: #12. Resolves: #21. --- src/nifreeze/estimator.py | 307 +++++++++--------------------- src/nifreeze/model/base.py | 5 +- src/nifreeze/registration/ants.py | 83 ++++---- 3 files changed, 137 insertions(+), 258 deletions(-) diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index e9d59c39..643d9ed5 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -22,229 +22,110 @@ # """A model-based algorithm for the realignment of dMRI data.""" +from __future__ import annotations + from pathlib import Path -from tempfile import TemporaryDirectory, mkstemp +from tempfile import TemporaryDirectory +from typing import Self -import nibabel as nb from tqdm import tqdm -from nifreeze.data.splitting import lovo_split -from nifreeze.model.base import ModelFactory +from nifreeze.data.base import BaseDataset +from nifreeze.model.base import BaseModel, ModelFactory from nifreeze.registration.ants import _prepare_registration_data, _run_registration from nifreeze.utils import iterators +class Filter: + """Alters an input data object (e.g., downsampling).""" + + class Estimator: """Estimates rigid-body head-motion and distortions derived from eddy-currents.""" - @staticmethod - def estimate( - data, - *, - align_kwargs=None, - iter_kwargs=None, - models=("b0",), - omp_nthreads=None, - n_jobs=None, + __slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs") + + def __init__( + self, + model: BaseModel | str, + strategy: str = "random", + prev: Self | None = None, + model_kwargs: dict | None = None, **kwargs, ): - r""" - Estimate head-motion and Eddy currents. - - Parameters - ---------- - data : :obj:`~nifreeze.dmri.DWI` - The target DWI dataset, represented by this tool's internal - type. The object is used in-place, and will contain the estimated - parameters in its ``motion_affines`` property, as well as the rotated - *b*-vectors within its ``gradients`` property. - n_iter : :obj:`int` - Number of iterations this particular model is going to be repeated. - align_kwargs : :obj:`dict` - Parameters to configure the image registration process. - iter_kwargs : :obj:`dict` - Parameters to configure the iterator strategy to traverse timepoints/orientations. - models : :obj:`list` - Selects the diffusion model that will generate the registration target - corresponding to each gradient map. - See :obj:`~nifreeze.model.ModelFactory` for allowed models (and corresponding - keywords). - omp_nthreads : :obj:`int` - Maximum number of threads an individual process may use. - n_jobs : :obj:`int` - Number of parallel jobs. - - Return - ------ - :obj:`list` of :obj:`numpy.ndarray` - A list of :math:`4 \times 4` affine matrices encoding the estimated - parameters of the deformations caused by head-motion and eddy-currents. - - """ - - # Massage iterator configuration - iter_kwargs = iter_kwargs or {} - iter_kwargs = { - "seed": None, - "bvals": None, # TODO: extract b-vals here if pertinent - } | iter_kwargs - iter_kwargs["size"] = len(data) - - iterfunc = getattr(iterators, f"{iter_kwargs.pop('strategy', 'random')}_iterator") - index_order = list(iterfunc(**iter_kwargs)) - - align_kwargs = align_kwargs or {} - - if "num_threads" not in align_kwargs and omp_nthreads is not None: - align_kwargs["num_threads"] = omp_nthreads - - n_iter = len(models) - - reg_target_type = ( - align_kwargs.pop("fixed_modality", None), - align_kwargs.pop("moving_modality", None), - ) - - for i_iter, model in enumerate(models): - # When downsampling these need to be set per-level - bmask_img = _prepare_brainmask_data(data.brainmask, data.affine) - - _prepare_kwargs(data, kwargs) - - single_model = model.lower() in ( - "b0", - "s0", - "avg", - "average", - "mean", - "gp", - ) or model.lower().startswith("full") - - dwmodel = None - if single_model: - if model.lower().startswith("full"): - model = model[4:] - - # Factory creates the appropriate model and pipes arguments - dwmodel = ModelFactory.init( - model=model, - **kwargs, - ) - dwmodel.fit(data.dataobj, n_jobs=n_jobs) - - with TemporaryDirectory() as tmp_dir: - print(f"Processing in <{tmp_dir}>") - ptmp_dir = Path(tmp_dir) - with tqdm(total=len(index_order), unit="dwi") as pbar: - # run a original-to-synthetic affine registration - for i in index_order: - pbar.set_description_str( - f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>" - ) - data_train, data_test = lovo_split(data, i) - grad_str = f"{i}, {data_test[-1][:3]}, b={int(data_test[-1][3])}" - pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs") - - if not single_model: # A true LOGO estimator - if hasattr(data, "gradients"): - kwargs["gtab"] = data_train[-1] - # Factory creates the appropriate model and pipes arguments - dwmodel = ModelFactory.init( - model=model, - n_jobs=n_jobs, - **kwargs, - ) - - # fit the model - dwmodel.fit( - data_train[0], - n_jobs=n_jobs, - ) - - # generate a synthetic dw volume for the test gradient - predicted = dwmodel.predict(data_test[-1]) - - # prepare data for running ANTs - fixed, moving = _prepare_registration_data( - data_test[0], predicted, data.affine, i, ptmp_dir, reg_target_type - ) - - pbar.set_description_str( - f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{i}>" - ) - - xform = _run_registration( - fixed, - moving, - bmask_img, - data.motion_affines, - data.affine, - data.dataobj.shape[:3], - data_test[-1][3], - i_iter, - i, - ptmp_dir, - reg_target_type, - align_kwargs, - ) - - # update - data.set_transform(i, xform.matrix) - pbar.update() - - return data.motion_affines - - -def _prepare_brainmask_data(brainmask, affine): - """Prepare the brainmask data: save the data to disk. - - Parameters - ---------- - brainmask : :obj:`numpy.ndarray` - Brainmask data. - affine : :obj:`numpy.ndarray` - Affine transformation matrix. - - Returns - ------- - bmask_img : :class:`~nibabel.nifti1.Nifti1Image` - Brainmask image. - """ - - bmask_img = None - if brainmask is not None: - _, bmask_img = mkstemp(suffix="_bmask.nii.gz") - nb.Nifti1Image(brainmask.astype("uint8"), affine, None).to_filename(bmask_img) - return bmask_img - - -def _prepare_kwargs(data, kwargs): - """Prepare the keyword arguments depending on the DWI data: add attributes corresponding to - the ``brainmask``, ``bzero``, ``gradients``, ``frame_time``, and ``total_duration`` DWI data - properties. - - Modifies kwargs in-place. - - Parameters - ---------- - data : :class:`nifreeze.data.dmri.DWI` - DWI data object. - kwargs : :obj:`dict` - Keyword arguments. - """ - from nifreeze.data.filtering import advanced_clip as _advanced_clip - - if data.brainmask is not None: - kwargs["mask"] = data.brainmask - - if hasattr(data, "bzero") and data.bzero is not None: - kwargs["S0"] = _advanced_clip(data.bzero) - - if hasattr(data, "gradients"): - kwargs["gtab"] = data.gradients - - if hasattr(data, "frame_time"): - kwargs["timepoints"] = data.frame_time - - if hasattr(data, "total_duration"): - kwargs["xlim"] = data.total_duration + self._model = model + self._prev = prev + self._strategy = strategy + self._model_kwargs = model_kwargs + self._align_kwargs = kwargs + + def run(self, dataset: BaseDataset, **kwargs): + if self._prev is not None: + result = self._prev.run(dataset, **kwargs) + if isinstance(self._prev, Filter): + dataset = result + + n_jobs = kwargs.get("n_jobs", None) + + # Prepare iterator + iterfunc = getattr(iterators, f"{self._strategy}_iterator") + index_iter = iterfunc(dataset, seed=kwargs.get("seed", None)) + + # Initialize model + if isinstance(self._model, str): + # Factory creates the appropriate model and pipes arguments + self._model = ModelFactory.init( + model=self._model, + dataset=dataset, + **self._model_kwargs, + ) + + if self._model.is_static: + self._model.fit(dataset, **kwargs) + + kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None) + + dataset_length = len(dataset) + with TemporaryDirectory() as tmp_dir: + print(f"Processing in <{tmp_dir}>") + ptmp_dir = Path(tmp_dir) + with tqdm(total=dataset_length, unit="vols.") as pbar: + # run a original-to-synthetic affine registration + for i in index_iter: + pbar.set_description_str(f"Fit and predict vol. <{i}>") + + # fit the model + reference, predicted = self._model.fit_predict( + i, + n_jobs=n_jobs, + ) + + # prepare data for running ANTs + fixed, moving = _prepare_registration_data( + reference, + predicted, + dataset.affine, + i, + ptmp_dir, + kwargs.pop("clip", "both"), + ) + + pbar.set_description_str(f"Realign vol. <{i}>") + + xform = _run_registration( + fixed, + moving, + dataset.brainmask, + dataset.motion_affines, + dataset.affine, + dataset.dataobj.shape[:3], + i, + ptmp_dir, + **kwargs, + ) + + # update + dataset.set_transform(i, xform.matrix) + pbar.update() + + return self diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index 748285fe..eca1abeb 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -96,7 +96,10 @@ def __init__(self, mask=None, **kwargs): # Setup brain mask if mask is None: - warn("No mask provided; consider using a mask to avoid issues in model optimization.") + warn( + "No mask provided; consider using a mask to avoid issues in model optimization.", + stacklevel=2, + ) self._mask = mask diff --git a/src/nifreeze/registration/ants.py b/src/nifreeze/registration/ants.py index f6bba237..8a0976c2 100644 --- a/src/nifreeze/registration/ants.py +++ b/src/nifreeze/registration/ants.py @@ -92,49 +92,57 @@ def _to_nifti( def _prepare_registration_data( - dwframe: np.ndarray, + fixed: np.ndarray, predicted: np.ndarray, affine: np.ndarray, vol_idx: int, dirname: Path | str, - reg_target_type: str, + clip: str | None = None, ) -> tuple[Path, Path]: """ Prepare the registration data: save the fixed and moving images to disk. Parameters ---------- - dwframe : :obj:`~numpy.ndarray` - DWI data object. + fixed : :obj:`~numpy.ndarray` + Reference volume's data array. predicted : :obj:`~numpy.ndarray` - Predicted data. + Predicted volume's data array. affine : :obj:`numpy.ndarray` - Affine transformation matrix. + Orientation affine from the original NIfTI. vol_idx : :obj:`int` - DWI volume index. + Volume index. dirname : :obj:`os.pathlike` Directory name where the data is saved. - reg_target_type : :obj:`str` - Target registration type. + clip : :obj:`str` or ``None`` + Clip intensity of ``"fixed"``, ``"moving"``, ``"both"``, + or ``"none"`` of the images. Returns ------- - fixed : :obj:`~pathlib.Path` + fixed_path : :obj:`~pathlib.Path` Fixed image filename. - moving : :obj:`~pathlib.Path` + moving_path : :obj:`~pathlib.Path` Moving image filename. + """ + clip = clip or "none" - moving = Path(dirname) / f"moving{vol_idx:05d}.nii.gz" - fixed = Path(dirname) / f"fixed{vol_idx:05d}.nii.gz" - _to_nifti(dwframe, affine, moving) + moving_path = Path(dirname) / f"moving{vol_idx:05d}.nii.gz" + fixed_path = Path(dirname) / f"fixed{vol_idx:05d}.nii.gz" + _to_nifti( + fixed, + affine, + moving_path, + clip=clip.lower() in ("fixed", "both"), + ) _to_nifti( predicted, affine, - fixed, - clip=reg_target_type == "dwi", + fixed_path, + clip=clip.lower() in ("moving", "both"), ) - return fixed, moving + return fixed_path, moving_path def _get_ants_settings(settings: str = "b0-to-b0_level0") -> Path: @@ -408,12 +416,9 @@ def _run_registration( em_affines: np.ndarray, affine: np.ndarray, shape: tuple[int, int, int], - bval: int, - i_iter: int, vol_idx: int, dirname: Path, - reg_target_type: str, - align_kwargs: dict, + **kwargs: dict, ) -> nt.base.BaseTransform: """ Register the moving image to the fixed image. @@ -427,22 +432,16 @@ def _run_registration( bmask_img : :class:`~nibabel.spatialimages.SpatialImage` Brainmask image. em_affines : :obj:`numpy.ndarray` - Estimated eddy motion affine transformation matrices. + Estimated head-motion affine transformation matrices. affine : :obj:`numpy.ndarray` - Affine transformation matrix. + Orientation affine from the original NIfTI. shape : :obj:`tuple` - Shape of the DWI frame. - bval : :obj:`int` - b-value of the corresponding DWI volume. - i_iter : :obj:`int` - Iteration number. + 3D shape of dataset. vol_idx : :obj:`int` - DWI frame index. + Dataset volume index. dirname : :obj:`Path` Directory name where the transformation is saved. - reg_target_type : :obj:`str` - Target registration type. - align_kwargs : :obj:`dict` + kwargs : :obj:`dict` Parameters to configure the image registration process. Returns @@ -452,18 +451,16 @@ def _run_registration( """ - if isinstance(reg_target_type, str): - reg_target_type = (reg_target_type, reg_target_type) - + if "config_file" in kwargs: + kwargs["from_file"] = pkg_fn( + "nifreeze.registration", + f"config/{kwargs.pop('config_file')}", + ) registration = Registration( terminal_output="file", - from_file=pkg_fn( - "nifreeze.registration", - f"config/{reg_target_type[0]}-to-{reg_target_type[1]}_level{i_iter}.json", - ), fixed_image=str(fixed.absolute()), moving_image=str(moving.absolute()), - **align_kwargs, + **kwargs, ) if bmask_img: registration.inputs.fixed_image_masks = ["NULL", bmask_img] @@ -472,7 +469,7 @@ def _run_registration( ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) reference = ImageGrid(shape=shape, affine=affine) initial_xform = Affine(matrix=em_affines[vol_idx], reference=reference) - mat_file = dirname / f"init_{i_iter}_{vol_idx:05d}.mat" + mat_file = dirname / f"init_{vol_idx:05d}.mat" initial_xform.to_filename(mat_file, fmt="itk") registration.inputs.initial_moving_transform = str(mat_file) @@ -486,8 +483,6 @@ def _run_registration( ), ) # debugging: generate aligned file for testing - xform.apply(moving, reference=fixed).to_filename( - dirname / f"aligned{vol_idx:05d}_{int(bval):04d}.nii.gz" - ) + xform.apply(moving, reference=fixed).to_filename(dirname / f"aligned{vol_idx:05d}.nii.gz") return xform From 9706c1ea23c0f20733430f6fce1a10e275b794cd Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 21 Jan 2025 10:18:41 +0100 Subject: [PATCH 02/11] fix: update to new Estimator object and dw mentions --- src/nifreeze/cli/run.py | 10 +++++----- test/test_integration.py | 11 +++++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/nifreeze/cli/run.py b/src/nifreeze/cli/run.py index b4c32fc4..7c58cafd 100644 --- a/src/nifreeze/cli/run.py +++ b/src/nifreeze/cli/run.py @@ -25,7 +25,7 @@ from pathlib import Path from nifreeze.cli.parser import parse_args -from nifreeze.data.dmri import DWI +from nifreeze.data.base import BaseDataset from nifreeze.estimator import Estimator @@ -40,12 +40,12 @@ def main(argv=None) -> None: args = parse_args(argv) # Open the data with the given file path - dwi_dataset: DWI = DWI.from_filename(args.input_file) + dataset: BaseDataset = BaseDataset.from_filename(args.input_file) estimator: Estimator = Estimator() - _ = estimator.estimate( - dwi_dataset, + _ = estimator.run( + dataset, align_kwargs=args.align_config, models=args.models, omp_nthreads=args.nthreads, @@ -58,7 +58,7 @@ def main(argv=None) -> None: output_path: Path = Path(args.output_dir) / output_filename # Save the DWI dataset to the output path - dwi_dataset.to_filename(output_path) + dataset.to_filename(output_path) if __name__ == "__main__": diff --git a/test/test_integration.py b/test/test_integration.py index 858c76f8..a6e54f77 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -70,26 +70,25 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path): brainmask=dwdata.brainmask, ) - estimator = Estimator() - em_affines = estimator.estimate( + estimator = Estimator(dwi_motion, model="b0") + estimator.run( data=dwi_motion, models=("b0",), seed=None, align_kwargs={ - "fixed_modality": "b0", - "moving_modality": "b0", + "config_file": "b0-to-b0_level0.json", "num_threads": min(cpu_count(), 8), }, ) # Uncomment to see the realigned dataset nt.linear.LinearTransformsMapping( - em_affines, + dwi_motion.motion_affines, reference=b0nii, ).apply(moved_nii).to_filename(tmp_path / "realigned.nii.gz") # For each moved b0 volume - for i, est in enumerate(em_affines): + for i, est in enumerate(dwi_motion.motion_affines): assert ( displacements_within_mask( masknii, From 544249b5048b8b484435a37e8c35265e00ac3419 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 22 Jan 2025 08:01:59 +0100 Subject: [PATCH 03/11] fix: continue adaptations to new Estimator --- src/nifreeze/estimator.py | 39 ++++++++++++++++++++++++++++++++++---- src/nifreeze/model/base.py | 7 +++++-- test/test_integration.py | 5 ++--- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index 643d9ed5..db231de2 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -39,11 +39,28 @@ class Filter: """Alters an input data object (e.g., downsampling).""" + def run(self, dataset: BaseDataset, **kwargs): + """ + Trigger execution of the designated filter. + + Parameters + ---------- + dataset : :obj:`~nifreeze.data.base.BaseDataset` + The input dataset this estimator operates on. + + Returns + ------- + :obj:`~nifreeze.estimator.Estimator` + The estimator, after fitting. + + """ + return dataset + class Estimator: """Estimates rigid-body head-motion and distortions derived from eddy-currents.""" - __slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs") + __slots__ = ("_model", "_strategy", "_dataset", "_prev", "_model_kwargs", "_align_kwargs") def __init__( self, @@ -56,10 +73,24 @@ def __init__( self._model = model self._prev = prev self._strategy = strategy - self._model_kwargs = model_kwargs - self._align_kwargs = kwargs + self._model_kwargs = model_kwargs or {} + self._align_kwargs = kwargs or {} def run(self, dataset: BaseDataset, **kwargs): + """ + Trigger execution of the workflow this estimator belongs. + + Parameters + ---------- + dataset : :obj:`~nifreeze.data.base.BaseDataset` + The input dataset this estimator operates on. + + Returns + ------- + :obj:`~nifreeze.estimator.Estimator` + The estimator, after fitting. + + """ if self._prev is not None: result = self._prev.run(dataset, **kwargs) if isinstance(self._prev, Filter): @@ -69,7 +100,7 @@ def run(self, dataset: BaseDataset, **kwargs): # Prepare iterator iterfunc = getattr(iterators, f"{self._strategy}_iterator") - index_iter = iterfunc(dataset, seed=kwargs.get("seed", None)) + index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None)) # Initialize model if isinstance(self._model, str): diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index eca1abeb..8fc251cf 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -33,7 +33,7 @@ class ModelFactory: """A factory for instantiating data models.""" @staticmethod - def init(model="DTI", **kwargs): + def init(model=None, **kwargs): """ Instantiate a diffusion model. @@ -49,6 +49,9 @@ def init(model="DTI", **kwargs): A model object compliant with DIPY's interface. """ + if model is None: + raise RuntimeError("No model identifier provided.") + if model.lower() in ("s0", "b0"): return TrivialModel(predicted=kwargs.pop("S0"), gtab=kwargs.pop("gtab")) @@ -143,7 +146,7 @@ def fit(self, data, **kwargs): """Do nothing.""" def predict(self, *_, **kwargs): - """Return the *b=0* map.""" + """Return the reference map.""" # No need to check fit (if not fitted, has raised already) return self._predicted diff --git a/test/test_integration.py b/test/test_integration.py index a6e54f77..33a2bd11 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -70,10 +70,9 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path): brainmask=dwdata.brainmask, ) - estimator = Estimator(dwi_motion, model="b0") + estimator = Estimator("b0") estimator.run( - data=dwi_motion, - models=("b0",), + dwi_motion, seed=None, align_kwargs={ "config_file": "b0-to-b0_level0.json", From 7322e93bc3eda2ebe5d2036b1f90c0f10596716c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 22 Jan 2025 17:27:54 +0100 Subject: [PATCH 04/11] enh: continue with the refactor --- docs/notebooks/bold_realignment.ipynb | 6 +- scripts/optimize_registration.py | 2 +- src/nifreeze/data/dmri.py | 110 +++++++++- src/nifreeze/estimator.py | 38 ++-- src/nifreeze/model/__init__.py | 4 +- src/nifreeze/model/_dipy.py | 22 -- src/nifreeze/model/base.py | 137 +++++-------- src/nifreeze/model/dmri.py | 285 ++++++++------------------ src/nifreeze/registration/ants.py | 130 ++++++------ test/conftest.py | 2 +- test/test_data_dmri.py | 5 +- test/test_integration.py | 11 +- test/test_model.py | 108 ++++------ 13 files changed, 390 insertions(+), 470 deletions(-) diff --git a/docs/notebooks/bold_realignment.ipynb b/docs/notebooks/bold_realignment.ipynb index 55bf7916..4f491c89 100644 --- a/docs/notebooks/bold_realignment.ipynb +++ b/docs/notebooks/bold_realignment.ipynb @@ -337,7 +337,7 @@ "metadata": {}, "outputs": [], "source": [ - "from nifreeze.model.base import AverageModel\n", + "from nifreeze.model.base import ExpectationModel\n", "from nifreeze.utils.iterators import random_iterator" ] }, @@ -358,7 +358,7 @@ " t_mask[t] = True\n", "\n", " # Fit and predict using the model\n", - " model = AverageModel()\n", + " model = ExpectationModel()\n", " model.fit(\n", " data[..., ~t_mask],\n", " stat=\"median\",\n", @@ -376,7 +376,7 @@ " fixedmask_path=brainmask_path,\n", " output_transform_prefix=f\"conversion-{t:02d}\",\n", " num_threads=8,\n", - " )\n", + " ).cmdline\n", "\n", " # Run the command\n", " proc = await asyncio.create_subprocess_shell(\n", diff --git a/scripts/optimize_registration.py b/scripts/optimize_registration.py index d9f732ad..9232aa6f 100644 --- a/scripts/optimize_registration.py +++ b/scripts/optimize_registration.py @@ -133,7 +133,7 @@ async def train_coro( fixedmask_path=brainmask_path, output_transform_prefix=f"conversion-{index:04d}", **align_kwargs, - ) + ).cmdline tasks.append( ants( diff --git a/src/nifreeze/data/dmri.py b/src/nifreeze/data/dmri.py index 6a32496c..8d033263 100644 --- a/src/nifreeze/data/dmri.py +++ b/src/nifreeze/data/dmri.py @@ -37,6 +37,30 @@ from nifreeze.data.base import BaseDataset, _cmp, _data_repr +DEFAULT_CLIP_PERCENTILE = 75 +"""Upper percentile threshold for intensity clipping.""" + +DEFAULT_MIN_S0 = 1e-5 +"""Minimum value when considering the :math:`S_{0}` DWI signal.""" + +DEFAULT_MAX_S0 = 1.0 +"""Maximum value when considering the :math:`S_{0}` DWI signal.""" + +DEFAULT_LOWB_THRESHOLD = 50 +"""The lower bound for the b-value so that the orientation is considered a DW volume.""" + +DEFAULT_HIGHB_THRESHOLD = 8000 +"""A b-value cap for DWI data.""" + +DEFAULT_NUM_BINS = 15 +"""Number of bins to classify b-values.""" + +DEFAULT_MULTISHELL_BIN_COUNT_THR = 7 +"""Default bin count to consider a multishell scheme.""" + +DTI_MIN_ORIENTATIONS = 6 +"""Minimum number of nonzero b-values in a DWI dataset.""" + @attr.s(slots=True) class DWI(BaseDataset): @@ -221,7 +245,7 @@ def load( bvec_file: Path | str | None = None, bval_file: Path | str | None = None, b0_file: Path | str | None = None, - b0_thres: float = 50.0, + b0_thres: float = DEFAULT_LOWB_THRESHOLD, ) -> DWI: """ Load DWI data and construct a DWI object. @@ -337,3 +361,87 @@ def load( dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) return dwi_obj + + +def find_shelling_scheme( + bvals, + num_bins=DEFAULT_NUM_BINS, + multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR, + bval_cap=DEFAULT_HIGHB_THRESHOLD, +): + """ + Find the shelling scheme on the given b-values. + + Computes the histogram of the b-values according to ``num_bins`` + and depending on the nonempty bin count, classify the shelling scheme + as single-shell if they are 2 (low-b and a shell); multi-shell if they are + below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise. + + Parameters + ---------- + bvals : :obj:`list` or :obj:`~numpy.ndarray` + List or array of b-values. + num_bins : :obj:`int`, optional + Number of bins. + multishell_nonempty_bin_count_thr : :obj:`int`, optional + Bin count to consider a multi-shell scheme. + + Returns + ------- + scheme : :obj:`str` + Shelling scheme. + bval_groups : :obj:`list` + List of grouped b-values. + bval_estimated : :obj:`list` + List of 'estimated' b-values as the median value of each b-value group. + + """ + + # Bin the b-values: use -1 as the lower bound to be able to appropriately + # include b0 values + hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap))) + + # Collect values in each bin + bval_groups = [] + bval_estimated = [] + for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False): + # Add only if a nonempty b-values mask + if (mask := (bvals > lower) & (bvals <= upper)).sum(): + bval_groups.append(bvals[mask]) + bval_estimated.append(np.median(bvals[mask])) + + nonempty_bins = len(bval_groups) + + if nonempty_bins < 2: + raise ValueError("DWI must have at least one high-b shell") + + if nonempty_bins == 2: + scheme = "single-shell" + elif nonempty_bins < multishell_nonempty_bin_count_thr: + scheme = "multi-shell" + else: + scheme = "DSI" + + return scheme, bval_groups, bval_estimated + + +def _rasb2dipy(gradient): + import warnings + + gradient = np.asanyarray(gradient) + if gradient.ndim == 1: + if gradient.size != 4: + raise ValueError("Missing gradient information.") + gradient = gradient[..., np.newaxis] + + if gradient.shape[0] != 4: + gradient = gradient.T + elif gradient.shape == (4, 4): + print("Warning: make sure gradient information is not transposed!") + + with warnings.catch_warnings(): + from dipy.core.gradients import gradient_table + + warnings.filterwarnings("ignore", category=UserWarning) + retval = gradient_table(gradient[3, :], bvecs=gradient[:3, :].T) + return retval diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index db231de2..244c027f 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -32,7 +32,10 @@ from nifreeze.data.base import BaseDataset from nifreeze.model.base import BaseModel, ModelFactory -from nifreeze.registration.ants import _prepare_registration_data, _run_registration +from nifreeze.registration.ants import ( + _prepare_registration_data, + _run_registration, +) from nifreeze.utils import iterators @@ -60,7 +63,7 @@ def run(self, dataset: BaseDataset, **kwargs): class Estimator: """Estimates rigid-body head-motion and distortions derived from eddy-currents.""" - __slots__ = ("_model", "_strategy", "_dataset", "_prev", "_model_kwargs", "_align_kwargs") + __slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs") def __init__( self, @@ -111,29 +114,37 @@ def run(self, dataset: BaseDataset, **kwargs): **self._model_kwargs, ) - if self._model.is_static: - self._model.fit(dataset, **kwargs) - kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None) dataset_length = len(dataset) with TemporaryDirectory() as tmp_dir: print(f"Processing in <{tmp_dir}>") ptmp_dir = Path(tmp_dir) + + bmask_path = None + if dataset.brainmask is not None: + import nibabel as nb + + bmask_path = ptmp_dir / "brainmask.nii.gz" + nb.Nifti1Image( + dataset.brainmask.astype("uint8"), dataset.affine, None + ).to_filename(bmask_path) + with tqdm(total=dataset_length, unit="vols.") as pbar: # run a original-to-synthetic affine registration for i in index_iter: pbar.set_description_str(f"Fit and predict vol. <{i}>") # fit the model - reference, predicted = self._model.fit_predict( + test_set = dataset[i] + predicted = self._model.fit_predict( i, n_jobs=n_jobs, ) # prepare data for running ANTs - fixed, moving = _prepare_registration_data( - reference, + predicted_path, volume_path, init_path = _prepare_registration_data( + test_set[0], predicted, dataset.affine, i, @@ -144,14 +155,13 @@ def run(self, dataset: BaseDataset, **kwargs): pbar.set_description_str(f"Realign vol. <{i}>") xform = _run_registration( - fixed, - moving, - dataset.brainmask, - dataset.motion_affines, - dataset.affine, - dataset.dataobj.shape[:3], + predicted_path, + volume_path, i, ptmp_dir, + init_affine=init_path, + fixedmask_path=bmask_path, + output_transform_prefix=f"ants-{i:05d}", **kwargs, ) diff --git a/src/nifreeze/model/__init__.py b/src/nifreeze/model/__init__.py index d3b425b0..edbe6d3e 100644 --- a/src/nifreeze/model/__init__.py +++ b/src/nifreeze/model/__init__.py @@ -23,7 +23,7 @@ """Data models.""" from nifreeze.model.base import ( - AverageModel, + ExpectationModel, ModelFactory, TrivialModel, ) @@ -37,7 +37,7 @@ __all__ = ( "ModelFactory", - "AverageModel", + "ExpectationModel", "AverageDWIModel", "DKIModel", "DTIModel", diff --git a/src/nifreeze/model/_dipy.py b/src/nifreeze/model/_dipy.py index b501c8b6..5be3c040 100644 --- a/src/nifreeze/model/_dipy.py +++ b/src/nifreeze/model/_dipy.py @@ -24,8 +24,6 @@ from __future__ import annotations -import warnings - import numpy as np from dipy.core.gradients import GradientTable from dipy.reconst.base import ReconstModel @@ -268,23 +266,3 @@ def predict( """ return gp_prediction(self.model, gtab, mask=self.mask) - - -def _rasb2dipy(gradient): - gradient = np.asanyarray(gradient) - if gradient.ndim == 1: - if gradient.size != 4: - raise ValueError("Missing gradient information.") - gradient = gradient[..., np.newaxis] - - if gradient.shape[0] != 4: - gradient = gradient.T - elif gradient.shape == (4, 4): - print("Warning: make sure gradient information is not transposed!") - - with warnings.catch_warnings(): - from dipy.core.gradients import gradient_table - - warnings.filterwarnings("ignore", category=UserWarning) - retval = gradient_table(gradient[3, :], bvecs=gradient[:3, :].T) - return retval diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index 8fc251cf..a0dce618 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -26,8 +26,6 @@ import numpy as np -from nifreeze.exceptions import ModelNotFittedError - class ModelFactory: """A factory for instantiating data models.""" @@ -53,19 +51,19 @@ def init(model=None, **kwargs): raise RuntimeError("No model identifier provided.") if model.lower() in ("s0", "b0"): - return TrivialModel(predicted=kwargs.pop("S0"), gtab=kwargs.pop("gtab")) + return TrivialModel(kwargs.pop("dataset")) + + if model.lower() in ("avg", "average", "mean"): + return ExpectationModel(kwargs.pop("dataset"), **kwargs) if model.lower() in ("avgdwi", "averagedwi", "meandwi"): from nifreeze.model.dmri import AverageDWIModel - return AverageDWIModel(**kwargs) - - if model.lower() in ("avg", "average", "mean"): - return AverageModel(**kwargs) + return AverageDWIModel(kwargs.pop("dataset"), **kwargs) if model.lower() in ("dti", "dki", "pet"): Model = globals()[f"{model.upper()}Model"] - return Model(**kwargs) + return Model(kwargs.pop("dataset"), **kwargs) raise NotImplementedError(f"Unsupported model <{model}>.") @@ -81,114 +79,81 @@ class BaseModel: """ - __slots__ = ( - "_model", - "_mask", - "_models", - "_datashape", - "_is_fitted", - "_modelargs", - ) + __slots__ = { + "_dataset": "Reference to a :obj:`~nifreeze.data.base.BaseDataset` object.", + } - def __init__(self, mask=None, **kwargs): + def __init__(self, dataset, **kwargs): """Base initialization.""" - # Keep model state - self._model = None # "Main" model - self._models = None # For parallel (chunked) execution - - # Setup brain mask - if mask is None: + self._dataset = dataset + # Warn if mask not present + if dataset.brainmask is None: warn( "No mask provided; consider using a mask to avoid issues in model optimization.", stacklevel=2, ) - self._mask = mask - - self._datashape = None - self._is_fitted = False - - self._modelargs = () - - @property - def is_fitted(self): - return self._is_fitted - - def fit(self, data, **kwargs): - """Abstract member signature of fit().""" - raise NotImplementedError("Cannot call fit() on a BaseModel instance.") - - def predict(self, *args, **kwargs): - """Abstract member signature of predict().""" - raise NotImplementedError("Cannot call predict() on a BaseModel instance.") + def fit_predict(self, *_, **kwargs): + """Fit and predict the indicate index of the dataset (abstract signature).""" + raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.") class TrivialModel(BaseModel): """A trivial model that returns a given map always.""" - __slots__ = ("_predicted",) + __slots__ = { + "_predicted": "A :obj:`~numpy.ndarray` with shape matching the dataset containing the map" + "that will always be returned as prediction (that is, a reference volume).", + } - def __init__(self, predicted=None, **kwargs): + def __init__(self, dataset, predicted=None, **kwargs): """Implement object initialization.""" - if predicted is None: - raise TypeError("This model requires the predicted map at initialization") - super().__init__(**kwargs) - self._predicted = predicted - self._datashape = predicted.shape + super().__init__(dataset, **kwargs) + self._predicted = ( + predicted + if predicted is not None + # Infer from dataset if not provided at initialization + else getattr(dataset, "reference", getattr(dataset, "bzero", None)) + ) - @property - def is_fitted(self): - return True - - def fit(self, data, **kwargs): - """Do nothing.""" + if self._predicted is None: + raise TypeError("This model requires the predicted map at initialization") - def predict(self, *_, **kwargs): + def fit_predict(self, *_, **kwargs): """Return the reference map.""" # No need to check fit (if not fitted, has raised already) return self._predicted -class AverageModel(BaseModel): - """A trivial model that returns an average map.""" +class ExpectationModel(BaseModel): + """A trivial model that returns an expectation map (for example, average).""" - __slots__ = ("_data",) + __slots__ = {"_stat": "The statistical operation to obtain the expectation map."} - def __init__(self, **kwargs): + def __init__(self, dataset, stat="median", **kwargs): """Initialize a new model.""" - super().__init__(**kwargs) - self._data = None + super().__init__(dataset, **kwargs) + self._stat = stat - def fit(self, data, **kwargs): - """Calculate the average.""" + def fit_predict(self, index, *_, **kwargs): + """ + Return the expectation map. - # Regress out global signal differences - if kwargs.pop("equalize", False): - data = data.copy().astype("float32") - reshaped_data = ( - data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] - ) - p5 = np.percentile(reshaped_data, 5.0, axis=0) - p95 = np.percentile(reshaped_data, 95.0, axis=0) - p5 - data = (data - p5) * p95.mean() / p95 + p5.mean() + Parameters + ---------- + index : :obj:`int` + The volume index that is left-out in fitting, and then predicted. + """ # Select the summary statistic - avg_func = getattr(np, kwargs.pop("stat", "mean")) - - # Calculate the average - self._data = avg_func(data, axis=-1) + avg_func = getattr(np, kwargs.pop("stat", self._stat)) - @property - def is_fitted(self): - return self._data is not None + # Create index mask + mask = np.ones(len(self._dataset), dtype=bool) + mask[index] = False - def predict(self, *_, **kwargs): - """Return the average map.""" - - if self._data is None: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - return self._data + # Calculate the average + return avg_func(self._dataset.dataobj[mask][0], axis=-1) diff --git a/src/nifreeze/model/dmri.py b/src/nifreeze/model/dmri.py index bc4658ee..6b68765d 100644 --- a/src/nifreeze/model/dmri.py +++ b/src/nifreeze/model/dmri.py @@ -26,9 +26,11 @@ import numpy as np from joblib import Parallel, delayed -from nifreeze.exceptions import ModelNotFittedError -from nifreeze.model._dipy import _rasb2dipy -from nifreeze.model.base import BaseModel +from nifreeze.data.dmri import ( + DEFAULT_CLIP_PERCENTILE, + DTI_MIN_ORIENTATIONS, +) +from nifreeze.model.base import BaseModel, ExpectationModel def _exec_fit(model, data, chunk=None): @@ -41,84 +43,49 @@ def _exec_predict(model, chunk=None, **kwargs): return np.squeeze(model.predict(**kwargs)), chunk -DEFAULT_CLIP_PERCENTILE = 75 -"""Upper percentile threshold for intensity clipping.""" - -DEFAULT_MIN_S0 = 1e-5 -"""Minimum value when considering the :math:`S_{0}` DWI signal.""" - -DEFAULT_MAX_S0 = 1.0 -"""Maximum value when considering the :math:`S_{0}` DWI signal.""" - -DEFAULT_MAX_BVALUE = 1000 -"""Maximum allowed value for the b-value.""" - -DEFAULT_LOWB_THRESHOLD = 50 -"""The lower bound for the b-value so that the orientation is considered a DW volume.""" - -DEFAULT_HIGHB_THRESHOLD = 10000 -"""A b-value cap for DWI data.""" - -DEFAULT_NUM_BINS = 15 -"""Number of bins to classify b-values.""" - -DEFAULT_MULTISHELL_BIN_COUNT_THR = 7 -"""Default bin count to consider a multishell scheme.""" - -DEFAULT_MAX_BVAL = 8000 -"""Maximum b-value cap.""" - - class BaseDWIModel(BaseModel): """Interface and default methods for DWI models.""" - __slots__ = ( - "_gtab", - "_S0", - "_b_max", - "_model_class", # Defining a model class, DIPY models are instantiated automagically - "_modelargs", - ) + __slots__ = { + "_model_class": "Defining a model class, DIPY models are instantiated automagically", + "_modelargs": "Arguments acceptable by the underlying DIPY-like model.", + } - def __init__(self, gtab, S0=None, b_max=None, **kwargs): - """Initialization. + def __init__(self, dataset, **kwargs): + r"""Initialization. Parameters ---------- - gtab : :obj:`numpy.ndarray` - An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and - columns are b-vector components and corresponding b-value, respectively. - S0 : :obj:`numpy.ndarray` - :math:`S_{0}` signal. - b_max : :obj:`int` - Maximum value to cap b-values. + dataset : :obj:`~nifreeze.data.dmri.DWI` + Reference to a DWI object. """ - super().__init__(**kwargs) + # Duck typing, instead of explicitly test for DWI type + if not hasattr(dataset, "bzero"): + raise TypeError("Dataset MUST be a DWI object.") - # Setup B0 map - self._S0 = None - if S0 is not None: - self._S0 = np.clip( - S0.astype("float32") / S0.max(), - a_min=DEFAULT_MIN_S0, - a_max=DEFAULT_MAX_S0, + if not hasattr(dataset, "gradients") or dataset.gradients is None: + raise ValueError("Dataset MUST have a gradient table.") + + if dataset.gradients.shape[0] < DTI_MIN_ORIENTATIONS: + raise ValueError( + f"DWI dataset is too small ({dataset.gradients.shape[0]} directions)." ) - # Cap b-values, if requested - self._gtab = gtab - self._b_max = None - if b_max and b_max > DEFAULT_MAX_BVALUE: - # Saturate b-values at b_max, since signal stops dropping - self._gtab[-1, self._gtab[-1] > b_max] = b_max - # A possibly good alternative is completely remove very high b-values - # bval_mask = gtab[-1] < b_max - # data = data[..., bval_mask] - # gtab = gtab[:, bval_mask] - self._b_max = b_max + super().__init__(dataset, **kwargs) + + def _fit(self, index, n_jobs=None, **kwargs): + """Fit the model chunk-by-chunk asynchronously""" + n_jobs = n_jobs or 1 + + brainmask = self._dataset.brainmask + idxmask = np.ones(len(self._dataset), dtype=bool) + idxmask[index] = False - kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} + data, _, gtab = self._dataset[idxmask] + # Select voxels within mask or just unravel 3D if no mask + data = data[brainmask, ...] if brainmask is not None else data.reshape(-1, data.shape[-1]) # DIPY models (or one with a fully-compliant interface) model_str = getattr(self, "_model_class", None) @@ -127,18 +94,7 @@ def __init__(self, gtab, S0=None, b_max=None, **kwargs): self._model = getattr( import_module(module_name), class_name, - )(_rasb2dipy(gtab), **kwargs) - - def fit(self, data, n_jobs=None, **kwargs): - """Fit the model chunk-by-chunk asynchronously""" - n_jobs = n_jobs or 1 - - self._datashape = data.shape - - # Select voxels within mask or just unravel 3D if no mask - data = ( - data[self._mask, ...] if self._mask is not None else data.reshape(-1, data.shape[-1]) - ) + )(gtab, **kwargs) # One single CPU - linear execution (full model) if n_jobs == 1: @@ -155,37 +111,31 @@ def fit(self, data, n_jobs=None, **kwargs): results = executor( delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks) ) - for submodel, index in results: - self._models[index] = submodel + for submodel, rindex in results: + self._models[rindex] = submodel - self._is_fitted = True self._model = None # Preempt further actions on the model + return n_jobs - def predict(self, gradient=None, **kwargs): - """Predict asynchronously chunk-by-chunk the diffusion signal.""" - - if gradient is None: - raise ValueError("A gradient to be simulated (b-vector, b-value) must be provided") - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + def fit_predict(self, index, **kwargs): + """ + Predict asynchronously chunk-by-chunk the diffusion signal. - gradient = np.array(gradient) # Tuples are unmutable + Parameters + ---------- + index : :obj:`int` + The volume index that is left-out in fitting, and then predicted. - # Cap the b-value if b_max is defined - gradient[-1] = min(gradient[-1], self._b_max or gradient[-1]) + """ - gradient = _rasb2dipy(gradient) + n_models = self._fit(index, **kwargs) - S0 = None - if self._S0 is not None: - S0 = ( - self._S0[self._mask, ...] - if self._mask is not None - else self._S0.reshape(-1, self._S0.shape[-1]) - ) + brainmask = self._dataset.brainmask + gradient = self._dataset.gradients[index] - n_models = len(self._models) if self._model is None and self._models else 1 + S0 = self._dataset.bzero + if S0 is not None: + S0 = S0[brainmask, ...] if brainmask is not None else S0.reshape(-1) if n_models == 1: predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0})) @@ -209,21 +159,21 @@ def predict(self, gradient=None, **kwargs): predicted = np.hstack(predicted) - if self._mask is not None: - retval = np.zeros_like(self._mask, dtype="float32") - retval[self._mask, ...] = predicted + if brainmask is not None: + retval = np.zeros_like(brainmask, dtype="float32") + retval[brainmask, ...] = predicted else: - retval = predicted.reshape(self._datashape[:-1]) + retval = predicted.reshape(self._dataset.shape[:-1]) return retval -class AverageDWIModel(BaseDWIModel): +class AverageDWIModel(ExpectationModel): """A trivial model that returns an average DWI volume.""" - __slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted") + __slots__ = ("_th_low", "_th_high", "_detrend") - def __init__(self, **kwargs): + def __init__(self, dataset, stat="median", th_low=100, th_high=100, detrend=False, **kwargs): r""" Implement object initialization. @@ -235,7 +185,7 @@ def __init__(self, **kwargs): th_high : :obj:`numbers.Number` An upper bound for the b-value corresponding to the diffusion weighted images that will be averaged. - bias : :obj:`bool` + detrend : :obj:`bool` Whether the overall distribution of each diffusion weighted image will be standardized and centered around the :data:`src.nifreeze.model.base.DEFAULT_CLIP_PERCENTILE` percentile. @@ -243,49 +193,42 @@ def __init__(self, **kwargs): Whether the summary statistic to apply is ``"mean"`` or ``"median"``. """ - super().__init__(**kwargs) + super().__init__(dataset, stat=stat, **kwargs) - self._th_low = kwargs.get("th_low", DEFAULT_LOWB_THRESHOLD) - self._th_high = kwargs.get("th_high", DEFAULT_HIGHB_THRESHOLD) - self._bias = kwargs.get("bias", True) - self._stat = kwargs.get("stat", "median") - self._data = None + self._th_low = th_low + self._th_high = th_high + self._detrend = detrend - def fit(self, data, **kwargs): - """Calculate the average.""" + def fit_predict(self, index, *_, **kwargs): + """Return the average map.""" + + bvalues = self._dataset.gradients[:, -1] + bcenter = bvalues[index] + + shellmask = np.ones(len(self._dataset), dtype=bool) - if (gtab := kwargs.pop("gtab", None)) is None: - raise ValueError("A gradient table must be provided.") + # Keep only bvalues within the range defined by th_high and th_low + shellmask[index] = False + shellmask[bvalues > (bcenter + self._th_high)] = False + shellmask[bvalues < (bcenter - self._th_low)] = False - # Select the interval of b-values for which DWIs will be averaged - b_mask = ( - ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) - if gtab is not None - else np.ones((data.shape[-1],), dtype=bool) - ) - shells = data[..., b_mask] + if not shellmask.sum(): + raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.") + + shelldata = self._dataset.dataobj[..., shellmask] # Regress out global signal differences - if self._bias: - centers = np.median(shells, axis=(0, 1, 2)) + if self._detrend: + centers = np.median(shelldata, axis=(0, 1, 2)) reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE) centers[centers < 1.0] = reference drift = reference / centers - shells = shells * drift + shelldata = shelldata * drift # Select the summary statistic avg_func = np.median if self._stat == "median" else np.mean # Calculate the average - self._data = avg_func(shells, axis=-1) - self._is_fitted = True - - def predict(self, *_, **kwargs): - """Return the average map.""" - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - return self._data + return avg_func(shelldata, axis=-1) class DTIModel(BaseDWIModel): @@ -314,65 +257,3 @@ class GPModel(BaseDWIModel): _modelargs = ("kernel_model",) _model_class = "nifreeze.model._dipy.GaussianProcessModel" - - -def find_shelling_scheme( - bvals, - num_bins=DEFAULT_NUM_BINS, - multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR, - bval_cap=DEFAULT_MAX_BVAL, -): - """ - Find the shelling scheme on the given b-values. - - Computes the histogram of the b-values according to ``num_bins`` - and depending on the nonempty bin count, classify the shelling scheme - as single-shell if they are 2 (low-b and a shell); multi-shell if they are - below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise. - - Parameters - ---------- - bvals : :obj:`list` or :obj:`~numpy.ndarray` - List or array of b-values. - num_bins : :obj:`int`, optional - Number of bins. - multishell_nonempty_bin_count_thr : :obj:`int`, optional - Bin count to consider a multi-shell scheme. - - Returns - ------- - scheme : :obj:`str` - Shelling scheme. - bval_groups : :obj:`list` - List of grouped b-values. - bval_estimated : :obj:`list` - List of 'estimated' b-values as the median value of each b-value group. - - """ - - # Bin the b-values: use -1 as the lower bound to be able to appropriately - # include b0 values - hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap))) - - # Collect values in each bin - bval_groups = [] - bval_estimated = [] - for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False): - # Add only if a nonempty b-values mask - if (mask := (bvals > lower) & (bvals <= upper)).sum(): - bval_groups.append(bvals[mask]) - bval_estimated.append(np.median(bvals[mask])) - - nonempty_bins = len(bval_groups) - - if nonempty_bins < 2: - raise ValueError("DWI must have at least one high-b shell") - - if nonempty_bins == 2: - scheme = "single-shell" - elif nonempty_bins < multishell_nonempty_bin_count_thr: - scheme = "multi-shell" - else: - scheme = "DSI" - - return scheme, bval_groups, bval_estimated diff --git a/src/nifreeze/registration/ants.py b/src/nifreeze/registration/ants.py index 8a0976c2..feb69bdb 100644 --- a/src/nifreeze/registration/ants.py +++ b/src/nifreeze/registration/ants.py @@ -92,22 +92,23 @@ def _to_nifti( def _prepare_registration_data( - fixed: np.ndarray, + sample: np.ndarray, predicted: np.ndarray, affine: np.ndarray, vol_idx: int, dirname: Path | str, clip: str | None = None, -) -> tuple[Path, Path]: + init_affine: np.ndarray | None = None, +) -> tuple[Path, Path, Path | None]: """ - Prepare the registration data: save the fixed and moving images to disk. + Prepare the registration data: save the moving and predicted (fixed) images to disk. Parameters ---------- - fixed : :obj:`~numpy.ndarray` - Reference volume's data array. + sample : :obj:`~numpy.ndarray` + Current volume for which a transformation is to be estimated. predicted : :obj:`~numpy.ndarray` - Predicted volume's data array. + Predicted volume's data array (that is, spatial reference). affine : :obj:`numpy.ndarray` Orientation affine from the original NIfTI. vol_idx : :obj:`int` @@ -115,34 +116,45 @@ def _prepare_registration_data( dirname : :obj:`os.pathlike` Directory name where the data is saved. clip : :obj:`str` or ``None`` - Clip intensity of ``"fixed"``, ``"moving"``, ``"both"``, + Clip intensity of ``"sample"``, ``"predicted"``, ``"both"``, or ``"none"`` of the images. Returns ------- - fixed_path : :obj:`~pathlib.Path` - Fixed image filename. - moving_path : :obj:`~pathlib.Path` - Moving image filename. + predicted_path : :obj:`~pathlib.Path` + Predicted image filename. + sample_path : :obj:`~pathlib.Path` + Current volume's filename. + init_path : :obj:`~pathlib.Path` or ``None`` + An initialization affine (for second and further estimators). """ clip = clip or "none" - moving_path = Path(dirname) / f"moving{vol_idx:05d}.nii.gz" - fixed_path = Path(dirname) / f"fixed{vol_idx:05d}.nii.gz" + predicted_path = Path(dirname) / f"predicted_{vol_idx:05d}.nii.gz" + sample_path = Path(dirname) / f"sample_{vol_idx:05d}.nii.gz" _to_nifti( - fixed, + sample, affine, - moving_path, - clip=clip.lower() in ("fixed", "both"), + sample_path, + clip=clip.lower() in ("sample", "both"), ) _to_nifti( predicted, affine, - fixed_path, - clip=clip.lower() in ("moving", "both"), + predicted_path, + clip=clip.lower() in ("predicted", "both"), ) - return fixed_path, moving_path + + init_path = None + if init_affine is not None: + ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) + reference = ImageGrid(shape=sample.shape[:3], affine=affine) + initial_xform = Affine(matrix=init_affine, reference=reference) + init_path = dirname / f"init_{vol_idx:05d}.mat" + initial_xform.to_filename(init_path, fmt="itk") + + return predicted_path, sample_path, init_path def _get_ants_settings(settings: str = "b0-to-b0_level0") -> Path: @@ -222,7 +234,7 @@ def generate_command( init_affine: str | Path | None = None, default: str = "b0-to-b0_level0", **kwargs: dict, -) -> str: +) -> Registration: """ Generate an ANTs' command line. @@ -245,15 +257,15 @@ def generate_command( Returns ------- - :obj:`str` - The ANTs registration command line string. + :obj:`~nipype.interfaces.ants.Registration` + The configured Nipype interface of ANTs registration. Examples -------- >>> generate_command( ... fixed_path=repodata / 'fileA.nii.gz', ... moving_path=repodata / 'fileB.nii.gz', - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 12.0 ] \ @@ -275,7 +287,7 @@ def generate_command( ... fixed_path=repodata / 'fileA.nii.gz', ... moving_path=repodata / 'fileB.nii.gz', ... default="dwi-to-b0_level0", - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 0.01 ] --metric Mattes[ \ @@ -297,7 +309,7 @@ def generate_command( ... moving_path=repodata / 'fileB.nii.gz', ... fixedmask_path=repodata / 'maskA.nii.gz', ... default="dwi-to-b0_level0", - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 0.01 ] --metric Mattes[ \ @@ -320,7 +332,7 @@ def generate_command( ... fixed_path=repodata / 'fileA.nii.gz', ... moving_path=repodata / 'fileB.nii.gz', ... default="dwi-to-b0_level0", - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 0.01 ] --metric Mattes[ \ @@ -342,7 +354,7 @@ def generate_command( ... moving_path=repodata / 'fileB.nii.gz', ... fixedmask_path=[repodata / 'maskA.nii.gz'], ... default="dwi-to-b0_level0", - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 0.01 ] --metric Mattes[ \ @@ -406,16 +418,12 @@ def generate_command( fixed_image=str(Path(fixed_path).absolute()), moving_image=str(Path(moving_path).absolute()), **settings, - ).cmdline + ) def _run_registration( - fixed: Path, - moving: Path, - bmask_img: nb.spatialimages.SpatialImage, - em_affines: np.ndarray, - affine: np.ndarray, - shape: tuple[int, int, int], + fixed_path: str | Path, + moving_path: str | Path, vol_idx: int, dirname: Path, **kwargs: dict, @@ -425,18 +433,10 @@ def _run_registration( Parameters ---------- - fixed : :obj:`Path` + fixed_path : :obj:`Path` Fixed image filename. - moving : :obj:`Path` + moving_path : :obj:`Path` Moving image filename. - bmask_img : :class:`~nibabel.spatialimages.SpatialImage` - Brainmask image. - em_affines : :obj:`numpy.ndarray` - Estimated head-motion affine transformation matrices. - affine : :obj:`numpy.ndarray` - Orientation affine from the original NIfTI. - shape : :obj:`tuple` - 3D shape of dataset. vol_idx : :obj:`int` Dataset volume index. dirname : :obj:`Path` @@ -451,27 +451,27 @@ def _run_registration( """ - if "config_file" in kwargs: - kwargs["from_file"] = pkg_fn( - "nifreeze.registration", - f"config/{kwargs.pop('config_file')}", - ) - registration = Registration( + align_kwargs = kwargs.copy() + environ = align_kwargs.pop("environ", {}) + num_threads = align_kwargs.pop("num_threads", None) + + if (seed := align_kwargs.pop("seed", None)) is not None: + environ["ANTS_RANDOM_SEED"] = str(seed) + + if "ants_config" in kwargs: + align_kwargs["default"] = align_kwargs.pop("ants_config").replace(".json", "") + + registration = generate_command( + fixed_path, + moving_path, terminal_output="file", - fixed_image=str(fixed.absolute()), - moving_image=str(moving.absolute()), - **kwargs, + environ=environ, + **align_kwargs, ) - if bmask_img: - registration.inputs.fixed_image_masks = ["NULL", bmask_img] + if num_threads: + registration.inputs.num_threads = num_threads - if em_affines is not None and np.any(em_affines[vol_idx, ...]): - ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) - reference = ImageGrid(shape=shape, affine=affine) - initial_xform = Affine(matrix=em_affines[vol_idx], reference=reference) - mat_file = dirname / f"init_{vol_idx:05d}.mat" - initial_xform.to_filename(mat_file, fmt="itk") - registration.inputs.initial_moving_transform = str(mat_file) + (dirname / f"cmd-{vol_idx:05d}.sh").write_text(registration.cmdline) # execute ants command line result = registration.run(cwd=str(dirname)).outputs @@ -479,10 +479,12 @@ def _run_registration( # read output transform xform = nt.linear.Affine( nt.io.itk.ITKLinearTransform.from_filename(result.forward_transforms[0]).to_ras( - reference=fixed, moving=moving + reference=fixed_path, moving=moving_path ), ) # debugging: generate aligned file for testing - xform.apply(moving, reference=fixed).to_filename(dirname / f"aligned{vol_idx:05d}.nii.gz") + xform.apply(moving_path, reference=fixed_path).to_filename( + dirname / f"dbg_{vol_idx:05d}.nii.gz" + ) return xform diff --git a/test/conftest.py b/test/conftest.py index d66738c7..56d0fbf8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -100,5 +100,5 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): terminalreporter.ensure_newline() terminalreporter.section("Werrors", sep="=", red=True, bold=True) terminalreporter.line( - f"Warnings as errors: Activated.\n{len(have_warnings)} warnings were raised and treated as errors.\n" + f"{len(have_warnings)} warnings were raised and treated as errors.\n" ) diff --git a/test/test_data_dmri.py b/test/test_data_dmri.py index 565b6d9a..a6afd79e 100644 --- a/test/test_data_dmri.py +++ b/test/test_data_dmri.py @@ -26,10 +26,7 @@ import numpy as np import pytest -from nifreeze.data.dmri import load -from nifreeze.model.dmri import ( - find_shelling_scheme, -) +from nifreeze.data.dmri import find_shelling_scheme, load def _create_dwi_random_dataobj(): diff --git a/test/test_integration.py b/test/test_integration.py index 33a2bd11..d3259133 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -30,6 +30,7 @@ from nifreeze.data.dmri import DWI from nifreeze.estimator import Estimator +from nifreeze.model.base import TrivialModel from nifreeze.registration.utils import displacements_within_mask @@ -70,14 +71,12 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path): brainmask=dwdata.brainmask, ) - estimator = Estimator("b0") + model = TrivialModel(dwdata) + estimator = Estimator(model) estimator.run( dwi_motion, - seed=None, - align_kwargs={ - "config_file": "b0-to-b0_level0.json", - "num_threads": min(cpu_count(), 8), - }, + seed=12345, + num_threads=min(cpu_count(), 8), ) # Uncomment to see the realigned dataset diff --git a/test/test_model.py b/test/test_model.py index 15e77c66..6a5759a4 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -27,11 +27,8 @@ from dipy.sims.voxel import single_tensor from nifreeze import model -from nifreeze.data.dmri import DWI -from nifreeze.data.splitting import lovo_split -from nifreeze.exceptions import ModelNotFittedError +from nifreeze.data.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0, DWI from nifreeze.model._dipy import GaussianProcessModel -from nifreeze.model.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0 from nifreeze.testing import simulations as _sim @@ -52,19 +49,21 @@ def test_trivial_model(): a_max=DEFAULT_MAX_S0, ) - tmodel = model.TrivialModel(predicted=_clipped_S0) + data = DWI( + dataobj=(*_S0.shape, 10), + bzero=_clipped_S0, + ) - data = None - assert tmodel.fit(data) is None + tmodel = model.TrivialModel(data) + predicted = tmodel.fit_predict(4) - assert np.all(_clipped_S0 == tmodel.predict((1, 0, 0))) + assert np.all(_clipped_S0 == predicted) def test_average_model(): """Check the implementation of the average DW model.""" - data = np.ones((100, 100, 100, 6), dtype=float) - + data = np.ones((100, 100, 100, 10), dtype=float) gtab = np.array( [ [0, 0, 0, 0], @@ -72,40 +71,37 @@ def test_average_model(): [0.25, 0.565, 0.21, 500], [-0.861, -0.464, 0.564, 1000], [0.307, -0.766, 0.677, 1000], - [0.736, 0.013, 0.774, 1300], + [0.736, 0.013, 0.774, 1000], + [-0.31, 0.933, 0.785, 1000], + [0.25, 0.565, 0.21, 2000], + [-0.861, -0.464, 0.564, 2000], + [0.307, -0.766, 0.677, 2000], ] ) data *= gtab[:, -1] + dataset = DWI(dataobj=data, gradients=gtab) - tmodel_mean = model.AverageDWIModel(gtab=gtab, bias=False, stat="mean") - tmodel_median = model.AverageDWIModel(gtab=gtab, bias=False, stat="median") - tmodel_1000 = model.AverageDWIModel(gtab=gtab, bias=False, th_high=1000, th_low=900) - tmodel_2000 = model.AverageDWIModel( - gtab=gtab, - bias=False, - th_high=2000, - th_low=900, - stat="mean", - ) + tmodel_mean = model.AverageDWIModel(dataset, stat="mean") + tmodel_mean_full = model.AverageDWIModel(dataset, stat="mean", th_low=2000, th_high=2000) + tmodel_median = model.AverageDWIModel(dataset) - with pytest.raises(ModelNotFittedError): - tmodel_mean.predict([0, 0, 0]) + # Verify that average cannot be calculated in shells with one single value + with pytest.raises(RuntimeError): + tmodel_mean.fit_predict(2) - # Verify that fit function returns nothing - assert tmodel_mean.fit(data[..., 1:], gtab=gtab[1:].T) is None + assert np.allclose(tmodel_mean.fit_predict(3), 1000) + assert np.allclose(tmodel_median.fit_predict(3), 1000) - tmodel_median.fit(data[..., 1:], gtab=gtab[1:].T) - tmodel_1000.fit(data[..., 1:], gtab=gtab[1:].T) - tmodel_2000.fit(data[..., 1:], gtab=gtab[1:].T) + grads = list(gtab[:, -1]) + del grads[1] + assert np.allclose(tmodel_mean_full.fit_predict(1), np.mean(grads)) - # Verify that the right statistics is applied and that the model discard b-values < 50 - assert np.all(tmodel_mean.predict([0, 0, 0]) == 950) - assert np.all(tmodel_median.predict([0, 0, 0]) == 1000) + tmodel_mean_2000 = model.AverageDWIModel(dataset, stat="mean", th_low=1100) + tmodel_median_2000 = model.AverageDWIModel(dataset, th_low=1100) - # Verify that the threshold for b-value selection works as expected - assert np.all(tmodel_1000.predict([0, 0, 0]) == 1000) - assert np.all(tmodel_2000.predict([0, 0, 0]) == 1100) + assert np.allclose(tmodel_mean_2000.fit_predict(9), gtab[3:-1, -1].mean()) + assert np.allclose(tmodel_median_2000.fit_predict(9), 1000) @pytest.mark.parametrize( @@ -143,42 +139,26 @@ def test_gp_model(evals, S0, snr, hsph_dirs, bval_shell): assert prediction.shape == (2,) -def test_two_initialisations(datadir): +def test_factory(datadir): """Check that the two different initialisations result in the same models""" # Load test data dmri_dataset = DWI.from_filename(datadir / "dwi.h5") - # Split data into test and train set - data_train, data_test = lovo_split(dmri_dataset, 10) - + modelargs = { + "th_low": 25, + "th_high": 25, + "detrend": True, + "stat": "mean", + } # Direct initialisation - model1 = model.AverageDWIModel( - gtab=data_train[-1], - S0=dmri_dataset.bzero, - th_low=100, - th_high=1000, - bias=False, - stat="mean", - ) - model1.fit(data_train[0], gtab=data_train[-1]) - predicted1 = model1.predict(data_test[-1]) + model1 = model.AverageDWIModel(dmri_dataset, **modelargs) # Initialisation via ModelFactory - model2 = model.ModelFactory.init( - gtab=data_train[-1], - model="avgdwi", - S0=dmri_dataset.bzero, - th_low=100, - th_high=1000, - bias=False, - stat="mean", - ) - - with pytest.raises(ModelNotFittedError): - model2.predict(data_test[-1]) - - model2.fit(data_train[0], gtab=data_train[-1]) - predicted2 = model2.predict(data_test[-1]) + model2 = model.ModelFactory.init(model="avgdwi", dataset=dmri_dataset, **modelargs) - assert np.all(predicted1 == predicted2) + assert model1._dataset == model2._dataset + assert model1._detrend == model2._detrend + assert model1._th_low == model2._th_low + assert model1._th_high == model2._th_high + assert model1._stat == model2._stat From 97e61c392cf7bf358c3a23e1a24197bdb37f84d5 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 23 Jan 2025 14:12:14 +0100 Subject: [PATCH 05/11] fix: pacify mypy --- src/nifreeze/cli/run.py | 9 +++++++-- src/nifreeze/estimator.py | 2 +- src/nifreeze/registration/ants.py | 3 +-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/nifreeze/cli/run.py b/src/nifreeze/cli/run.py index 7c58cafd..ee7eaa37 100644 --- a/src/nifreeze/cli/run.py +++ b/src/nifreeze/cli/run.py @@ -42,12 +42,17 @@ def main(argv=None) -> None: # Open the data with the given file path dataset: BaseDataset = BaseDataset.from_filename(args.input_file) - estimator: Estimator = Estimator() + prev_model: Estimator | None = None + for model in args.models: + estimator: Estimator = Estimator( + args.model, + prev=prev_model, + ) + prev_model = estimator _ = estimator.run( dataset, align_kwargs=args.align_config, - models=args.models, omp_nthreads=args.nthreads, njobs=args.njobs, seed=args.seed, diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index 244c027f..99a9d753 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -137,7 +137,7 @@ def run(self, dataset: BaseDataset, **kwargs): # fit the model test_set = dataset[i] - predicted = self._model.fit_predict( + predicted = self._model.fit_predict( # type: ignore[union-attr] i, n_jobs=n_jobs, ) diff --git a/src/nifreeze/registration/ants.py b/src/nifreeze/registration/ants.py index 69036e85..e7bf65ab 100644 --- a/src/nifreeze/registration/ants.py +++ b/src/nifreeze/registration/ants.py @@ -130,7 +130,6 @@ def _prepare_registration_data( """ clip = clip or "none" - predicted_path = Path(dirname) / f"predicted_{vol_idx:05d}.nii.gz" sample_path = Path(dirname) / f"sample_{vol_idx:05d}.nii.gz" _to_nifti( @@ -151,7 +150,7 @@ def _prepare_registration_data( ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) reference = ImageGrid(shape=sample.shape[:3], affine=affine) initial_xform = Affine(matrix=init_affine, reference=reference) - init_path = dirname / f"init_{vol_idx:05d}.mat" + init_path = Path(dirname) / f"init_{vol_idx:05d}.mat" initial_xform.to_filename(init_path, fmt="itk") return predicted_path, sample_path, init_path From ce143580485230ead5e3e22960ef8c95637ae593 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 23 Jan 2025 15:28:09 +0100 Subject: [PATCH 06/11] fix: erroneous direct use of argument --- src/nifreeze/cli/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nifreeze/cli/run.py b/src/nifreeze/cli/run.py index 1d080620..d16e1a41 100644 --- a/src/nifreeze/cli/run.py +++ b/src/nifreeze/cli/run.py @@ -45,7 +45,7 @@ def main(argv=None) -> None: prev_model: Estimator | None = None for _model in args.models: estimator: Estimator = Estimator( - args.model, + _model, prev=prev_model, ) prev_model = estimator From 7d44484988406e1831c67c89883b04d65ba4a3d1 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 23 Jan 2025 15:41:01 +0100 Subject: [PATCH 07/11] fix: failing test after sloppy merge --- test/test_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_model.py b/test/test_model.py index 87eed40a..b4feddd3 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -78,10 +78,6 @@ def test_trivial_model(use_mask): def test_average_model(): """Check the implementation of the average DW model.""" - size = (100, 100, 100, 6) - data = np.ones(size, dtype=float) - mask = np.ones(size[:3], dtype=bool) - gtab = np.array( [ [0, 0, 0, 0], @@ -97,6 +93,10 @@ def test_average_model(): ] ) + size = (100, 100, 100, gtab.shape[0]) + data = np.ones(size, dtype=float) + mask = np.ones(size[:3], dtype=bool) + data *= gtab[:, -1] dataset = DWI(dataobj=data, gradients=gtab, brainmask=mask) From 668e240cb15362623d3505ffeb175968fa865c70 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 24 Jan 2025 09:13:37 +0100 Subject: [PATCH 08/11] Apply suggestions from code review Co-authored-by: Chris Markiewicz --- src/nifreeze/estimator.py | 14 ++++++++------ src/nifreeze/model/base.py | 6 ++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index 99a9d753..4d50b649 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -26,7 +26,7 @@ from pathlib import Path from tempfile import TemporaryDirectory -from typing import Self +from typing import Self, TypeVar from tqdm import tqdm @@ -38,11 +38,13 @@ ) from nifreeze.utils import iterators +DatasetT = TypeVar("DatasetT", bound=BaseDataset) + class Filter: """Alters an input data object (e.g., downsampling).""" - def run(self, dataset: BaseDataset, **kwargs): + def run(self, dataset: DatasetT, **kwargs) -> DatasetT: """ Trigger execution of the designated filter. @@ -53,8 +55,8 @@ def run(self, dataset: BaseDataset, **kwargs): Returns ------- - :obj:`~nifreeze.estimator.Estimator` - The estimator, after fitting. + dataset : :obj:`~nifreeze.data.base.BaseDataset` + The dataset, after filtering. """ return dataset @@ -69,7 +71,7 @@ def __init__( self, model: BaseModel | str, strategy: str = "random", - prev: Self | None = None, + prev: Estimator | Filter | None = None, model_kwargs: dict | None = None, **kwargs, ): @@ -79,7 +81,7 @@ def __init__( self._model_kwargs = model_kwargs or {} self._align_kwargs = kwargs or {} - def run(self, dataset: BaseDataset, **kwargs): + def run(self, dataset: DatasetT, **kwargs) -> Self: """ Trigger execution of the workflow this estimator belongs. diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index c7f419d3..98d1f6ce 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -22,6 +22,7 @@ # """Base infrastructure for nifreeze's models.""" +from abc import abstractmethod from warnings import warn import numpy as np @@ -95,7 +96,8 @@ def __init__(self, dataset, **kwargs): if dataset.brainmask is None: warn(mask_absence_warn_msg, stacklevel=2) - def fit_predict(self, *_, **kwargs): + @abstractmethod + def fit_predict(self, index, **kwargs) -> np.ndarray: """Fit and predict the indicate index of the dataset (abstract signature).""" raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.") @@ -139,7 +141,7 @@ def __init__(self, dataset, stat="median", **kwargs): super().__init__(dataset, **kwargs) self._stat = stat - def fit_predict(self, index, *_, **kwargs): + def fit_predict(self, index, **kwargs): """ Return the expectation map. From 38147b71aea96f6da774154b6340233b5972150a Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 24 Jan 2025 09:53:44 +0100 Subject: [PATCH 09/11] doc: minimal documentation edits --- src/nifreeze/model/base.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index 98d1f6ce..f80394db 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -44,7 +44,7 @@ def init(model=None, **kwargs): ---------- model : :obj:`str` Diffusion model. - Options: ``"DTI"``, ``"DKI"``, ``"S0"``, ``"AverageDW"`` + Options: ``"DTI"``, ``"DKI"``, ``"S0"``, ``"AverageDWI"`` Return ------ @@ -84,9 +84,7 @@ class BaseModel: """ - __slots__ = { - "_dataset": "Reference to a :obj:`~nifreeze.data.base.BaseDataset` object.", - } + __slots__ = ("_dataset", ) def __init__(self, dataset, **kwargs): """Base initialization.""" @@ -105,10 +103,7 @@ def fit_predict(self, index, **kwargs) -> np.ndarray: class TrivialModel(BaseModel): """A trivial model that returns a given map always.""" - __slots__ = { - "_predicted": "A :obj:`~numpy.ndarray` with shape matching the dataset containing the map" - "that will always be returned as prediction (that is, a reference volume).", - } + __slots__ = ("_predicted", ) def __init__(self, dataset, predicted=None, **kwargs): """Implement object initialization.""" @@ -134,7 +129,7 @@ def fit_predict(self, *_, **kwargs): class ExpectationModel(BaseModel): """A trivial model that returns an expectation map (for example, average).""" - __slots__ = {"_stat": "The statistical operation to obtain the expectation map."} + __slots__ = ("_stat", ) def __init__(self, dataset, stat="median", **kwargs): """Initialize a new model.""" From 24efee46e89a23869f09671b419006d5c852a29c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 24 Jan 2025 10:00:31 +0100 Subject: [PATCH 10/11] fix: type assignment exception --- src/nifreeze/estimator.py | 2 +- src/nifreeze/model/base.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index 4d50b649..f397603e 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -99,7 +99,7 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: if self._prev is not None: result = self._prev.run(dataset, **kwargs) if isinstance(self._prev, Filter): - dataset = result + dataset = result # type: ignore[assignment] n_jobs = kwargs.get("n_jobs", None) diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index f80394db..cd4ca744 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -84,7 +84,7 @@ class BaseModel: """ - __slots__ = ("_dataset", ) + __slots__ = ("_dataset",) def __init__(self, dataset, **kwargs): """Base initialization.""" @@ -103,7 +103,7 @@ def fit_predict(self, index, **kwargs) -> np.ndarray: class TrivialModel(BaseModel): """A trivial model that returns a given map always.""" - __slots__ = ("_predicted", ) + __slots__ = ("_predicted",) def __init__(self, dataset, predicted=None, **kwargs): """Implement object initialization.""" @@ -129,7 +129,7 @@ def fit_predict(self, *_, **kwargs): class ExpectationModel(BaseModel): """A trivial model that returns an expectation map (for example, average).""" - __slots__ = ("_stat", ) + __slots__ = ("_stat",) def __init__(self, dataset, stat="median", **kwargs): """Initialize a new model.""" From 6697c7872cf77e3dd35288118f0294d80d6e7f48 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 24 Jan 2025 12:19:26 +0100 Subject: [PATCH 11/11] enh: add test checking stacked models Related: #12. --- src/nifreeze/estimator.py | 1 + src/nifreeze/registration/ants.py | 34 +++++++++++----- test/conftest.py | 62 +++++++++++++++++++++++++++-- test/test_integration.py | 66 +++++++++++++------------------ 4 files changed, 112 insertions(+), 51 deletions(-) diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index f397603e..c197991a 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -117,6 +117,7 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: ) kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None) + kwargs = self._align_kwargs | kwargs dataset_length = len(dataset) with TemporaryDirectory() as tmp_dir: diff --git a/src/nifreeze/registration/ants.py b/src/nifreeze/registration/ants.py index e7bf65ab..789721cc 100644 --- a/src/nifreeze/registration/ants.py +++ b/src/nifreeze/registration/ants.py @@ -97,7 +97,7 @@ def _prepare_registration_data( affine: np.ndarray, vol_idx: int, dirname: Path | str, - clip: str | None = None, + clip: str | bool | None = None, init_affine: np.ndarray | None = None, ) -> tuple[Path, Path, Path | None]: """ @@ -129,20 +129,21 @@ def _prepare_registration_data( An initialization affine (for second and further estimators). """ - clip = clip or "none" + predicted_path = Path(dirname) / f"predicted_{vol_idx:05d}.nii.gz" sample_path = Path(dirname) / f"sample_{vol_idx:05d}.nii.gz" + _to_nifti( sample, affine, sample_path, - clip=clip.lower() in ("sample", "both"), + clip=str(clip).lower() in ("sample", "both", "true"), ) _to_nifti( predicted, affine, predicted_path, - clip=clip.lower() in ("predicted", "both"), + clip=str(clip).lower() in ("predicted", "both", "true"), ) init_path = None @@ -232,6 +233,9 @@ def generate_command( movingmask_path: str | Path | list[str] | None = None, init_affine: str | Path | None = None, default: str = "b0-to-b0_level0", + terminal_output: str | None = None, + num_threads: int | None = None, + environ: dict | None = None, **kwargs, ) -> Registration: """ @@ -251,6 +255,12 @@ def generate_command( Initial affine transformation. default : :obj:`str`, optional Default settings configuration. + terminal_output : :obj:`str`, optional + Redirect terminal output (Nipype configuration) + environ : :obj:`dict`, optional + Add environment variables to the execution. + num_threads : :obj:`int`, optional + Set the number of threads for ANTs' execution. **kwargs : :obj:`dict` Additional parameters for ANTs registration. @@ -413,11 +423,17 @@ def generate_command( settings["initial_moving_transform"] = str(init_affine) # Generate command line with nipype and return - return Registration( + reg_iface = Registration( fixed_image=str(Path(fixed_path).absolute()), moving_image=str(Path(moving_path).absolute()), + terminal_output=terminal_output, + environ=environ or {}, **settings, ) + if num_threads: + reg_iface.inputs.num_threads = num_threads + + return reg_iface def _run_registration( @@ -451,10 +467,11 @@ def _run_registration( """ align_kwargs = kwargs.copy() - environ = align_kwargs.pop("environ", {}) + environ = align_kwargs.pop("environ", None) num_threads = align_kwargs.pop("num_threads", None) if (seed := align_kwargs.pop("seed", None)) is not None: + environ = environ or {} environ["ANTS_RANDOM_SEED"] = str(seed) if "ants_config" in kwargs: @@ -463,12 +480,11 @@ def _run_registration( registration = generate_command( fixed_path, moving_path, - terminal_output="file", environ=environ, + terminal_output="file_split", + num_threads=num_threads, **align_kwargs, ) - if num_threads: - registration.inputs.num_threads = num_threads (dirname / f"cmd-{vol_idx:05d}.sh").write_text(registration.cmdline) diff --git a/test/conftest.py b/test/conftest.py index f2bbbef0..c10c231c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -26,9 +26,12 @@ from pathlib import Path import nibabel as nb +import nitransforms as nt import numpy as np import pytest +from nifreeze.data.dmri import DWI + test_data_env = os.getenv("TEST_DATA_HOME", str(Path.home() / "nifreeze-tests")) test_output_dir = os.getenv("TEST_OUTPUT_DIR") test_workdir = os.getenv("TEST_WORK_DIR") @@ -54,19 +57,19 @@ def doctest_imports(doctest_namespace): doctest_namespace["repodata"] = _datadir -@pytest.fixture +@pytest.fixture(scope="session") def outdir(): """Determine if test artifacts should be stored somewhere or deleted.""" return None if test_output_dir is None else Path(test_output_dir) -@pytest.fixture +@pytest.fixture(scope="session") def datadir(): """Return a data path outside the package's structure (i.e., large datasets).""" return Path(test_data_env) -@pytest.fixture +@pytest.fixture(scope="session") def repodata(): """Return the path to this repository's test data folder.""" return _datadir @@ -80,6 +83,59 @@ def pytest_addoption(parser): ) +@pytest.fixture(scope="session") +def motion_data(tmp_path_factory, datadir): + # Temporary directory for session-scoped fixtures + tmp_path = tmp_path_factory.mktemp("motion_test_data") + + dwdata = DWI.from_filename(datadir / "dwi.h5") + b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None) + masknii = nb.Nifti1Image(dwdata.brainmask.astype("uint8"), dwdata.affine, None) + + # Generate a list of large-yet-plausible bulk-head motion + xfms = nt.linear.LinearTransformsMapping( + [ + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.03, z=0.005), (0.8, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.005), (0.8, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.02), (0.4, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.02), (0.4, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.002), (0.0, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.02, z=0.002), (0.0, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.01, z=0.002), (0.0, 0.4, 0.2)), + ], + reference=b0nii, + ) + + # Induce motion into dataset (i.e., apply the inverse transforms) + moved_nii = (~xfms).apply(b0nii, reference=b0nii) + + # Save the moved dataset for debugging or further processing + moved_path = tmp_path / "test.nii.gz" + ground_truth_path = tmp_path / "ground_truth.nii.gz" + moved_nii.to_filename(moved_path) + xfms.apply(moved_nii).to_filename(ground_truth_path) + + # Wrap into dataset object + dwi_motion = DWI( + dataobj=np.asanyarray(moved_nii.dataobj), + affine=b0nii.affine, + bzero=dwdata.bzero, + gradients=dwdata.gradients[..., : len(xfms)], + brainmask=dwdata.brainmask, + ) + + # Return data as a dictionary (or any format that makes sense for your tests) + return { + "b0nii": b0nii, + "masknii": masknii, + "moved_nii": moved_nii, + "xfms": xfms, + "moved_path": moved_path, + "ground_truth_path": ground_truth_path, + "moved_nifreeze": dwi_motion, + } + + @pytest.hookimpl(trylast=True) def pytest_sessionfinish(session, exitstatus): have_werrors = os.getenv("NIFREEZE_WERRORS", False) diff --git a/test/test_integration.py b/test/test_integration.py index d3259133..84d2e3d2 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -24,54 +24,22 @@ from os import cpu_count -import nibabel as nb import nitransforms as nt -import numpy as np -from nifreeze.data.dmri import DWI from nifreeze.estimator import Estimator from nifreeze.model.base import TrivialModel from nifreeze.registration.utils import displacements_within_mask -def test_proximity_estimator_trivial_model(datadir, tmp_path): +def test_proximity_estimator_trivial_model(motion_data, tmp_path): """Check the proximity of transforms estimated by the estimator with a trivial B0 model.""" - dwdata = DWI.from_filename(datadir / "dwi.h5") - b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None) - masknii = nb.Nifti1Image(dwdata.brainmask.astype(np.uint8), dwdata.affine, None) + b0nii = motion_data["b0nii"] + moved_nii = motion_data["moved_nii"] + xfms = motion_data["xfms"] + dwi_motion = motion_data["moved_nifreeze"] - # Generate a list of large-yet-plausible bulk-head motion. - xfms = nt.linear.LinearTransformsMapping( - [ - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.03, z=0.005), (0.8, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.005), (0.8, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.02), (0.4, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.02), (0.4, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.002), (0.0, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.02, z=0.002), (0.0, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.01, z=0.002), (0.0, 0.4, 0.2)), - ], - reference=b0nii, - ) - - # Induce motion into dataset (i.e., apply the inverse transforms) - moved_nii = (~xfms).apply(b0nii, reference=b0nii) - - # Uncomment to see the moved dataset - moved_nii.to_filename(tmp_path / "test.nii.gz") - xfms.apply(moved_nii).to_filename(tmp_path / "ground_truth.nii.gz") - - # Wrap into dataset object - dwi_motion = DWI( - dataobj=moved_nii.dataobj, - affine=b0nii.affine, - bzero=dwdata.bzero, - gradients=dwdata.gradients[..., : len(xfms)], - brainmask=dwdata.brainmask, - ) - - model = TrivialModel(dwdata) + model = TrivialModel(dwi_motion) estimator = Estimator(model) estimator.run( dwi_motion, @@ -89,9 +57,29 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path): for i, est in enumerate(dwi_motion.motion_affines): assert ( displacements_within_mask( - masknii, + motion_data["masknii"], nt.linear.Affine(est), xfms[i], ).max() < 0.25 ) + + +def test_stacked_estimators(motion_data): + """Check that models can be stacked.""" + + # Wrap into dataset object + dmri_dataset = motion_data["moved_nifreeze"] + + estimator1 = Estimator( + TrivialModel(dmri_dataset), + ants_config="dwi-to-dwi_level0.json", + clip=False, + ) + estimator2 = Estimator( + TrivialModel(dmri_dataset), + prev=estimator1, + clip=False, + ) + + estimator2.run(dmri_dataset)