Skip to content

Commit 7237ecf

Browse files
authored
Merge pull request #62 from nipreps/enh/estimator-refactor
ENH: Implicitly-pipelined and modality-agnostic ``Estimator``
2 parents d853cab + 6697c78 commit 7237ecf

File tree

14 files changed

+635
-744
lines changed

14 files changed

+635
-744
lines changed

docs/notebooks/bold_realignment.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@
337337
"metadata": {},
338338
"outputs": [],
339339
"source": [
340-
"from nifreeze.model.base import AverageModel\n",
340+
"from nifreeze.model.base import ExpectationModel\n",
341341
"from nifreeze.utils.iterators import random_iterator"
342342
]
343343
},
@@ -358,7 +358,7 @@
358358
" t_mask[t] = True\n",
359359
"\n",
360360
" # Fit and predict using the model\n",
361-
" model = AverageModel()\n",
361+
" model = ExpectationModel()\n",
362362
" model.fit(\n",
363363
" data[..., ~t_mask],\n",
364364
" stat=\"median\",\n",
@@ -376,7 +376,7 @@
376376
" fixedmask_path=brainmask_path,\n",
377377
" output_transform_prefix=f\"conversion-{t:02d}\",\n",
378378
" num_threads=8,\n",
379-
" )\n",
379+
" ).cmdline\n",
380380
"\n",
381381
" # Run the command\n",
382382
" proc = await asyncio.create_subprocess_shell(\n",

scripts/optimize_registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def train_coro(
134134
moving_path,
135135
fixedmask_path=brainmask_path,
136136
**_kwargs,
137-
)
137+
).cmdline
138138

