-
Notifications
You must be signed in to change notification settings - Fork 4
REF: Refactor filters into filtering
#71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -28,8 +28,12 @@ | |||||||||||
from scipy.ndimage import median_filter | ||||||||||||
from skimage.morphology import ball | ||||||||||||
|
||||||||||||
from nifreeze.data.dmri import DEFAULT_CLIP_PERCENTILE | ||||||||||||
|
||||||||||||
DEFAULT_DTYPE = "int16" | ||||||||||||
"""The default image's data type.""" | ||||||||||||
BVAL_ATOL = 100.0 | ||||||||||||
"""b-value tolerance value.""" | ||||||||||||
|
||||||||||||
|
||||||||||||
def advanced_clip( | ||||||||||||
|
@@ -96,3 +100,134 @@ | |||||||||||
data = np.round(255 * data).astype(dtype) | ||||||||||||
|
||||||||||||
return data | ||||||||||||
|
||||||||||||
|
||||||||||||
def robust_minmax_normalization( | ||||||||||||
data: np.ndarray, mask: np.ndarray | None = None, p_min: float = 5.0, p_max: float = 95.0 | ||||||||||||
) -> np.ndarray: | ||||||||||||
r"""Normalize min-max percentiles of each volume to the grand min-max | ||||||||||||
percentiles. | ||||||||||||
|
||||||||||||
Robust min/max normalization of the volumes in the dataset following: | ||||||||||||
|
||||||||||||
.. math:: | ||||||||||||
\text{data}_{\text{normalized}} = \frac{(\text{data} - p_{min}) \cdot p_{\text{mean}}}{p_{\text{range}}} + p_{min}^{\text{mean}} | ||||||||||||
|
||||||||||||
where | ||||||||||||
|
||||||||||||
.. math:: | ||||||||||||
p_{\text{range}} = p_{max} - p_{min}, \quad p_{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{\text{range}_i}, \quad p_{min}^{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{5_i} | ||||||||||||
|
||||||||||||
If a mask is provided, only the data within the mask are considered. | ||||||||||||
|
||||||||||||
Parameters | ||||||||||||
---------- | ||||||||||||
data : :obj:`~numpy.ndarray` | ||||||||||||
Data to be normalized. | ||||||||||||
mask : :obj:`~numpy.ndarray`, optional | ||||||||||||
Mask. If provided, only the data within the mask are considered. | ||||||||||||
p_min : :obj:`float`, optional | ||||||||||||
The lower percentile value for normalization. | ||||||||||||
p_max : :obj:`float`, optional | ||||||||||||
The upper percentile value for normalization. | ||||||||||||
|
||||||||||||
Returns | ||||||||||||
------- | ||||||||||||
:obj:`~numpy.ndarray` | ||||||||||||
Normalized data. | ||||||||||||
""" | ||||||||||||
|
||||||||||||
data = data.copy().astype("float32") | ||||||||||||
reshaped_data = data.reshape((-1, data.shape[-1])) if mask is None else data[mask] | ||||||||||||
p5 = np.percentile(reshaped_data, p_min, axis=0) | ||||||||||||
p95 = np.percentile(reshaped_data, p_max, axis=0) - p5 | ||||||||||||
return (data - p5) * p95.mean() / p95 + p5.mean() | ||||||||||||
|
||||||||||||
|
||||||||||||
def grand_mean_normalization( | ||||||||||||
data: np.ndarray, mask: np.ndarray | None = None, center: float = DEFAULT_CLIP_PERCENTILE | ||||||||||||
) -> np.ndarray: | ||||||||||||
"""Robust grand mean normalization. | ||||||||||||
|
||||||||||||
Regresses out global signal differences so that data are normalized and | ||||||||||||
centered around a given value. | ||||||||||||
|
||||||||||||
If a mask is provided, only the data within the mask are considered. | ||||||||||||
|
||||||||||||
Parameters | ||||||||||||
---------- | ||||||||||||
data : :obj:`~numpy.ndarray` | ||||||||||||
Data to be normalized. | ||||||||||||
mask : :obj:`~numpy.ndarray`, optional | ||||||||||||
Mask. If provided, only the data within the mask are considered. | ||||||||||||
center : float, optional | ||||||||||||
Central value around which to normalize the data. | ||||||||||||
|
||||||||||||
Returns | ||||||||||||
------- | ||||||||||||
:obj:`~numpy.ndarray` | ||||||||||||
Normalized data. | ||||||||||||
""" | ||||||||||||
|
||||||||||||
volumes = data | ||||||||||||
if mask is not None: | ||||||||||||
volumes = data[..., mask] | ||||||||||||
Comment on lines
+172
to
+174
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
centers = np.median(volumes, axis=(0, 1, 2)) | ||||||||||||
reference = np.percentile(centers[centers >= 1.0], center) | ||||||||||||
centers[centers < 1.0] = reference | ||||||||||||
drift = reference / centers | ||||||||||||
return volumes * drift | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this function should filter out masked direction, right?
Suggested change
|
||||||||||||
|
||||||||||||
|
||||||||||||
def dwi_select_shells( | ||||||||||||
gradients: np.ndarray, | ||||||||||||
index: int, | ||||||||||||
atol_low: float | None = None, | ||||||||||||
atol_high: float | None = None, | ||||||||||||
) -> np.ndarray: | ||||||||||||
"""Select DWI shells around the given index and lower and upper b-value | ||||||||||||
bounds. | ||||||||||||
|
||||||||||||
Computes a boolean mask of the DWI shells around the given index with the | ||||||||||||
provided lower and upper bound b-values. | ||||||||||||
|
||||||||||||
If ``atol_low`` and ``atol_high`` are both ``None``, the returned shell mask | ||||||||||||
corresponds to the lengths of the diffusion-sensitizing gradients. | ||||||||||||
|
||||||||||||
Parameters | ||||||||||||
---------- | ||||||||||||
gradients : :obj:`~numpy.ndarray` | ||||||||||||
Gradients. | ||||||||||||
index : :obj:`int` | ||||||||||||
Index of the shell data. | ||||||||||||
atol_low : :obj:`float`, optional | ||||||||||||
A lower bound for the b-value. | ||||||||||||
atol_high : :obj:`float`, optional | ||||||||||||
An upper bound for the b-value. | ||||||||||||
|
||||||||||||
Returns | ||||||||||||
------- | ||||||||||||
shellmask : :obj:`~numpy.ndarray` | ||||||||||||
Shell mask. | ||||||||||||
""" | ||||||||||||
|
||||||||||||
jhlegarreta marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
bvalues = gradients[:, -1] | ||||||||||||
bcenter = bvalues[index] | ||||||||||||
|
||||||||||||
shellmask = np.ones(len(bvalues), dtype=bool) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
if atol_low is None and atol_high is None: | ||||||||||||
return shellmask | ||||||||||||
|
||||||||||||
atol_low = 0 if atol_low is None else atol_low | ||||||||||||
atol_high = gradients[:, -1].max() if atol_high is None else atol_high | ||||||||||||
|
||||||||||||
# Keep only bvalues within the range defined by atol_high and atol_low | ||||||||||||
shellmask[bvalues > (bcenter + atol_high)] = False | ||||||||||||
shellmask[bvalues < (bcenter - atol_low)] = False | ||||||||||||
|
||||||||||||
if not shellmask.sum(): | ||||||||||||
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.") | ||||||||||||
|
||||||||||||
return shellmask |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,11 +27,8 @@ | |
from dipy.core.gradients import gradient_table_from_bvals_bvecs | ||
from joblib import Parallel, delayed | ||
|
||
from nifreeze.data.dmri import ( | ||
DEFAULT_CLIP_PERCENTILE, | ||
DTI_MIN_ORIENTATIONS, | ||
DWI, | ||
) | ||
from nifreeze.data.dmri import DTI_MIN_ORIENTATIONS, DWI | ||
from nifreeze.data.filtering import BVAL_ATOL, dwi_select_shells, grand_mean_normalization | ||
from nifreeze.model.base import BaseModel, ExpectationModel | ||
|
||
|
||
|
@@ -183,14 +180,14 @@ | |
class AverageDWIModel(ExpectationModel): | ||
"""A trivial model that returns an average DWI volume.""" | ||
|
||
__slots__ = ("_th_low", "_th_high", "_detrend") | ||
__slots__ = ("_atol_low", "_atol_high", "_detrend") | ||
|
||
def __init__( | ||
self, | ||
dataset: DWI, | ||
stat: str = "median", | ||
th_low: float = 100.0, | ||
th_high: float = 100.0, | ||
atol_low: float = BVAL_ATOL, | ||
atol_high: float = BVAL_ATOL, | ||
detrend: bool = False, | ||
**kwargs, | ||
): | ||
|
@@ -203,10 +200,10 @@ | |
Reference to a DWI object. | ||
stat : :obj:`str`, optional | ||
Whether the summary statistic to apply is ``"mean"`` or ``"median"``. | ||
th_low : :obj:`float`, optional | ||
atol_low : :obj:`float`, optional | ||
A lower bound for the b-value corresponding to the diffusion weighted images | ||
that will be averaged. | ||
th_high : :obj:`float`, optional | ||
atol_low : :obj:`float`, optional | ||
An upper bound for the b-value corresponding to the diffusion weighted images | ||
that will be averaged. | ||
detrend : :obj:`bool`, optional | ||
|
@@ -217,38 +214,27 @@ | |
""" | ||
super().__init__(dataset, stat=stat, **kwargs) | ||
|
||
self._th_low = th_low | ||
self._th_high = th_high | ||
self._atol_low = atol_low | ||
self._atol_high = atol_high | ||
self._detrend = detrend | ||
|
||
def fit_predict(self, index, *_, **kwargs): | ||
def fit_predict(self, index: int, *_, **kwargs): | ||
"""Return the average map.""" | ||
|
||
bvalues = self._dataset.gradients[:, -1] | ||
bcenter = bvalues[index] | ||
|
||
shellmask = np.ones(len(self._dataset), dtype=bool) | ||
|
||
# Keep only bvalues within the range defined by th_high and th_low | ||
shellmask[index] = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @oesteban While writing the test, I have realized that this statement should be removed: we want to keep the value at the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the statement. Consequently, the exception here is not not raised: https://github.com/nipreps/nifreeze/actions/runs/14149369168/job/39640631654?pr=71#step:11:539 So I am wondering about the rationale behind assigning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not, please have a look at the definition https://github.com/nipreps/nifreeze/pull/71/files#diff-2197e776366f74ac177f201dce923642554f122630fabab0b45063f2f6cf1832R183 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Okay, I just came back to this and now I understand. The idea for this method is: we want to leave one volume out at each iteration, so you want to mark that specific volume for dropping in the returned mask. Does that make sense? |
||
shellmask[bvalues > (bcenter + self._th_high)] = False | ||
shellmask[bvalues < (bcenter - self._th_low)] = False | ||
# Select the summary statistic | ||
avg_func = np.median if self._stat == "median" else np.mean | ||
|
||
if not shellmask.sum(): | ||
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.") | ||
shellmask = dwi_select_shells( | ||
self._dataset.gradients, | ||
index, | ||
atol_low=self._atol_low, | ||
atol_high=self._atol_high, | ||
) | ||
|
||
shelldata = self._dataset.dataobj[..., shellmask] | ||
|
||
# Regress out global signal differences | ||
if self._detrend: | ||
centers = np.median(shelldata, axis=(0, 1, 2)) | ||
reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE) | ||
centers[centers < 1.0] = reference | ||
drift = reference / centers | ||
shelldata = shelldata * drift | ||
shelldata = grand_mean_normalization(shelldata, mask=None) | ||
|
||
# Select the summary statistic | ||
avg_func = np.median if self._stat == "median" else np.mean | ||
# Calculate the average | ||
return avg_func(shelldata, axis=-1) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- | ||
# vi: set ft=python sts=4 ts=4 sw=4 et: | ||
# | ||
# Copyright The NiPreps Developers <[email protected]> | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# We support and encourage derived works from this project, please read | ||
# about our expectations at | ||
# | ||
# https://www.nipreps.org/community/licensing/ | ||
# | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from nifreeze.data.filtering import ( | ||
BVAL_ATOL, | ||
dwi_select_shells, | ||
grand_mean_normalization, | ||
robust_minmax_normalization, | ||
) | ||
|
||
|
||
def _generate_random_choices(request, values, count): | ||
rng = request.node.rng | ||
|
||
num_elements = len(values) | ||
|
||
# Randomly distribute N among the given values | ||
partitions = rng.multinomial(count, np.ones(num_elements) / num_elements) | ||
|
||
# Create a list of selected values | ||
selected_values = [ | ||
val for val, count in zip(values, partitions, strict=True) for _ in range(count) | ||
] | ||
|
||
return sorted(selected_values) | ||
|
||
|
||
def _create_random_gtab_dataobj(request, n_gradients=10, shells=(1000, 2000, 3000), b0s=1): | ||
rng = request.node.rng | ||
|
||
# Generate a random number of elements for each shell | ||
bvals_shells = _generate_random_choices(request, shells, n_gradients) | ||
|
||
bvals = np.hstack([b0s * [0], bvals_shells]) | ||
bvecs = np.hstack([np.zeros((3, b0s)), rng.random((3, n_gradients))]) | ||
|
||
return bvals, bvecs | ||
|
||
|
||
def _random_uniform_4d_data(request, size=(32, 32, 32, 5), a=0.0, b=1.0) -> np.ndarray: | ||
"""Create 4D random uniform data for testing.""" | ||
|
||
rng = request.node.rng | ||
data = rng.random(size=size).astype(np.float32) | ||
return (b - a) * data + a | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"n_gradients, shells, index, expected_output", | ||
[(5, (1000, 2000, 3000), 3, np.asarray([False, True, True, True, True, False]))], | ||
) | ||
def test_dwi_select_shells(request, n_gradients, shells, index, expected_output): | ||
bvals, bvecs = _create_random_gtab_dataobj(request, n_gradients=n_gradients, shells=shells) | ||
|
||
gradients = np.vstack([bvecs, bvals[np.newaxis, :]], dtype="float32") | ||
|
||
shell_mask = dwi_select_shells( | ||
gradients.T, | ||
index, | ||
atol_low=BVAL_ATOL, | ||
atol_high=BVAL_ATOL, | ||
) | ||
|
||
assert np.all(shell_mask == expected_output) | ||
|
||
|
||
@pytest.mark.parametrize("a, b, mask, center", [(0.0, 2.0, None, 1)]) | ||
def test_grand_mean_normalization(request, a, b, mask, center): | ||
data = _random_uniform_4d_data(request, a=a, b=b) | ||
|
||
centers = np.median(data, axis=(0, 1, 2)) | ||
reference = np.percentile(centers[centers >= 1.0], center) | ||
centers[centers < 1.0] = reference | ||
drift = reference / centers | ||
expected_output = data * drift | ||
|
||
normalized_data = grand_mean_normalization(data, mask=mask, center=center) | ||
|
||
assert np.allclose(normalized_data, expected_output, atol=1e-6) | ||
|
||
|
||
@pytest.mark.parametrize("a, b, mask, p_min, p_max", [(0.0, 2.0, None, 5.0, 95.0)]) | ||
def test_robust_minmax_normalization(request, a, b, mask, p_min, p_max): | ||
data = _random_uniform_4d_data(request, a=a, b=b) | ||
|
||
reshaped_data = data.reshape((-1, data.shape[-1])) | ||
p5 = np.percentile(reshaped_data, p_min, axis=0) | ||
p95 = np.percentile(reshaped_data, p_max, axis=0) | ||
p_range = p95 - p5 | ||
p_mean = np.mean(p_range) | ||
p5_mean = np.mean(p5) | ||
expected_output = (data - p5) * p_mean / p_range + p5_mean | ||
|
||
normalized_data = robust_minmax_normalization(data, mask=mask, p_min=p_min, p_max=p_max) | ||
assert np.allclose(normalized_data, expected_output, atol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check here is necessary otherwise a new, unnecessary dimension is added to the data.