Skip to content

Commit

Permalink
Merge pull request #62 from nipreps/enh/estimator-refactor
Browse files Browse the repository at this point in the history
ENH: Implicitly-pipelined and modality-agnostic ``Estimator``
  • Loading branch information
oesteban authored Jan 24, 2025
2 parents d853cab + 6697c78 commit 7237ecf
Show file tree
Hide file tree
Showing 14 changed files with 635 additions and 744 deletions.
6 changes: 3 additions & 3 deletions docs/notebooks/bold_realignment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@
"metadata": {},
"outputs": [],
"source": [
"from nifreeze.model.base import AverageModel\n",
"from nifreeze.model.base import ExpectationModel\n",
"from nifreeze.utils.iterators import random_iterator"
]
},
Expand All @@ -358,7 +358,7 @@
" t_mask[t] = True\n",
"\n",
" # Fit and predict using the model\n",
" model = AverageModel()\n",
" model = ExpectationModel()\n",
" model.fit(\n",
" data[..., ~t_mask],\n",
" stat=\"median\",\n",
Expand All @@ -376,7 +376,7 @@
" fixedmask_path=brainmask_path,\n",
" output_transform_prefix=f\"conversion-{t:02d}\",\n",
" num_threads=8,\n",
" )\n",
" ).cmdline\n",
"\n",
" # Run the command\n",
" proc = await asyncio.create_subprocess_shell(\n",
Expand Down
2 changes: 1 addition & 1 deletion scripts/optimize_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def train_coro(
moving_path,
fixedmask_path=brainmask_path,
**_kwargs,
)
).cmdline

tasks.append(
ants(
Expand Down
19 changes: 12 additions & 7 deletions src/nifreeze/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pathlib import Path

from nifreeze.cli.parser import parse_args
from nifreeze.data.dmri import DWI
from nifreeze.data.base import BaseDataset
from nifreeze.estimator import Estimator


Expand All @@ -40,14 +40,19 @@ def main(argv=None) -> None:
args = parse_args(argv)

# Open the data with the given file path
dwi_dataset: DWI = DWI.from_filename(args.input_file)
dataset: BaseDataset = BaseDataset.from_filename(args.input_file)

estimator: Estimator = Estimator()
prev_model: Estimator | None = None
for _model in args.models:
estimator: Estimator = Estimator(
_model,
prev=prev_model,
)
prev_model = estimator

_ = estimator.estimate(
dwi_dataset,
_ = estimator.run(
dataset,
align_kwargs=args.align_config,
models=args.models,
omp_nthreads=args.nthreads,
njobs=args.njobs,
seed=args.seed,
Expand All @@ -58,7 +63,7 @@ def main(argv=None) -> None:
output_path: Path = Path(args.output_dir) / output_filename

# Save the DWI dataset to the output path
dwi_dataset.to_filename(output_path)
dataset.to_filename(output_path)


if __name__ == "__main__":
Expand Down
110 changes: 109 additions & 1 deletion src/nifreeze/data/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
from nifreeze.utils.ndimage import load_api

DEFAULT_CLIP_PERCENTILE = 75
"""Upper percentile threshold for intensity clipping."""

DEFAULT_MIN_S0 = 1e-5
"""Minimum value when considering the :math:`S_{0}` DWI signal."""

DEFAULT_MAX_S0 = 1.0
"""Maximum value when considering the :math:`S_{0}` DWI signal."""

DEFAULT_LOWB_THRESHOLD = 50
"""The lower bound for the b-value so that the orientation is considered a DW volume."""

DEFAULT_HIGHB_THRESHOLD = 8000
"""A b-value cap for DWI data."""

DEFAULT_NUM_BINS = 15
"""Number of bins to classify b-values."""

DEFAULT_MULTISHELL_BIN_COUNT_THR = 7
"""Default bin count to consider a multishell scheme."""

DTI_MIN_ORIENTATIONS = 6
"""Minimum number of nonzero b-values in a DWI dataset."""


@attr.s(slots=True)
class DWI(BaseDataset[np.ndarray | None]):
Expand Down Expand Up @@ -226,7 +250,7 @@ def load(
bvec_file: Path | str | None = None,
bval_file: Path | str | None = None,
b0_file: Path | str | None = None,
b0_thres: float = 50.0,
b0_thres: float = DEFAULT_LOWB_THRESHOLD,
) -> DWI:
"""
Load DWI data and construct a DWI object.
Expand Down Expand Up @@ -342,3 +366,87 @@ def load(
dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)

return dwi_obj


def find_shelling_scheme(
bvals,
num_bins=DEFAULT_NUM_BINS,
multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR,
bval_cap=DEFAULT_HIGHB_THRESHOLD,
):
"""
Find the shelling scheme on the given b-values.
Computes the histogram of the b-values according to ``num_bins``
and depending on the nonempty bin count, classify the shelling scheme
as single-shell if they are 2 (low-b and a shell); multi-shell if they are
below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise.
Parameters
----------
bvals : :obj:`list` or :obj:`~numpy.ndarray`
List or array of b-values.
num_bins : :obj:`int`, optional
Number of bins.
multishell_nonempty_bin_count_thr : :obj:`int`, optional
Bin count to consider a multi-shell scheme.
Returns
-------
scheme : :obj:`str`
Shelling scheme.
bval_groups : :obj:`list`
List of grouped b-values.
bval_estimated : :obj:`list`
List of 'estimated' b-values as the median value of each b-value group.
"""

# Bin the b-values: use -1 as the lower bound to be able to appropriately
# include b0 values
hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap)))

# Collect values in each bin
bval_groups = []
bval_estimated = []
for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False):
# Add only if a nonempty b-values mask
if (mask := (bvals > lower) & (bvals <= upper)).sum():
bval_groups.append(bvals[mask])
bval_estimated.append(np.median(bvals[mask]))

nonempty_bins = len(bval_groups)

if nonempty_bins < 2:
raise ValueError("DWI must have at least one high-b shell")

if nonempty_bins == 2:
scheme = "single-shell"
elif nonempty_bins < multishell_nonempty_bin_count_thr:
scheme = "multi-shell"
else:
scheme = "DSI"

return scheme, bval_groups, bval_estimated


def _rasb2dipy(gradient):
import warnings

gradient = np.asanyarray(gradient)
if gradient.ndim == 1:
if gradient.size != 4:
raise ValueError("Missing gradient information.")
gradient = gradient[..., np.newaxis]

if gradient.shape[0] != 4:
gradient = gradient.T
elif gradient.shape == (4, 4):
print("Warning: make sure gradient information is not transposed!")

with warnings.catch_warnings():
from dipy.core.gradients import gradient_table

warnings.filterwarnings("ignore", category=UserWarning)
retval = gradient_table(gradient[3, :], bvecs=gradient[:3, :].T)
return retval
Loading

0 comments on commit 7237ecf

Please sign in to comment.