139139
tasks.append(
140140
ants(

src/nifreeze/cli/run.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pathlib import Path
2626

2727
from nifreeze.cli.parser import parse_args
28-
from nifreeze.data.dmri import DWI
28+
from nifreeze.data.base import BaseDataset
2929
from nifreeze.estimator import Estimator
3030

3131

@@ -40,14 +40,19 @@ def main(argv=None) -> None:
4040
args = parse_args(argv)
4141

4242
# Open the data with the given file path
43-
dwi_dataset: DWI = DWI.from_filename(args.input_file)
43+
dataset: BaseDataset = BaseDataset.from_filename(args.input_file)
4444

45-
estimator: Estimator = Estimator()
45+
prev_model: Estimator | None = None
46+
for _model in args.models:
47+
estimator: Estimator = Estimator(
48+
_model,
49+
prev=prev_model,
50+
)
51+
prev_model = estimator
4652

47-
_ = estimator.estimate(
48-
dwi_dataset,
53+
_ = estimator.run(
54+
dataset,
4955
align_kwargs=args.align_config,
50-
models=args.models,
5156
omp_nthreads=args.nthreads,
5257
njobs=args.njobs,
5358
seed=args.seed,
@@ -58,7 +63,7 @@ def main(argv=None) -> None:
5863
output_path: Path = Path(args.output_dir) / output_filename
5964

6065
# Save the DWI dataset to the output path
61-
dwi_dataset.to_filename(output_path)
66+
dataset.to_filename(output_path)
6267

6368

6469
if __name__ == "__main__":

src/nifreeze/data/dmri.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,30 @@
3939
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
4040
from nifreeze.utils.ndimage import load_api
4141

42+
DEFAULT_CLIP_PERCENTILE = 75
43+
"""Upper percentile threshold for intensity clipping."""
44+
45+
DEFAULT_MIN_S0 = 1e-5
46+
"""Minimum value when considering the :math:`S_{0}` DWI signal."""
47+
48+
DEFAULT_MAX_S0 = 1.0
49+
"""Maximum value when considering the :math:`S_{0}` DWI signal."""
50+
51+
DEFAULT_LOWB_THRESHOLD = 50
52+
"""The lower bound for the b-value so that the orientation is considered a DW volume."""
53+
54+
DEFAULT_HIGHB_THRESHOLD = 8000
55+
"""A b-value cap for DWI data."""
56+
57+
DEFAULT_NUM_BINS = 15
58+
"""Number of bins to classify b-values."""
59+
60+
DEFAULT_MULTISHELL_BIN_COUNT_THR = 7
61+
"""Default bin count to consider a multishell scheme."""
62+
63+
DTI_MIN_ORIENTATIONS = 6
64+
"""Minimum number of nonzero b-values in a DWI dataset."""
65+
4266

4367
@attr.s(slots=True)
4468
class DWI(BaseDataset[np.ndarray | None]):
@@ -226,7 +250,7 @@ def load(
226250
bvec_file: Path | str | None = None,
227251
bval_file: Path | str | None = None,
228252
b0_file: Path | str | None = None,
229-
b0_thres: float = 50.0,
253+
b0_thres: float = DEFAULT_LOWB_THRESHOLD,
230254
) -> DWI:
231255
"""
232256
Load DWI data and construct a DWI object.
@@ -342,3 +366,87 @@ def load(
342366
dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
343367

344368
return dwi_obj
369+
370+
371+
def find_shelling_scheme(
372+
bvals,
373+
num_bins=DEFAULT_NUM_BINS,
374+
multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR,
375+
bval_cap=DEFAULT_HIGHB_THRESHOLD,
376+
):
377+
"""
378+
Find the shelling scheme on the given b-values.
379+
380+
Computes the histogram of the b-values according to ``num_bins``
381+
and depending on the nonempty bin count, classify the shelling scheme
382+
as single-shell if they are 2 (low-b and a shell); multi-shell if they are
383+
below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise.
384+
385+
Parameters
386+
----------
387+
bvals : :obj:`list` or :obj:`~numpy.ndarray`
388+
List or array of b-values.
389+
num_bins : :obj:`int`, optional
390+
Number of bins.
391+
multishell_nonempty_bin_count_thr : :obj:`int`, optional
392+
Bin count to consider a multi-shell scheme.
393+
394+
Returns
395+
-------
396+
scheme : :obj:`str`
397+
Shelling scheme.
398+
bval_groups : :obj:`list`
399+
List of grouped b-values.
400+
bval_estimated : :obj:`list`
401+
List of 'estimated' b-values as the median value of each b-value group.
402+
403+
"""
404+
405+
# Bin the b-values: use -1 as the lower bound to be able to appropriately
406+
# include b0 values
407+
hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap)))
408+
409+
# Collect values in each bin
410+
bval_groups = []
411+
bval_estimated = []
412+
for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False):
413+
# Add only if a nonempty b-values mask
414+
if (mask := (bvals > lower) & (bvals <= upper)).sum():
415+
bval_groups.append(bvals[mask])
416+
bval_estimated.append(np.median(bvals[mask]))
417+
418+
nonempty_bins = len(bval_groups)
419+
420+
if nonempty_bins < 2:
421+
raise ValueError("DWI must have at least one high-b shell")
422+
423+
if nonempty_bins == 2:
424+
scheme = "single-shell"
425+
elif nonempty_bins < multishell_nonempty_bin_count_thr:
426+
scheme = "multi-shell"
427+
else:
428+
scheme = "DSI"
429+
430+
return scheme, bval_groups, bval_estimated
431+
432+
433+
def _rasb2dipy(gradient):
434+
import warnings
435+
436+
gradient = np.asanyarray(gradient)
437+
if gradient.ndim == 1:
438+
if gradient.size != 4:
439+
raise ValueError("Missing gradient information.")
440+
gradient = gradient[..., np.newaxis]
441+
442+
if gradient.shape[0] != 4:
443+
gradient = gradient.T
444+
elif gradient.shape == (4, 4):
445+
print("Warning: make sure gradient information is not transposed!")
446+
447+
with warnings.catch_warnings():
448+
from dipy.core.gradients import gradient_table
449+
450+
warnings.filterwarnings("ignore", category=UserWarning)
451+
retval = gradient_table(gradient[3, :], bvecs=gradient[:3, :].T)
452+
return retval

0 commit comments

Comments
 (0)