Skip to content

Commit 7322e93

Browse files
committed
enh: continue with the refactor
1 parent 544249b commit 7322e93

13 files changed

+390
-470
lines changed

docs/notebooks/bold_realignment.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@
337337
"metadata": {},
338338
"outputs": [],
339339
"source": [
340-
"from nifreeze.model.base import AverageModel\n",
340+
"from nifreeze.model.base import ExpectationModel\n",
341341
"from nifreeze.utils.iterators import random_iterator"
342342
]
343343
},
@@ -358,7 +358,7 @@
358358
" t_mask[t] = True\n",
359359
"\n",
360360
" # Fit and predict using the model\n",
361-
" model = AverageModel()\n",
361+
" model = ExpectationModel()\n",
362362
" model.fit(\n",
363363
" data[..., ~t_mask],\n",
364364
" stat=\"median\",\n",
@@ -376,7 +376,7 @@
376376
" fixedmask_path=brainmask_path,\n",
377377
" output_transform_prefix=f\"conversion-{t:02d}\",\n",
378378
" num_threads=8,\n",
379-
" )\n",
379+
" ).cmdline\n",
380380
"\n",
381381
" # Run the command\n",
382382
" proc = await asyncio.create_subprocess_shell(\n",

scripts/optimize_registration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def train_coro(
133133
fixedmask_path=brainmask_path,
134134
output_transform_prefix=f"conversion-{index:04d}",
135135
**align_kwargs,
136-
)
136+
).cmdline
137137

