diff --git a/docs/notebooks/bold_realignment.ipynb b/docs/notebooks/bold_realignment.ipynb index 55bf7916..4f491c89 100644 --- a/docs/notebooks/bold_realignment.ipynb +++ b/docs/notebooks/bold_realignment.ipynb @@ -337,7 +337,7 @@ "metadata": {}, "outputs": [], "source": [ - "from nifreeze.model.base import AverageModel\n", + "from nifreeze.model.base import ExpectationModel\n", "from nifreeze.utils.iterators import random_iterator" ] }, @@ -358,7 +358,7 @@ " t_mask[t] = True\n", "\n", " # Fit and predict using the model\n", - " model = AverageModel()\n", + " model = ExpectationModel()\n", " model.fit(\n", " data[..., ~t_mask],\n", " stat=\"median\",\n", @@ -376,7 +376,7 @@ " fixedmask_path=brainmask_path,\n", " output_transform_prefix=f\"conversion-{t:02d}\",\n", " num_threads=8,\n", - " )\n", + " ).cmdline\n", "\n", " # Run the command\n", " proc = await asyncio.create_subprocess_shell(\n", diff --git a/scripts/optimize_registration.py b/scripts/optimize_registration.py index abc05d6f..7cf38ae3 100644 --- a/scripts/optimize_registration.py +++ b/scripts/optimize_registration.py @@ -134,7 +134,7 @@ async def train_coro( moving_path, fixedmask_path=brainmask_path, **_kwargs, - ) + ).cmdline tasks.append( ants( diff --git a/src/nifreeze/cli/run.py b/src/nifreeze/cli/run.py index b4c32fc4..d16e1a41 100644 --- a/src/nifreeze/cli/run.py +++ b/src/nifreeze/cli/run.py @@ -25,7 +25,7 @@ from pathlib import Path from nifreeze.cli.parser import parse_args -from nifreeze.data.dmri import DWI +from nifreeze.data.base import BaseDataset from nifreeze.estimator import Estimator @@ -40,14 +40,19 @@ def main(argv=None) -> None: args = parse_args(argv) # Open the data with the given file path - dwi_dataset: DWI = DWI.from_filename(args.input_file) + dataset: BaseDataset = BaseDataset.from_filename(args.input_file) - estimator: Estimator = Estimator() + prev_model: Estimator | None = None + for _model in args.models: + estimator: Estimator = Estimator( + _model, + prev=prev_model, + ) + prev_model = estimator - _ = estimator.estimate( - dwi_dataset, + _ = estimator.run( + dataset, align_kwargs=args.align_config, - models=args.models, omp_nthreads=args.nthreads, njobs=args.njobs, seed=args.seed, @@ -58,7 +63,7 @@ def main(argv=None) -> None: output_path: Path = Path(args.output_dir) / output_filename # Save the DWI dataset to the output path - dwi_dataset.to_filename(output_path) + dataset.to_filename(output_path) if __name__ == "__main__": diff --git a/src/nifreeze/data/dmri.py b/src/nifreeze/data/dmri.py index 69423c8c..8c2e3d08 100644 --- a/src/nifreeze/data/dmri.py +++ b/src/nifreeze/data/dmri.py @@ -39,6 +39,30 @@ from nifreeze.data.base import BaseDataset, _cmp, _data_repr from nifreeze.utils.ndimage import load_api +DEFAULT_CLIP_PERCENTILE = 75 +"""Upper percentile threshold for intensity clipping.""" + +DEFAULT_MIN_S0 = 1e-5 +"""Minimum value when considering the :math:`S_{0}` DWI signal.""" + +DEFAULT_MAX_S0 = 1.0 +"""Maximum value when considering the :math:`S_{0}` DWI signal.""" + +DEFAULT_LOWB_THRESHOLD = 50 +"""The lower bound for the b-value so that the orientation is considered a DW volume.""" + +DEFAULT_HIGHB_THRESHOLD = 8000 +"""A b-value cap for DWI data.""" + +DEFAULT_NUM_BINS = 15 +"""Number of bins to classify b-values.""" + +DEFAULT_MULTISHELL_BIN_COUNT_THR = 7 +"""Default bin count to consider a multishell scheme.""" + +DTI_MIN_ORIENTATIONS = 6 +"""Minimum number of nonzero b-values in a DWI dataset.""" + @attr.s(slots=True) class DWI(BaseDataset[np.ndarray | None]): @@ -226,7 +250,7 @@ def load( bvec_file: Path | str | None = None, bval_file: Path | str | None = None, b0_file: Path | str | None = None, - b0_thres: float = 50.0, + b0_thres: float = DEFAULT_LOWB_THRESHOLD, ) -> DWI: """ Load DWI data and construct a DWI object. @@ -342,3 +366,87 @@ def load( dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) return dwi_obj + + +def find_shelling_scheme( + bvals, + num_bins=DEFAULT_NUM_BINS, + multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR, + bval_cap=DEFAULT_HIGHB_THRESHOLD, +): + """ + Find the shelling scheme on the given b-values. + + Computes the histogram of the b-values according to ``num_bins`` + and depending on the nonempty bin count, classify the shelling scheme + as single-shell if they are 2 (low-b and a shell); multi-shell if they are + below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise. + + Parameters + ---------- + bvals : :obj:`list` or :obj:`~numpy.ndarray` + List or array of b-values. + num_bins : :obj:`int`, optional + Number of bins. + multishell_nonempty_bin_count_thr : :obj:`int`, optional + Bin count to consider a multi-shell scheme. + + Returns + ------- + scheme : :obj:`str` + Shelling scheme. + bval_groups : :obj:`list` + List of grouped b-values. + bval_estimated : :obj:`list` + List of 'estimated' b-values as the median value of each b-value group. + + """ + + # Bin the b-values: use -1 as the lower bound to be able to appropriately + # include b0 values + hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap))) + + # Collect values in each bin + bval_groups = [] + bval_estimated = [] + for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False): + # Add only if a nonempty b-values mask + if (mask := (bvals > lower) & (bvals <= upper)).sum(): + bval_groups.append(bvals[mask]) + bval_estimated.append(np.median(bvals[mask])) + + nonempty_bins = len(bval_groups) + + if nonempty_bins < 2: + raise ValueError("DWI must have at least one high-b shell") + + if nonempty_bins == 2: + scheme = "single-shell" + elif nonempty_bins < multishell_nonempty_bin_count_thr: + scheme = "multi-shell" + else: + scheme = "DSI" + + return scheme, bval_groups, bval_estimated + + +def _rasb2dipy(gradient): + import warnings + + gradient = np.asanyarray(gradient) + if gradient.ndim == 1: + if gradient.size != 4: + raise ValueError("Missing gradient information.") + gradient = gradient[..., np.newaxis] + + if gradient.shape[0] != 4: + gradient = gradient.T + elif gradient.shape == (4, 4): + print("Warning: make sure gradient information is not transposed!") + + with warnings.catch_warnings(): + from dipy.core.gradients import gradient_table + + warnings.filterwarnings("ignore", category=UserWarning) + retval = gradient_table(gradient[3, :], bvecs=gradient[:3, :].T) + return retval diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index e9d59c39..c197991a 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -22,229 +22,154 @@ # """A model-based algorithm for the realignment of dMRI data.""" +from __future__ import annotations + from pathlib import Path -from tempfile import TemporaryDirectory, mkstemp +from tempfile import TemporaryDirectory +from typing import Self, TypeVar -import nibabel as nb from tqdm import tqdm -from nifreeze.data.splitting import lovo_split -from nifreeze.model.base import ModelFactory -from nifreeze.registration.ants import _prepare_registration_data, _run_registration +from nifreeze.data.base import BaseDataset +from nifreeze.model.base import BaseModel, ModelFactory +from nifreeze.registration.ants import ( + _prepare_registration_data, + _run_registration, +) from nifreeze.utils import iterators +DatasetT = TypeVar("DatasetT", bound=BaseDataset) + + +class Filter: + """Alters an input data object (e.g., downsampling).""" + + def run(self, dataset: DatasetT, **kwargs) -> DatasetT: + """ + Trigger execution of the designated filter. + + Parameters + ---------- + dataset : :obj:`~nifreeze.data.base.BaseDataset` + The input dataset this estimator operates on. + + Returns + ------- + dataset : :obj:`~nifreeze.data.base.BaseDataset` + The dataset, after filtering. + + """ + return dataset + class Estimator: """Estimates rigid-body head-motion and distortions derived from eddy-currents.""" - @staticmethod - def estimate( - data, - *, - align_kwargs=None, - iter_kwargs=None, - models=("b0",), - omp_nthreads=None, - n_jobs=None, + __slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs") + + def __init__( + self, + model: BaseModel | str, + strategy: str = "random", + prev: Estimator | Filter | None = None, + model_kwargs: dict | None = None, **kwargs, ): - r""" - Estimate head-motion and Eddy currents. + self._model = model + self._prev = prev + self._strategy = strategy + self._model_kwargs = model_kwargs or {} + self._align_kwargs = kwargs or {} + + def run(self, dataset: DatasetT, **kwargs) -> Self: + """ + Trigger execution of the workflow this estimator belongs. Parameters ---------- - data : :obj:`~nifreeze.dmri.DWI` - The target DWI dataset, represented by this tool's internal - type. The object is used in-place, and will contain the estimated - parameters in its ``motion_affines`` property, as well as the rotated - *b*-vectors within its ``gradients`` property. - n_iter : :obj:`int` - Number of iterations this particular model is going to be repeated. - align_kwargs : :obj:`dict` - Parameters to configure the image registration process. - iter_kwargs : :obj:`dict` - Parameters to configure the iterator strategy to traverse timepoints/orientations. - models : :obj:`list` - Selects the diffusion model that will generate the registration target - corresponding to each gradient map. - See :obj:`~nifreeze.model.ModelFactory` for allowed models (and corresponding - keywords). - omp_nthreads : :obj:`int` - Maximum number of threads an individual process may use. - n_jobs : :obj:`int` - Number of parallel jobs. - - Return - ------ - :obj:`list` of :obj:`numpy.ndarray` - A list of :math:`4 \times 4` affine matrices encoding the estimated - parameters of the deformations caused by head-motion and eddy-currents. + dataset : :obj:`~nifreeze.data.base.BaseDataset` + The input dataset this estimator operates on. - """ + Returns + ------- + :obj:`~nifreeze.estimator.Estimator` + The estimator, after fitting. - # Massage iterator configuration - iter_kwargs = iter_kwargs or {} - iter_kwargs = { - "seed": None, - "bvals": None, # TODO: extract b-vals here if pertinent - } | iter_kwargs - iter_kwargs["size"] = len(data) - - iterfunc = getattr(iterators, f"{iter_kwargs.pop('strategy', 'random')}_iterator") - index_order = list(iterfunc(**iter_kwargs)) - - align_kwargs = align_kwargs or {} - - if "num_threads" not in align_kwargs and omp_nthreads is not None: - align_kwargs["num_threads"] = omp_nthreads - - n_iter = len(models) - - reg_target_type = ( - align_kwargs.pop("fixed_modality", None), - align_kwargs.pop("moving_modality", None), - ) - - for i_iter, model in enumerate(models): - # When downsampling these need to be set per-level - bmask_img = _prepare_brainmask_data(data.brainmask, data.affine) - - _prepare_kwargs(data, kwargs) - - single_model = model.lower() in ( - "b0", - "s0", - "avg", - "average", - "mean", - "gp", - ) or model.lower().startswith("full") - - dwmodel = None - if single_model: - if model.lower().startswith("full"): - model = model[4:] - - # Factory creates the appropriate model and pipes arguments - dwmodel = ModelFactory.init( - model=model, - **kwargs, - ) - dwmodel.fit(data.dataobj, n_jobs=n_jobs) - - with TemporaryDirectory() as tmp_dir: - print(f"Processing in <{tmp_dir}>") - ptmp_dir = Path(tmp_dir) - with tqdm(total=len(index_order), unit="dwi") as pbar: - # run a original-to-synthetic affine registration - for i in index_order: - pbar.set_description_str( - f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>" - ) - data_train, data_test = lovo_split(data, i) - grad_str = f"{i}, {data_test[-1][:3]}, b={int(data_test[-1][3])}" - pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs") - - if not single_model: # A true LOGO estimator - if hasattr(data, "gradients"): - kwargs["gtab"] = data_train[-1] - # Factory creates the appropriate model and pipes arguments - dwmodel = ModelFactory.init( - model=model, - n_jobs=n_jobs, - **kwargs, - ) - - # fit the model - dwmodel.fit( - data_train[0], - n_jobs=n_jobs, - ) - - # generate a synthetic dw volume for the test gradient - predicted = dwmodel.predict(data_test[-1]) - - # prepare data for running ANTs - fixed, moving = _prepare_registration_data( - data_test[0], predicted, data.affine, i, ptmp_dir, reg_target_type - ) - - pbar.set_description_str( - f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{i}>" - ) - - xform = _run_registration( - fixed, - moving, - bmask_img, - data.motion_affines, - data.affine, - data.dataobj.shape[:3], - data_test[-1][3], - i_iter, - i, - ptmp_dir, - reg_target_type, - align_kwargs, - ) - - # update - data.set_transform(i, xform.matrix) - pbar.update() - - return data.motion_affines - - -def _prepare_brainmask_data(brainmask, affine): - """Prepare the brainmask data: save the data to disk. - - Parameters - ---------- - brainmask : :obj:`numpy.ndarray` - Brainmask data. - affine : :obj:`numpy.ndarray` - Affine transformation matrix. - - Returns - ------- - bmask_img : :class:`~nibabel.nifti1.Nifti1Image` - Brainmask image. - """ - - bmask_img = None - if brainmask is not None: - _, bmask_img = mkstemp(suffix="_bmask.nii.gz") - nb.Nifti1Image(brainmask.astype("uint8"), affine, None).to_filename(bmask_img) - return bmask_img - - -def _prepare_kwargs(data, kwargs): - """Prepare the keyword arguments depending on the DWI data: add attributes corresponding to - the ``brainmask``, ``bzero``, ``gradients``, ``frame_time``, and ``total_duration`` DWI data - properties. - - Modifies kwargs in-place. - - Parameters - ---------- - data : :class:`nifreeze.data.dmri.DWI` - DWI data object. - kwargs : :obj:`dict` - Keyword arguments. - """ - from nifreeze.data.filtering import advanced_clip as _advanced_clip - - if data.brainmask is not None: - kwargs["mask"] = data.brainmask - - if hasattr(data, "bzero") and data.bzero is not None: - kwargs["S0"] = _advanced_clip(data.bzero) - - if hasattr(data, "gradients"): - kwargs["gtab"] = data.gradients - - if hasattr(data, "frame_time"): - kwargs["timepoints"] = data.frame_time - - if hasattr(data, "total_duration"): - kwargs["xlim"] = data.total_duration + """ + if self._prev is not None: + result = self._prev.run(dataset, **kwargs) + if isinstance(self._prev, Filter): + dataset = result # type: ignore[assignment] + + n_jobs = kwargs.get("n_jobs", None) + + # Prepare iterator + iterfunc = getattr(iterators, f"{self._strategy}_iterator") + index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None)) + + # Initialize model + if isinstance(self._model, str): + # Factory creates the appropriate model and pipes arguments + self._model = ModelFactory.init( + model=self._model, + dataset=dataset, + **self._model_kwargs, + ) + + kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None) + kwargs = self._align_kwargs | kwargs + + dataset_length = len(dataset) + with TemporaryDirectory() as tmp_dir: + print(f"Processing in <{tmp_dir}>") + ptmp_dir = Path(tmp_dir) + + bmask_path = None + if dataset.brainmask is not None: + import nibabel as nb + + bmask_path = ptmp_dir / "brainmask.nii.gz" + nb.Nifti1Image( + dataset.brainmask.astype("uint8"), dataset.affine, None + ).to_filename(bmask_path) + + with tqdm(total=dataset_length, unit="vols.") as pbar: + # run a original-to-synthetic affine registration + for i in index_iter: + pbar.set_description_str(f"Fit and predict vol. <{i}>") + + # fit the model + test_set = dataset[i] + predicted = self._model.fit_predict( # type: ignore[union-attr] + i, + n_jobs=n_jobs, + ) + + # prepare data for running ANTs + predicted_path, volume_path, init_path = _prepare_registration_data( + test_set[0], + predicted, + dataset.affine, + i, + ptmp_dir, + kwargs.pop("clip", "both"), + ) + + pbar.set_description_str(f"Realign vol. <{i}>") + + xform = _run_registration( + predicted_path, + volume_path, + i, + ptmp_dir, + init_affine=init_path, + fixedmask_path=bmask_path, + output_transform_prefix=f"ants-{i:05d}", + **kwargs, + ) + + # update + dataset.set_transform(i, xform.matrix) + pbar.update() + + return self diff --git a/src/nifreeze/model/__init__.py b/src/nifreeze/model/__init__.py index d3b425b0..edbe6d3e 100644 --- a/src/nifreeze/model/__init__.py +++ b/src/nifreeze/model/__init__.py @@ -23,7 +23,7 @@ """Data models.""" from nifreeze.model.base import ( - AverageModel, + ExpectationModel, ModelFactory, TrivialModel, ) @@ -37,7 +37,7 @@ __all__ = ( "ModelFactory", - "AverageModel", + "ExpectationModel", "AverageDWIModel", "DKIModel", "DTIModel", diff --git a/src/nifreeze/model/_dipy.py b/src/nifreeze/model/_dipy.py index 4f4b3020..adc1b15b 100644 --- a/src/nifreeze/model/_dipy.py +++ b/src/nifreeze/model/_dipy.py @@ -24,7 +24,6 @@ from __future__ import annotations -import warnings from typing import Any import numpy as np @@ -272,23 +271,3 @@ def predict( """ return gp_prediction(self.model, gtab, mask=self.mask) - - -def _rasb2dipy(gradient): - gradient = np.asanyarray(gradient) - if gradient.ndim == 1: - if gradient.size != 4: - raise ValueError("Missing gradient information.") - gradient = gradient[..., np.newaxis] - - if gradient.shape[0] != 4: - gradient = gradient.T - elif gradient.shape == (4, 4): - print("Warning: make sure gradient information is not transposed!") - - with warnings.catch_warnings(): - from dipy.core.gradients import gradient_table - - warnings.filterwarnings("ignore", category=UserWarning) - retval = gradient_table(gradient[3, :], bvecs=gradient[:3, :].T) - return retval diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index 7c12bfd8..cd4ca744 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -22,12 +22,11 @@ # """Base infrastructure for nifreeze's models.""" +from abc import abstractmethod from warnings import warn import numpy as np -from nifreeze.exceptions import ModelNotFittedError - mask_absence_warn_msg = ( "No mask provided; consider using a mask to avoid issues in model optimization." ) @@ -37,7 +36,7 @@ class ModelFactory: """A factory for instantiating data models.""" @staticmethod - def init(model="DTI", **kwargs): + def init(model=None, **kwargs): """ Instantiate a diffusion model. @@ -45,7 +44,7 @@ def init(model="DTI", **kwargs): ---------- model : :obj:`str` Diffusion model. - Options: ``"DTI"``, ``"DKI"``, ``"S0"``, ``"AverageDW"`` + Options: ``"DTI"``, ``"DKI"``, ``"S0"``, ``"AverageDWI"`` Return ------ @@ -53,22 +52,23 @@ def init(model="DTI", **kwargs): A model object compliant with DIPY's interface. """ + if model is None: + raise RuntimeError("No model identifier provided.") + if model.lower() in ("s0", "b0"): - return TrivialModel( - mask=kwargs.pop("mask"), predicted=kwargs.pop("S0"), gtab=kwargs.pop("gtab") - ) + return TrivialModel(kwargs.pop("dataset")) + + if model.lower() in ("avg", "average", "mean"): + return ExpectationModel(kwargs.pop("dataset"), **kwargs) if model.lower() in ("avgdwi", "averagedwi", "meandwi"): from nifreeze.model.dmri import AverageDWIModel - return AverageDWIModel(**kwargs) - - if model.lower() in ("avg", "average", "mean"): - return AverageModel(**kwargs) + return AverageDWIModel(kwargs.pop("dataset"), **kwargs) if model.lower() in ("dti", "dki", "pet"): Model = globals()[f"{model.upper()}Model"] - return Model(**kwargs) + return Model(kwargs.pop("dataset"), **kwargs) raise NotImplementedError(f"Unsupported model <{model}>.") @@ -84,44 +84,20 @@ class BaseModel: """ - __slots__ = ( - "_model", - "_mask", - "_models", - "_datashape", - "_is_fitted", - "_modelargs", - ) + __slots__ = ("_dataset",) - def __init__(self, mask=None, **kwargs): + def __init__(self, dataset, **kwargs): """Base initialization.""" - # Keep model state - self._model = None # "Main" model - self._models = None # For parallel (chunked) execution - - # Setup brain mask - if mask is None: + self._dataset = dataset + # Warn if mask not present + if dataset.brainmask is None: warn(mask_absence_warn_msg, stacklevel=2) - self._mask = mask - - self._datashape = None - self._is_fitted = False - - self._modelargs = () - - @property - def is_fitted(self): - return self._is_fitted - - def fit(self, data, **kwargs): - """Abstract member signature of fit().""" - raise NotImplementedError("Cannot call fit() on a BaseModel instance.") - - def predict(self, *args, **kwargs): - """Abstract member signature of predict().""" - raise NotImplementedError("Cannot call predict() on a BaseModel instance.") + @abstractmethod + def fit_predict(self, index, **kwargs) -> np.ndarray: + """Fit and predict the indicate index of the dataset (abstract signature).""" + raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.") class TrivialModel(BaseModel): @@ -129,66 +105,53 @@ class TrivialModel(BaseModel): __slots__ = ("_predicted",) - def __init__(self, predicted=None, **kwargs): + def __init__(self, dataset, predicted=None, **kwargs): """Implement object initialization.""" - if predicted is None: - raise TypeError("This model requires the predicted map at initialization") - - super().__init__(**kwargs) - self._predicted = predicted - self._datashape = predicted.shape - @property - def is_fitted(self): - return True + super().__init__(dataset, **kwargs) + self._predicted = ( + predicted + if predicted is not None + # Infer from dataset if not provided at initialization + else getattr(dataset, "reference", getattr(dataset, "bzero", None)) + ) - def fit(self, data, **kwargs): - """Do nothing.""" + if self._predicted is None: + raise TypeError("This model requires the predicted map at initialization") - def predict(self, *_, **kwargs): - """Return the *b=0* map.""" + def fit_predict(self, *_, **kwargs): + """Return the reference map.""" # No need to check fit (if not fitted, has raised already) return self._predicted -class AverageModel(BaseModel): - """A trivial model that returns an average map.""" +class ExpectationModel(BaseModel): + """A trivial model that returns an expectation map (for example, average).""" - __slots__ = ("_data",) + __slots__ = ("_stat",) - def __init__(self, **kwargs): + def __init__(self, dataset, stat="median", **kwargs): """Initialize a new model.""" - super().__init__(**kwargs) - self._data = None - - def fit(self, data, **kwargs): - """Calculate the average.""" - - # Regress out global signal differences - if kwargs.pop("equalize", False): - data = data.copy().astype("float32") - reshaped_data = ( - data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] - ) - p5 = np.percentile(reshaped_data, 5.0, axis=0) - p95 = np.percentile(reshaped_data, 95.0, axis=0) - p5 - data = (data - p5) * p95.mean() / p95 + p5.mean() + super().__init__(dataset, **kwargs) + self._stat = stat - # Select the summary statistic - avg_func = getattr(np, kwargs.pop("stat", "mean")) - - # Calculate the average - self._data = avg_func(data, axis=-1) + def fit_predict(self, index, **kwargs): + """ + Return the expectation map. - @property - def is_fitted(self): - return self._data is not None + Parameters + ---------- + index : :obj:`int` + The volume index that is left-out in fitting, and then predicted. - def predict(self, *_, **kwargs): - """Return the average map.""" + """ + # Select the summary statistic + avg_func = getattr(np, kwargs.pop("stat", self._stat)) - if self._data is None: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + # Create index mask + mask = np.ones(len(self._dataset), dtype=bool) + mask[index] = False - return self._data + # Calculate the average + return avg_func(self._dataset.dataobj[mask][0], axis=-1) diff --git a/src/nifreeze/model/dmri.py b/src/nifreeze/model/dmri.py index bc4658ee..6b68765d 100644 --- a/src/nifreeze/model/dmri.py +++ b/src/nifreeze/model/dmri.py @@ -26,9 +26,11 @@ import numpy as np from joblib import Parallel, delayed -from nifreeze.exceptions import ModelNotFittedError -from nifreeze.model._dipy import _rasb2dipy -from nifreeze.model.base import BaseModel +from nifreeze.data.dmri import ( + DEFAULT_CLIP_PERCENTILE, + DTI_MIN_ORIENTATIONS, +) +from nifreeze.model.base import BaseModel, ExpectationModel def _exec_fit(model, data, chunk=None): @@ -41,84 +43,49 @@ def _exec_predict(model, chunk=None, **kwargs): return np.squeeze(model.predict(**kwargs)), chunk -DEFAULT_CLIP_PERCENTILE = 75 -"""Upper percentile threshold for intensity clipping.""" - -DEFAULT_MIN_S0 = 1e-5 -"""Minimum value when considering the :math:`S_{0}` DWI signal.""" - -DEFAULT_MAX_S0 = 1.0 -"""Maximum value when considering the :math:`S_{0}` DWI signal.""" - -DEFAULT_MAX_BVALUE = 1000 -"""Maximum allowed value for the b-value.""" - -DEFAULT_LOWB_THRESHOLD = 50 -"""The lower bound for the b-value so that the orientation is considered a DW volume.""" - -DEFAULT_HIGHB_THRESHOLD = 10000 -"""A b-value cap for DWI data.""" - -DEFAULT_NUM_BINS = 15 -"""Number of bins to classify b-values.""" - -DEFAULT_MULTISHELL_BIN_COUNT_THR = 7 -"""Default bin count to consider a multishell scheme.""" - -DEFAULT_MAX_BVAL = 8000 -"""Maximum b-value cap.""" - - class BaseDWIModel(BaseModel): """Interface and default methods for DWI models.""" - __slots__ = ( - "_gtab", - "_S0", - "_b_max", - "_model_class", # Defining a model class, DIPY models are instantiated automagically - "_modelargs", - ) + __slots__ = { + "_model_class": "Defining a model class, DIPY models are instantiated automagically", + "_modelargs": "Arguments acceptable by the underlying DIPY-like model.", + } - def __init__(self, gtab, S0=None, b_max=None, **kwargs): - """Initialization. + def __init__(self, dataset, **kwargs): + r"""Initialization. Parameters ---------- - gtab : :obj:`numpy.ndarray` - An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and - columns are b-vector components and corresponding b-value, respectively. - S0 : :obj:`numpy.ndarray` - :math:`S_{0}` signal. - b_max : :obj:`int` - Maximum value to cap b-values. + dataset : :obj:`~nifreeze.data.dmri.DWI` + Reference to a DWI object. """ - super().__init__(**kwargs) + # Duck typing, instead of explicitly test for DWI type + if not hasattr(dataset, "bzero"): + raise TypeError("Dataset MUST be a DWI object.") - # Setup B0 map - self._S0 = None - if S0 is not None: - self._S0 = np.clip( - S0.astype("float32") / S0.max(), - a_min=DEFAULT_MIN_S0, - a_max=DEFAULT_MAX_S0, + if not hasattr(dataset, "gradients") or dataset.gradients is None: + raise ValueError("Dataset MUST have a gradient table.") + + if dataset.gradients.shape[0] < DTI_MIN_ORIENTATIONS: + raise ValueError( + f"DWI dataset is too small ({dataset.gradients.shape[0]} directions)." ) - # Cap b-values, if requested - self._gtab = gtab - self._b_max = None - if b_max and b_max > DEFAULT_MAX_BVALUE: - # Saturate b-values at b_max, since signal stops dropping - self._gtab[-1, self._gtab[-1] > b_max] = b_max - # A possibly good alternative is completely remove very high b-values - # bval_mask = gtab[-1] < b_max - # data = data[..., bval_mask] - # gtab = gtab[:, bval_mask] - self._b_max = b_max + super().__init__(dataset, **kwargs) + + def _fit(self, index, n_jobs=None, **kwargs): + """Fit the model chunk-by-chunk asynchronously""" + n_jobs = n_jobs or 1 + + brainmask = self._dataset.brainmask + idxmask = np.ones(len(self._dataset), dtype=bool) + idxmask[index] = False - kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} + data, _, gtab = self._dataset[idxmask] + # Select voxels within mask or just unravel 3D if no mask + data = data[brainmask, ...] if brainmask is not None else data.reshape(-1, data.shape[-1]) # DIPY models (or one with a fully-compliant interface) model_str = getattr(self, "_model_class", None) @@ -127,18 +94,7 @@ def __init__(self, gtab, S0=None, b_max=None, **kwargs): self._model = getattr( import_module(module_name), class_name, - )(_rasb2dipy(gtab), **kwargs) - - def fit(self, data, n_jobs=None, **kwargs): - """Fit the model chunk-by-chunk asynchronously""" - n_jobs = n_jobs or 1 - - self._datashape = data.shape - - # Select voxels within mask or just unravel 3D if no mask - data = ( - data[self._mask, ...] if self._mask is not None else data.reshape(-1, data.shape[-1]) - ) + )(gtab, **kwargs) # One single CPU - linear execution (full model) if n_jobs == 1: @@ -155,37 +111,31 @@ def fit(self, data, n_jobs=None, **kwargs): results = executor( delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks) ) - for submodel, index in results: - self._models[index] = submodel + for submodel, rindex in results: + self._models[rindex] = submodel - self._is_fitted = True self._model = None # Preempt further actions on the model + return n_jobs - def predict(self, gradient=None, **kwargs): - """Predict asynchronously chunk-by-chunk the diffusion signal.""" - - if gradient is None: - raise ValueError("A gradient to be simulated (b-vector, b-value) must be provided") - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + def fit_predict(self, index, **kwargs): + """ + Predict asynchronously chunk-by-chunk the diffusion signal. - gradient = np.array(gradient) # Tuples are unmutable + Parameters + ---------- + index : :obj:`int` + The volume index that is left-out in fitting, and then predicted. - # Cap the b-value if b_max is defined - gradient[-1] = min(gradient[-1], self._b_max or gradient[-1]) + """ - gradient = _rasb2dipy(gradient) + n_models = self._fit(index, **kwargs) - S0 = None - if self._S0 is not None: - S0 = ( - self._S0[self._mask, ...] - if self._mask is not None - else self._S0.reshape(-1, self._S0.shape[-1]) - ) + brainmask = self._dataset.brainmask + gradient = self._dataset.gradients[index] - n_models = len(self._models) if self._model is None and self._models else 1 + S0 = self._dataset.bzero + if S0 is not None: + S0 = S0[brainmask, ...] if brainmask is not None else S0.reshape(-1) if n_models == 1: predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0})) @@ -209,21 +159,21 @@ def predict(self, gradient=None, **kwargs): predicted = np.hstack(predicted) - if self._mask is not None: - retval = np.zeros_like(self._mask, dtype="float32") - retval[self._mask, ...] = predicted + if brainmask is not None: + retval = np.zeros_like(brainmask, dtype="float32") + retval[brainmask, ...] = predicted else: - retval = predicted.reshape(self._datashape[:-1]) + retval = predicted.reshape(self._dataset.shape[:-1]) return retval -class AverageDWIModel(BaseDWIModel): +class AverageDWIModel(ExpectationModel): """A trivial model that returns an average DWI volume.""" - __slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted") + __slots__ = ("_th_low", "_th_high", "_detrend") - def __init__(self, **kwargs): + def __init__(self, dataset, stat="median", th_low=100, th_high=100, detrend=False, **kwargs): r""" Implement object initialization. @@ -235,7 +185,7 @@ def __init__(self, **kwargs): th_high : :obj:`numbers.Number` An upper bound for the b-value corresponding to the diffusion weighted images that will be averaged. - bias : :obj:`bool` + detrend : :obj:`bool` Whether the overall distribution of each diffusion weighted image will be standardized and centered around the :data:`src.nifreeze.model.base.DEFAULT_CLIP_PERCENTILE` percentile. @@ -243,49 +193,42 @@ def __init__(self, **kwargs): Whether the summary statistic to apply is ``"mean"`` or ``"median"``. """ - super().__init__(**kwargs) + super().__init__(dataset, stat=stat, **kwargs) - self._th_low = kwargs.get("th_low", DEFAULT_LOWB_THRESHOLD) - self._th_high = kwargs.get("th_high", DEFAULT_HIGHB_THRESHOLD) - self._bias = kwargs.get("bias", True) - self._stat = kwargs.get("stat", "median") - self._data = None + self._th_low = th_low + self._th_high = th_high + self._detrend = detrend - def fit(self, data, **kwargs): - """Calculate the average.""" + def fit_predict(self, index, *_, **kwargs): + """Return the average map.""" + + bvalues = self._dataset.gradients[:, -1] + bcenter = bvalues[index] + + shellmask = np.ones(len(self._dataset), dtype=bool) - if (gtab := kwargs.pop("gtab", None)) is None: - raise ValueError("A gradient table must be provided.") + # Keep only bvalues within the range defined by th_high and th_low + shellmask[index] = False + shellmask[bvalues > (bcenter + self._th_high)] = False + shellmask[bvalues < (bcenter - self._th_low)] = False - # Select the interval of b-values for which DWIs will be averaged - b_mask = ( - ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) - if gtab is not None - else np.ones((data.shape[-1],), dtype=bool) - ) - shells = data[..., b_mask] + if not shellmask.sum(): + raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.") + + shelldata = self._dataset.dataobj[..., shellmask] # Regress out global signal differences - if self._bias: - centers = np.median(shells, axis=(0, 1, 2)) + if self._detrend: + centers = np.median(shelldata, axis=(0, 1, 2)) reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE) centers[centers < 1.0] = reference drift = reference / centers - shells = shells * drift + shelldata = shelldata * drift # Select the summary statistic avg_func = np.median if self._stat == "median" else np.mean # Calculate the average - self._data = avg_func(shells, axis=-1) - self._is_fitted = True - - def predict(self, *_, **kwargs): - """Return the average map.""" - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - return self._data + return avg_func(shelldata, axis=-1) class DTIModel(BaseDWIModel): @@ -314,65 +257,3 @@ class GPModel(BaseDWIModel): _modelargs = ("kernel_model",) _model_class = "nifreeze.model._dipy.GaussianProcessModel" - - -def find_shelling_scheme( - bvals, - num_bins=DEFAULT_NUM_BINS, - multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR, - bval_cap=DEFAULT_MAX_BVAL, -): - """ - Find the shelling scheme on the given b-values. - - Computes the histogram of the b-values according to ``num_bins`` - and depending on the nonempty bin count, classify the shelling scheme - as single-shell if they are 2 (low-b and a shell); multi-shell if they are - below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise. - - Parameters - ---------- - bvals : :obj:`list` or :obj:`~numpy.ndarray` - List or array of b-values. - num_bins : :obj:`int`, optional - Number of bins. - multishell_nonempty_bin_count_thr : :obj:`int`, optional - Bin count to consider a multi-shell scheme. - - Returns - ------- - scheme : :obj:`str` - Shelling scheme. - bval_groups : :obj:`list` - List of grouped b-values. - bval_estimated : :obj:`list` - List of 'estimated' b-values as the median value of each b-value group. - - """ - - # Bin the b-values: use -1 as the lower bound to be able to appropriately - # include b0 values - hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap))) - - # Collect values in each bin - bval_groups = [] - bval_estimated = [] - for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False): - # Add only if a nonempty b-values mask - if (mask := (bvals > lower) & (bvals <= upper)).sum(): - bval_groups.append(bvals[mask]) - bval_estimated.append(np.median(bvals[mask])) - - nonempty_bins = len(bval_groups) - - if nonempty_bins < 2: - raise ValueError("DWI must have at least one high-b shell") - - if nonempty_bins == 2: - scheme = "single-shell" - elif nonempty_bins < multishell_nonempty_bin_count_thr: - scheme = "multi-shell" - else: - scheme = "DSI" - - return scheme, bval_groups, bval_estimated diff --git a/src/nifreeze/registration/ants.py b/src/nifreeze/registration/ants.py index 712a0874..789721cc 100644 --- a/src/nifreeze/registration/ants.py +++ b/src/nifreeze/registration/ants.py @@ -92,49 +92,69 @@ def _to_nifti( def _prepare_registration_data( - dwframe: np.ndarray, + sample: np.ndarray, predicted: np.ndarray, affine: np.ndarray, vol_idx: int, dirname: Path | str, - reg_target_type: str, -) -> tuple[Path, Path]: + clip: str | bool | None = None, + init_affine: np.ndarray | None = None, +) -> tuple[Path, Path, Path | None]: """ - Prepare the registration data: save the fixed and moving images to disk. + Prepare the registration data: save the moving and predicted (fixed) images to disk. Parameters ---------- - dwframe : :obj:`~numpy.ndarray` - DWI data object. + sample : :obj:`~numpy.ndarray` + Current volume for which a transformation is to be estimated. predicted : :obj:`~numpy.ndarray` - Predicted data. + Predicted volume's data array (that is, spatial reference). affine : :obj:`numpy.ndarray` - Affine transformation matrix. + Orientation affine from the original NIfTI. vol_idx : :obj:`int` - DWI volume index. + Volume index. dirname : :obj:`os.pathlike` Directory name where the data is saved. - reg_target_type : :obj:`str` - Target registration type. + clip : :obj:`str` or ``None`` + Clip intensity of ``"sample"``, ``"predicted"``, ``"both"``, + or ``"none"`` of the images. Returns ------- - fixed : :obj:`~pathlib.Path` - Fixed image filename. - moving : :obj:`~pathlib.Path` - Moving image filename. + predicted_path : :obj:`~pathlib.Path` + Predicted image filename. + sample_path : :obj:`~pathlib.Path` + Current volume's filename. + init_path : :obj:`~pathlib.Path` or ``None`` + An initialization affine (for second and further estimators). + """ - moving = Path(dirname) / f"moving{vol_idx:05d}.nii.gz" - fixed = Path(dirname) / f"fixed{vol_idx:05d}.nii.gz" - _to_nifti(dwframe, affine, moving) + predicted_path = Path(dirname) / f"predicted_{vol_idx:05d}.nii.gz" + sample_path = Path(dirname) / f"sample_{vol_idx:05d}.nii.gz" + + _to_nifti( + sample, + affine, + sample_path, + clip=str(clip).lower() in ("sample", "both", "true"), + ) _to_nifti( predicted, affine, - fixed, - clip=reg_target_type == "dwi", + predicted_path, + clip=str(clip).lower() in ("predicted", "both", "true"), ) - return fixed, moving + + init_path = None + if init_affine is not None: + ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) + reference = ImageGrid(shape=sample.shape[:3], affine=affine) + initial_xform = Affine(matrix=init_affine, reference=reference) + init_path = Path(dirname) / f"init_{vol_idx:05d}.mat" + initial_xform.to_filename(init_path, fmt="itk") + + return predicted_path, sample_path, init_path def _get_ants_settings(settings: str = "b0-to-b0_level0") -> Path: @@ -213,8 +233,11 @@ def generate_command( movingmask_path: str | Path | list[str] | None = None, init_affine: str | Path | None = None, default: str = "b0-to-b0_level0", + terminal_output: str | None = None, + num_threads: int | None = None, + environ: dict | None = None, **kwargs, -) -> str: +) -> Registration: """ Generate an ANTs' command line. @@ -232,20 +255,26 @@ def generate_command( Initial affine transformation. default : :obj:`str`, optional Default settings configuration. + terminal_output : :obj:`str`, optional + Redirect terminal output (Nipype configuration) + environ : :obj:`dict`, optional + Add environment variables to the execution. + num_threads : :obj:`int`, optional + Set the number of threads for ANTs' execution. **kwargs : :obj:`dict` Additional parameters for ANTs registration. Returns ------- - :obj:`str` - The ANTs registration command line string. + :obj:`~nipype.interfaces.ants.Registration` + The configured Nipype interface of ANTs registration. Examples -------- >>> generate_command( ... fixed_path=repodata / 'fileA.nii.gz', ... moving_path=repodata / 'fileB.nii.gz', - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 12.0 ] \ @@ -267,7 +296,7 @@ def generate_command( ... fixed_path=repodata / 'fileA.nii.gz', ... moving_path=repodata / 'fileB.nii.gz', ... default="dwi-to-b0_level0", - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 0.01 ] --metric Mattes[ \ @@ -289,7 +318,7 @@ def generate_command( ... moving_path=repodata / 'fileB.nii.gz', ... fixedmask_path=repodata / 'maskA.nii.gz', ... default="dwi-to-b0_level0", - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 0.01 ] --metric Mattes[ \ @@ -312,7 +341,7 @@ def generate_command( ... fixed_path=repodata / 'fileA.nii.gz', ... moving_path=repodata / 'fileB.nii.gz', ... default="dwi-to-b0_level0", - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 0.01 ] --metric Mattes[ \ @@ -334,7 +363,7 @@ def generate_command( ... moving_path=repodata / 'fileB.nii.gz', ... fixedmask_path=[repodata / 'maskA.nii.gz'], ... default="dwi-to-b0_level0", - ... ) # doctest: +NORMALIZE_WHITESPACE + ... ).cmdline # doctest: +NORMALIZE_WHITESPACE 'antsRegistration --collapse-output-transforms 1 --dimensionality 3 \ --initialize-transforms-per-stage 0 --interpolation Linear --output transform \ --transform Rigid[ 0.01 ] --metric Mattes[ \ @@ -394,55 +423,40 @@ def generate_command( settings["initial_moving_transform"] = str(init_affine) # Generate command line with nipype and return - return Registration( + reg_iface = Registration( fixed_image=str(Path(fixed_path).absolute()), moving_image=str(Path(moving_path).absolute()), + terminal_output=terminal_output, + environ=environ or {}, **settings, - ).cmdline + ) + if num_threads: + reg_iface.inputs.num_threads = num_threads + + return reg_iface def _run_registration( - fixed: Path, - moving: Path, - bmask_img: nb.spatialimages.SpatialImage, - em_affines: np.ndarray, - affine: np.ndarray, - shape: tuple[int, int, int], - bval: int, - i_iter: int, + fixed_path: str | Path, + moving_path: str | Path, vol_idx: int, dirname: Path, - reg_target_type: str | tuple[str, str], - align_kwargs: dict, + **kwargs, ) -> nt.base.BaseTransform: """ Register the moving image to the fixed image. Parameters ---------- - fixed : :obj:`Path` + fixed_path : :obj:`Path` Fixed image filename. - moving : :obj:`Path` + moving_path : :obj:`Path` Moving image filename. - bmask_img : :class:`~nibabel.spatialimages.SpatialImage` - Brainmask image. - em_affines : :obj:`numpy.ndarray` - Estimated eddy motion affine transformation matrices. - affine : :obj:`numpy.ndarray` - Affine transformation matrix. - shape : :obj:`tuple` - Shape of the DWI frame. - bval : :obj:`int` - b-value of the corresponding DWI volume. - i_iter : :obj:`int` - Iteration number. vol_idx : :obj:`int` - DWI frame index. + Dataset volume index. dirname : :obj:`Path` Directory name where the transformation is saved. - reg_target_type : :obj:`str` or tuple of :obj:`str` - Target registration type. - align_kwargs : :obj:`dict` + kwargs : :obj:`dict` Parameters to configure the image registration process. Returns @@ -452,29 +466,27 @@ def _run_registration( """ - if isinstance(reg_target_type, str): - reg_target_type = (reg_target_type, reg_target_type) + align_kwargs = kwargs.copy() + environ = align_kwargs.pop("environ", None) + num_threads = align_kwargs.pop("num_threads", None) - registration = Registration( - terminal_output="file", - from_file=pkg_fn( - "nifreeze.registration", - f"config/{reg_target_type[0]}-to-{reg_target_type[1]}_level{i_iter}.json", - ), - fixed_image=str(fixed.absolute()), - moving_image=str(moving.absolute()), + if (seed := align_kwargs.pop("seed", None)) is not None: + environ = environ or {} + environ["ANTS_RANDOM_SEED"] = str(seed) + + if "ants_config" in kwargs: + align_kwargs["default"] = align_kwargs.pop("ants_config").replace(".json", "") + + registration = generate_command( + fixed_path, + moving_path, + environ=environ, + terminal_output="file_split", + num_threads=num_threads, **align_kwargs, ) - if bmask_img: - registration.inputs.fixed_image_masks = ["NULL", bmask_img] - if em_affines is not None and np.any(em_affines[vol_idx, ...]): - ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) - reference = ImageGrid(shape=shape, affine=affine) - initial_xform = Affine(matrix=em_affines[vol_idx], reference=reference) - mat_file = dirname / f"init_{i_iter}_{vol_idx:05d}.mat" - initial_xform.to_filename(mat_file, fmt="itk") - registration.inputs.initial_moving_transform = str(mat_file) + (dirname / f"cmd-{vol_idx:05d}.sh").write_text(registration.cmdline) # execute ants command line result = registration.run(cwd=str(dirname)).outputs @@ -482,12 +494,12 @@ def _run_registration( # read output transform xform = nt.linear.Affine( nt.io.itk.ITKLinearTransform.from_filename(result.forward_transforms[0]).to_ras( - reference=fixed, moving=moving + reference=fixed_path, moving=moving_path ), ) # debugging: generate aligned file for testing - xform.apply(moving, reference=fixed).to_filename( - dirname / f"aligned{vol_idx:05d}_{int(bval):04d}.nii.gz" + xform.apply(moving_path, reference=fixed_path).to_filename( + dirname / f"dbg_{vol_idx:05d}.nii.gz" ) return xform diff --git a/test/conftest.py b/test/conftest.py index f2bbbef0..c10c231c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -26,9 +26,12 @@ from pathlib import Path import nibabel as nb +import nitransforms as nt import numpy as np import pytest +from nifreeze.data.dmri import DWI + test_data_env = os.getenv("TEST_DATA_HOME", str(Path.home() / "nifreeze-tests")) test_output_dir = os.getenv("TEST_OUTPUT_DIR") test_workdir = os.getenv("TEST_WORK_DIR") @@ -54,19 +57,19 @@ def doctest_imports(doctest_namespace): doctest_namespace["repodata"] = _datadir -@pytest.fixture +@pytest.fixture(scope="session") def outdir(): """Determine if test artifacts should be stored somewhere or deleted.""" return None if test_output_dir is None else Path(test_output_dir) -@pytest.fixture +@pytest.fixture(scope="session") def datadir(): """Return a data path outside the package's structure (i.e., large datasets).""" return Path(test_data_env) -@pytest.fixture +@pytest.fixture(scope="session") def repodata(): """Return the path to this repository's test data folder.""" return _datadir @@ -80,6 +83,59 @@ def pytest_addoption(parser): ) +@pytest.fixture(scope="session") +def motion_data(tmp_path_factory, datadir): + # Temporary directory for session-scoped fixtures + tmp_path = tmp_path_factory.mktemp("motion_test_data") + + dwdata = DWI.from_filename(datadir / "dwi.h5") + b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None) + masknii = nb.Nifti1Image(dwdata.brainmask.astype("uint8"), dwdata.affine, None) + + # Generate a list of large-yet-plausible bulk-head motion + xfms = nt.linear.LinearTransformsMapping( + [ + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.03, z=0.005), (0.8, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.005), (0.8, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.02), (0.4, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.02), (0.4, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.002), (0.0, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.02, z=0.002), (0.0, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.01, z=0.002), (0.0, 0.4, 0.2)), + ], + reference=b0nii, + ) + + # Induce motion into dataset (i.e., apply the inverse transforms) + moved_nii = (~xfms).apply(b0nii, reference=b0nii) + + # Save the moved dataset for debugging or further processing + moved_path = tmp_path / "test.nii.gz" + ground_truth_path = tmp_path / "ground_truth.nii.gz" + moved_nii.to_filename(moved_path) + xfms.apply(moved_nii).to_filename(ground_truth_path) + + # Wrap into dataset object + dwi_motion = DWI( + dataobj=np.asanyarray(moved_nii.dataobj), + affine=b0nii.affine, + bzero=dwdata.bzero, + gradients=dwdata.gradients[..., : len(xfms)], + brainmask=dwdata.brainmask, + ) + + # Return data as a dictionary (or any format that makes sense for your tests) + return { + "b0nii": b0nii, + "masknii": masknii, + "moved_nii": moved_nii, + "xfms": xfms, + "moved_path": moved_path, + "ground_truth_path": ground_truth_path, + "moved_nifreeze": dwi_motion, + } + + @pytest.hookimpl(trylast=True) def pytest_sessionfinish(session, exitstatus): have_werrors = os.getenv("NIFREEZE_WERRORS", False) diff --git a/test/test_data_dmri.py b/test/test_data_dmri.py index 565b6d9a..a6afd79e 100644 --- a/test/test_data_dmri.py +++ b/test/test_data_dmri.py @@ -26,10 +26,7 @@ import numpy as np import pytest -from nifreeze.data.dmri import load -from nifreeze.model.dmri import ( - find_shelling_scheme, -) +from nifreeze.data.dmri import find_shelling_scheme, load def _create_dwi_random_dataobj(): diff --git a/test/test_integration.py b/test/test_integration.py index 858c76f8..84d2e3d2 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -24,77 +24,62 @@ from os import cpu_count -import nibabel as nb import nitransforms as nt -import numpy as np -from nifreeze.data.dmri import DWI from nifreeze.estimator import Estimator +from nifreeze.model.base import TrivialModel from nifreeze.registration.utils import displacements_within_mask -def test_proximity_estimator_trivial_model(datadir, tmp_path): +def test_proximity_estimator_trivial_model(motion_data, tmp_path): """Check the proximity of transforms estimated by the estimator with a trivial B0 model.""" - dwdata = DWI.from_filename(datadir / "dwi.h5") - b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None) - masknii = nb.Nifti1Image(dwdata.brainmask.astype(np.uint8), dwdata.affine, None) + b0nii = motion_data["b0nii"] + moved_nii = motion_data["moved_nii"] + xfms = motion_data["xfms"] + dwi_motion = motion_data["moved_nifreeze"] - # Generate a list of large-yet-plausible bulk-head motion. - xfms = nt.linear.LinearTransformsMapping( - [ - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.03, z=0.005), (0.8, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.005), (0.8, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.02), (0.4, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.02), (0.4, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.002), (0.0, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.02, z=0.002), (0.0, 0.2, 0.2)), - nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.01, z=0.002), (0.0, 0.4, 0.2)), - ], - reference=b0nii, - ) - - # Induce motion into dataset (i.e., apply the inverse transforms) - moved_nii = (~xfms).apply(b0nii, reference=b0nii) - - # Uncomment to see the moved dataset - moved_nii.to_filename(tmp_path / "test.nii.gz") - xfms.apply(moved_nii).to_filename(tmp_path / "ground_truth.nii.gz") - - # Wrap into dataset object - dwi_motion = DWI( - dataobj=moved_nii.dataobj, - affine=b0nii.affine, - bzero=dwdata.bzero, - gradients=dwdata.gradients[..., : len(xfms)], - brainmask=dwdata.brainmask, - ) - - estimator = Estimator() - em_affines = estimator.estimate( - data=dwi_motion, - models=("b0",), - seed=None, - align_kwargs={ - "fixed_modality": "b0", - "moving_modality": "b0", - "num_threads": min(cpu_count(), 8), - }, + model = TrivialModel(dwi_motion) + estimator = Estimator(model) + estimator.run( + dwi_motion, + seed=12345, + num_threads=min(cpu_count(), 8), ) # Uncomment to see the realigned dataset nt.linear.LinearTransformsMapping( - em_affines, + dwi_motion.motion_affines, reference=b0nii, ).apply(moved_nii).to_filename(tmp_path / "realigned.nii.gz") # For each moved b0 volume - for i, est in enumerate(em_affines): + for i, est in enumerate(dwi_motion.motion_affines): assert ( displacements_within_mask( - masknii, + motion_data["masknii"], nt.linear.Affine(est), xfms[i], ).max() < 0.25 ) + + +def test_stacked_estimators(motion_data): + """Check that models can be stacked.""" + + # Wrap into dataset object + dmri_dataset = motion_data["moved_nifreeze"] + + estimator1 = Estimator( + TrivialModel(dmri_dataset), + ants_config="dwi-to-dwi_level0.json", + clip=False, + ) + estimator2 = Estimator( + TrivialModel(dmri_dataset), + prev=estimator1, + clip=False, + ) + + estimator2.run(dmri_dataset) diff --git a/test/test_model.py b/test/test_model.py index 7f16a5a0..b4feddd3 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -29,12 +29,9 @@ from dipy.sims.voxel import single_tensor from nifreeze import model -from nifreeze.data.dmri import DWI -from nifreeze.data.splitting import lovo_split -from nifreeze.exceptions import ModelNotFittedError +from nifreeze.data.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0, DWI from nifreeze.model._dipy import GaussianProcessModel from nifreeze.model.base import mask_absence_warn_msg -from nifreeze.model.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0 from nifreeze.testing import simulations as _sim @@ -64,22 +61,23 @@ def test_trivial_model(use_mask): a_max=DEFAULT_MAX_S0, ) + data = DWI( + dataobj=(*_S0.shape, 10), + bzero=_clipped_S0, + brainmask=mask, + ) + with context: - tmodel = model.TrivialModel(mask=mask, predicted=_clipped_S0) + tmodel = model.TrivialModel(data) - data = None - assert tmodel.fit(data) is None + predicted = tmodel.fit_predict(4) - assert np.all(_clipped_S0 == tmodel.predict((1, 0, 0))) + assert np.all(_clipped_S0 == predicted) def test_average_model(): """Check the implementation of the average DW model.""" - size = (100, 100, 100, 6) - data = np.ones(size, dtype=float) - mask = np.ones(size, dtype=bool) - gtab = np.array( [ [0, 0, 0, 0], @@ -87,41 +85,41 @@ def test_average_model(): [0.25, 0.565, 0.21, 500], [-0.861, -0.464, 0.564, 1000], [0.307, -0.766, 0.677, 1000], - [0.736, 0.013, 0.774, 1300], + [0.736, 0.013, 0.774, 1000], + [-0.31, 0.933, 0.785, 1000], + [0.25, 0.565, 0.21, 2000], + [-0.861, -0.464, 0.564, 2000], + [0.307, -0.766, 0.677, 2000], ] ) + size = (100, 100, 100, gtab.shape[0]) + data = np.ones(size, dtype=float) + mask = np.ones(size[:3], dtype=bool) + data *= gtab[:, -1] + dataset = DWI(dataobj=data, gradients=gtab, brainmask=mask) - tmodel_mean = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, stat="mean") - tmodel_median = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, stat="median") - tmodel_1000 = model.AverageDWIModel(mask=mask, gtab=gtab, bias=False, th_high=1000, th_low=900) - tmodel_2000 = model.AverageDWIModel( - mask=mask, - gtab=gtab, - bias=False, - th_high=2000, - th_low=900, - stat="mean", - ) + tmodel_mean = model.AverageDWIModel(dataset, stat="mean") + tmodel_mean_full = model.AverageDWIModel(dataset, stat="mean", th_low=2000, th_high=2000) + tmodel_median = model.AverageDWIModel(dataset) - with pytest.raises(ModelNotFittedError): - tmodel_mean.predict([0, 0, 0]) + # Verify that average cannot be calculated in shells with one single value + with pytest.raises(RuntimeError): + tmodel_mean.fit_predict(2) - # Verify that fit function returns nothing - assert tmodel_mean.fit(data[..., 1:], gtab=gtab[1:].T) is None + assert np.allclose(tmodel_mean.fit_predict(3), 1000) + assert np.allclose(tmodel_median.fit_predict(3), 1000) - tmodel_median.fit(data[..., 1:], gtab=gtab[1:].T) - tmodel_1000.fit(data[..., 1:], gtab=gtab[1:].T) - tmodel_2000.fit(data[..., 1:], gtab=gtab[1:].T) + grads = list(gtab[:, -1]) + del grads[1] + assert np.allclose(tmodel_mean_full.fit_predict(1), np.mean(grads)) - # Verify that the right statistics is applied and that the model discard b-values < 50 - assert np.all(tmodel_mean.predict([0, 0, 0]) == 950) - assert np.all(tmodel_median.predict([0, 0, 0]) == 1000) + tmodel_mean_2000 = model.AverageDWIModel(dataset, stat="mean", th_low=1100) + tmodel_median_2000 = model.AverageDWIModel(dataset, th_low=1100) - # Verify that the threshold for b-value selection works as expected - assert np.all(tmodel_1000.predict([0, 0, 0]) == 1000) - assert np.all(tmodel_2000.predict([0, 0, 0]) == 1100) + assert np.allclose(tmodel_mean_2000.fit_predict(9), gtab[3:-1, -1].mean()) + assert np.allclose(tmodel_median_2000.fit_predict(9), 1000) @pytest.mark.parametrize( @@ -159,44 +157,26 @@ def test_gp_model(evals, S0, snr, hsph_dirs, bval_shell): assert prediction.shape == (2,) -def test_two_initialisations(datadir): +def test_factory(datadir): """Check that the two different initialisations result in the same models""" # Load test data dmri_dataset = DWI.from_filename(datadir / "dwi.h5") - # Split data into test and train set - data_train, data_test = lovo_split(dmri_dataset, 10) - + modelargs = { + "th_low": 25, + "th_high": 25, + "detrend": True, + "stat": "mean", + } # Direct initialisation - model1 = model.AverageDWIModel( - mask=dmri_dataset.brainmask.astype(bool), - gtab=data_train[-1], - S0=dmri_dataset.bzero, - th_low=100, - th_high=1000, - bias=False, - stat="mean", - ) - model1.fit(data_train[0], gtab=data_train[-1]) - predicted1 = model1.predict(data_test[-1]) + model1 = model.AverageDWIModel(dmri_dataset, **modelargs) # Initialisation via ModelFactory - model2 = model.ModelFactory.init( - model="avgdwi", - mask=dmri_dataset.brainmask.astype(bool), - gtab=data_train[-1], - S0=dmri_dataset.bzero, - th_low=100, - th_high=1000, - bias=False, - stat="mean", - ) - - with pytest.raises(ModelNotFittedError): - model2.predict(data_test[-1]) - - model2.fit(data_train[0], gtab=data_train[-1]) - predicted2 = model2.predict(data_test[-1]) + model2 = model.ModelFactory.init(model="avgdwi", dataset=dmri_dataset, **modelargs) - assert np.all(predicted1 == predicted2) + assert model1._dataset == model2._dataset + assert model1._detrend == model2._detrend + assert model1._th_low == model2._th_low + assert model1._th_high == model2._th_high + assert model1._stat == model2._stat