138138
tasks.append(
139139
ants(

src/nifreeze/data/dmri.py

+109-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,30 @@
3737

3838
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
3939

40+
DEFAULT_CLIP_PERCENTILE = 75
41+
"""Upper percentile threshold for intensity clipping."""
42+
43+
DEFAULT_MIN_S0 = 1e-5
44+
"""Minimum value when considering the :math:`S_{0}` DWI signal."""
45+
46+
DEFAULT_MAX_S0 = 1.0
47+
"""Maximum value when considering the :math:`S_{0}` DWI signal."""
48+
49+
DEFAULT_LOWB_THRESHOLD = 50
50+
"""The lower bound for the b-value so that the orientation is considered a DW volume."""
51+
52+
DEFAULT_HIGHB_THRESHOLD = 8000
53+
"""A b-value cap for DWI data."""
54+
55+
DEFAULT_NUM_BINS = 15
56+
"""Number of bins to classify b-values."""
57+
58+
DEFAULT_MULTISHELL_BIN_COUNT_THR = 7
59+
"""Default bin count to consider a multishell scheme."""
60+
61+
DTI_MIN_ORIENTATIONS = 6
62+
"""Minimum number of nonzero b-values in a DWI dataset."""
63+
4064

4165
@attr.s(slots=True)
4266
class DWI(BaseDataset):
@@ -221,7 +245,7 @@ def load(
221245
bvec_file: Path | str | None = None,
222246
bval_file: Path | str | None = None,
223247
b0_file: Path | str | None = None,
224-
b0_thres: float = 50.0,
248+
b0_thres: float = DEFAULT_LOWB_THRESHOLD,
225249
) -> DWI:
226250
"""
227251
Load DWI data and construct a DWI object.
@@ -337,3 +361,87 @@ def load(
337361
dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
338362

339363
return dwi_obj
364+
365+
366+
def find_shelling_scheme(
367+
bvals,
368+
num_bins=DEFAULT_NUM_BINS,
369+
multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR,
370+
bval_cap=DEFAULT_HIGHB_THRESHOLD,
371+
):
372+
"""
373+
Find the shelling scheme on the given b-values.
374+
375+
Computes the histogram of the b-values according to ``num_bins``
376+
and depending on the nonempty bin count, classify the shelling scheme
377+
as single-shell if they are 2 (low-b and a shell); multi-shell if they are
378+
below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise.
379+
380+
Parameters
381+
----------
382+
bvals : :obj:`list` or :obj:`~numpy.ndarray`
383+
List or array of b-values.
384+
num_bins : :obj:`int`, optional
385+
Number of bins.
386+
multishell_nonempty_bin_count_thr : :obj:`int`, optional
387+
Bin count to consider a multi-shell scheme.
388+
389+
Returns
390+
-------
391+
scheme : :obj:`str`
392+
Shelling scheme.
393+
bval_groups : :obj:`list`
394+
List of grouped b-values.
395+
bval_estimated : :obj:`list`
396+
List of 'estimated' b-values as the median value of each b-value group.
397+
398+
"""
399+
400+
# Bin the b-values: use -1 as the lower bound to be able to appropriately
401+
# include b0 values
402+
hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap)))
403+
404+
# Collect values in each bin
405+
bval_groups = []
406+
bval_estimated = []
407+
for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False):
408+
# Add only if a nonempty b-values mask
409+
if (mask := (bvals > lower) & (bvals <= upper)).sum():
410+
bval_groups.append(bvals[mask])
411+
bval_estimated.append(np.median(bvals[mask]))
412+
413+
nonempty_bins = len(bval_groups)
414+
415+
if nonempty_bins < 2:
416+
raise ValueError("DWI must have at least one high-b shell")
417+
418+
if nonempty_bins == 2:
419+
scheme = "single-shell"
420+
elif nonempty_bins < multishell_nonempty_bin_count_thr:
421+
scheme = "multi-shell"
422+
else:
423+
scheme = "DSI"
424+
425+
return scheme, bval_groups, bval_estimated
426+
427+
428+
def _rasb2dipy(gradient):
429+
import warnings
430+
431+
gradient = np.asanyarray(gradient)
432+
if gradient.ndim == 1:
433+
if gradient.size != 4:
434+
raise ValueError("Missing gradient information.")
435+
gradient = gradient[..., np.newaxis]
436+
437+
if gradient.shape[0] != 4:
438+
gradient = gradient.T
439+
elif gradient.shape == (4, 4):
440+
print("Warning: make sure gradient information is not transposed!")
441+
442+
with warnings.catch_warnings():
443+
from dipy.core.gradients import gradient_table
444+
445+
warnings.filterwarnings("ignore", category=UserWarning)
446+
retval = gradient_table(gradient[3, :], bvecs=gradient[:3, :].T)
447+
return retval

src/nifreeze/estimator.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232

3333
from nifreeze.data.base import BaseDataset
3434
from nifreeze.model.base import BaseModel, ModelFactory
35-
from nifreeze.registration.ants import _prepare_registration_data, _run_registration
35+
from nifreeze.registration.ants import (
36+
_prepare_registration_data,
37+
_run_registration,
38+
)
3639
from nifreeze.utils import iterators
3740

3841

@@ -60,7 +63,7 @@ def run(self, dataset: BaseDataset, **kwargs):
6063
class Estimator:
6164
"""Estimates rigid-body head-motion and distortions derived from eddy-currents."""
6265

63-
__slots__ = ("_model", "_strategy", "_dataset", "_prev", "_model_kwargs", "_align_kwargs")
66+
__slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
6467

6568
def __init__(
6669
self,
@@ -111,29 +114,37 @@ def run(self, dataset: BaseDataset, **kwargs):
111114
**self._model_kwargs,
112115
)
113116

114-
if self._model.is_static:
115-
self._model.fit(dataset, **kwargs)
116-
117117
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
118118

119119
dataset_length = len(dataset)
120120
with TemporaryDirectory() as tmp_dir:
121121
print(f"Processing in <{tmp_dir}>")
122122
ptmp_dir = Path(tmp_dir)
123+
124+
bmask_path = None
125+
if dataset.brainmask is not None:
126+
import nibabel as nb
127+
128+
bmask_path = ptmp_dir / "brainmask.nii.gz"
129+
nb.Nifti1Image(
130+
dataset.brainmask.astype("uint8"), dataset.affine, None
131+
).to_filename(bmask_path)
132+
123133
with tqdm(total=dataset_length, unit="vols.") as pbar:
124134
# run a original-to-synthetic affine registration
125135
for i in index_iter:
126136
pbar.set_description_str(f"Fit and predict vol. <{i}>")
127137

128138
# fit the model
129-
reference, predicted = self._model.fit_predict(
139+
test_set = dataset[i]
140+
predicted = self._model.fit_predict(
130141
i,
131142
n_jobs=n_jobs,
132143
)
133144

134145
# prepare data for running ANTs
135-
fixed, moving = _prepare_registration_data(
136-
reference,
146+
predicted_path, volume_path, init_path = _prepare_registration_data(
147+
test_set[0],
137148
predicted,
138149
dataset.affine,
139150
i,
@@ -144,14 +155,13 @@ def run(self, dataset: BaseDataset, **kwargs):
144155
pbar.set_description_str(f"Realign vol. <{i}>")
145156

146157
xform = _run_registration(
147-
fixed,
148-
moving,
149-
dataset.brainmask,
150-
dataset.motion_affines,
151-
dataset.affine,
152-
dataset.dataobj.shape[:3],
158+
predicted_path,
159+
volume_path,
153160
i,
154161
ptmp_dir,
162+
init_affine=init_path,
163+
fixedmask_path=bmask_path,
164+
output_transform_prefix=f"ants-{i:05d}",
155165
**kwargs,
156166
)
157167

src/nifreeze/model/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"""Data models."""
2424

2525
from nifreeze.model.base import (
26-
AverageModel,
26+
ExpectationModel,
2727
ModelFactory,
2828
TrivialModel,
2929
)
@@ -37,7 +37,7 @@
3737

3838
__all__ = (
3939
"ModelFactory",
40-
"AverageModel",
40+
"ExpectationModel",
4141
"AverageDWIModel",
4242
"DKIModel",
4343
"DTIModel",

src/nifreeze/model/_dipy.py

-22
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
from __future__ import annotations
2626

27-
import warnings
28-
2927
import numpy as np
3028
from dipy.core.gradients import GradientTable
3129
from dipy.reconst.base import ReconstModel
@@ -268,23 +266,3 @@ def predict(
268266
269267
"""
270268
return gp_prediction(self.model, gtab, mask=self.mask)
271-
272-
273-
def _rasb2dipy(gradient):
274-
gradient = np.asanyarray(gradient)
275-
if gradient.ndim == 1:
276-
if gradient.size != 4:
277-
raise ValueError("Missing gradient information.")
278-
gradient = gradient[..., np.newaxis]
279-
280-
if gradient.shape[0] != 4:
281-
gradient = gradient.T
282-
elif gradient.shape == (4, 4):
283-
print("Warning: make sure gradient information is not transposed!")
284-
285-
with warnings.catch_warnings():
286-
from dipy.core.gradients import gradient_table
287-
288-
warnings.filterwarnings("ignore", category=UserWarning)
289-
retval = gradient_table(gradient[3, :], bvecs=gradient[:3, :].T)
290-
return retval

0 commit comments

Comments
 (0)