Skip to content

REF: Refactor filters into filtering #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions src/nifreeze/data/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
from scipy.ndimage import median_filter
from skimage.morphology import ball

from nifreeze.data.dmri import DEFAULT_CLIP_PERCENTILE

DEFAULT_DTYPE = "int16"
"""The default image's data type."""
BVAL_ATOL = 100.0
"""b-value tolerance value."""


def advanced_clip(
Expand Down Expand Up @@ -96,3 +100,134 @@
data = np.round(255 * data).astype(dtype)

return data


def robust_minmax_normalization(
data: np.ndarray, mask: np.ndarray | None = None, p_min: float = 5.0, p_max: float = 95.0
) -> np.ndarray:
r"""Normalize min-max percentiles of each volume to the grand min-max
percentiles.

Robust min/max normalization of the volumes in the dataset following:

.. math::
\text{data}_{\text{normalized}} = \frac{(\text{data} - p_{min}) \cdot p_{\text{mean}}}{p_{\text{range}}} + p_{min}^{\text{mean}}

where

.. math::
p_{\text{range}} = p_{max} - p_{min}, \quad p_{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{\text{range}_i}, \quad p_{min}^{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{5_i}

If a mask is provided, only the data within the mask are considered.

Parameters
----------
data : :obj:`~numpy.ndarray`
Data to be normalized.
mask : :obj:`~numpy.ndarray`, optional
Mask. If provided, only the data within the mask are considered.
p_min : :obj:`float`, optional
The lower percentile value for normalization.
p_max : :obj:`float`, optional
The upper percentile value for normalization.

Returns
-------
:obj:`~numpy.ndarray`
Normalized data.
"""

data = data.copy().astype("float32")
reshaped_data = data.reshape((-1, data.shape[-1])) if mask is None else data[mask]
p5 = np.percentile(reshaped_data, p_min, axis=0)
p95 = np.percentile(reshaped_data, p_max, axis=0) - p5
return (data - p5) * p95.mean() / p95 + p5.mean()


def grand_mean_normalization(
data: np.ndarray, mask: np.ndarray | None = None, center: float = DEFAULT_CLIP_PERCENTILE
) -> np.ndarray:
"""Robust grand mean normalization.

Regresses out global signal differences so that data are normalized and
centered around a given value.

If a mask is provided, only the data within the mask are considered.

Parameters
----------
data : :obj:`~numpy.ndarray`
Data to be normalized.
mask : :obj:`~numpy.ndarray`, optional
Mask. If provided, only the data within the mask are considered.
center : float, optional
Central value around which to normalize the data.

Returns
-------
:obj:`~numpy.ndarray`
Normalized data.
"""

volumes = data
if mask is not None:
volumes = data[..., mask]

Check warning on line 174 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L174

Added line #L174 was not covered by tests
Comment on lines +173 to +174
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check here is necessary otherwise a new, unnecessary dimension is added to the data.

Comment on lines +172 to +174
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
volumes = data
if mask is not None:
volumes = data[..., mask]
mask = mask if mask is not None else np.ones(data.shape[-1], dtype=bool)
volumes = data[..., mask]


centers = np.median(volumes, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], center)
centers[centers < 1.0] = reference
drift = reference / centers
return volumes * drift
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function should filter out masked direction, right?

Suggested change
return volumes * drift
data[..., mask] = volumes * drift
return data



def dwi_select_shells(
gradients: np.ndarray,
index: int,
atol_low: float | None = None,
atol_high: float | None = None,
) -> np.ndarray:
"""Select DWI shells around the given index and lower and upper b-value
bounds.

Computes a boolean mask of the DWI shells around the given index with the
provided lower and upper bound b-values.

If ``atol_low`` and ``atol_high`` are both ``None``, the returned shell mask
corresponds to the lengths of the diffusion-sensitizing gradients.

Parameters
----------
gradients : :obj:`~numpy.ndarray`
Gradients.
index : :obj:`int`
Index of the shell data.
atol_low : :obj:`float`, optional
A lower bound for the b-value.
atol_high : :obj:`float`, optional
An upper bound for the b-value.

Returns
-------
shellmask : :obj:`~numpy.ndarray`
Shell mask.
"""

bvalues = gradients[:, -1]
bcenter = bvalues[index]

shellmask = np.ones(len(bvalues), dtype=bool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
shellmask = np.ones(len(bvalues), dtype=bool)
shellmask = np.ones(len(bvalues), dtype=bool)
shellmask[index] = False # Drop the left-out index


if atol_low is None and atol_high is None:
return shellmask

Check warning on line 221 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L221

Added line #L221 was not covered by tests

atol_low = 0 if atol_low is None else atol_low
atol_high = gradients[:, -1].max() if atol_high is None else atol_high

# Keep only bvalues within the range defined by atol_high and atol_low
shellmask[bvalues > (bcenter + atol_high)] = False
shellmask[bvalues < (bcenter - atol_low)] = False

if not shellmask.sum():
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.")

Check warning on line 231 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L231

Added line #L231 was not covered by tests

return shellmask
52 changes: 19 additions & 33 deletions src/nifreeze/model/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@
from dipy.core.gradients import gradient_table_from_bvals_bvecs
from joblib import Parallel, delayed

from nifreeze.data.dmri import (
DEFAULT_CLIP_PERCENTILE,
DTI_MIN_ORIENTATIONS,
DWI,
)
from nifreeze.data.dmri import DTI_MIN_ORIENTATIONS, DWI
from nifreeze.data.filtering import BVAL_ATOL, dwi_select_shells, grand_mean_normalization
from nifreeze.model.base import BaseModel, ExpectationModel


Expand Down Expand Up @@ -183,14 +180,14 @@
class AverageDWIModel(ExpectationModel):
"""A trivial model that returns an average DWI volume."""

__slots__ = ("_th_low", "_th_high", "_detrend")
__slots__ = ("_atol_low", "_atol_high", "_detrend")

def __init__(
self,
dataset: DWI,
stat: str = "median",
th_low: float = 100.0,
th_high: float = 100.0,
atol_low: float = BVAL_ATOL,
atol_high: float = BVAL_ATOL,
detrend: bool = False,
**kwargs,
):
Expand All @@ -203,10 +200,10 @@
Reference to a DWI object.
stat : :obj:`str`, optional
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.
th_low : :obj:`float`, optional
atol_low : :obj:`float`, optional
A lower bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
th_high : :obj:`float`, optional
atol_low : :obj:`float`, optional
An upper bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
detrend : :obj:`bool`, optional
Expand All @@ -217,38 +214,27 @@
"""
super().__init__(dataset, stat=stat, **kwargs)

self._th_low = th_low
self._th_high = th_high
self._atol_low = atol_low
self._atol_high = atol_high
self._detrend = detrend

def fit_predict(self, index, *_, **kwargs):
def fit_predict(self, index: int, *_, **kwargs):
"""Return the average map."""

bvalues = self._dataset.gradients[:, -1]
bcenter = bvalues[index]

shellmask = np.ones(len(self._dataset), dtype=bool)

# Keep only bvalues within the range defined by th_high and th_low
shellmask[index] = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oesteban While writing the test, I have realized that this statement should be removed: we want to keep the value at the index position, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the statement. Consequently, the exception here is not not raised:
https://github.com/nipreps/nifreeze/pull/71/files#diff-2197e776366f74ac177f201dce923642554f122630fabab0b45063f2f6cf1832R230-R231

https://github.com/nipreps/nifreeze/actions/runs/14149369168/job/39640631654?pr=71#step:11:539

So I am wondering about the rationale behind assigning False to the index.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as dwi_select_shells does this inside (which I imagine it does so), yes, you need to remove this because it's moved into the new function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I am wondering about the rationale behind assigning False to the index.

Okay, I just came back to this and now I understand. The idea for this method is: we want to leave one volume out at each iteration, so you want to mark that specific volume for dropping in the returned mask. Does that make sense?

shellmask[bvalues > (bcenter + self._th_high)] = False
shellmask[bvalues < (bcenter - self._th_low)] = False
# Select the summary statistic
avg_func = np.median if self._stat == "median" else np.mean

if not shellmask.sum():
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.")
shellmask = dwi_select_shells(
self._dataset.gradients,
index,
atol_low=self._atol_low,
atol_high=self._atol_high,
)

shelldata = self._dataset.dataobj[..., shellmask]

# Regress out global signal differences
if self._detrend:
centers = np.median(shelldata, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE)
centers[centers < 1.0] = reference
drift = reference / centers
shelldata = shelldata * drift
shelldata = grand_mean_normalization(shelldata, mask=None)

Check warning on line 236 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L236

Added line #L236 was not covered by tests

# Select the summary statistic
avg_func = np.median if self._stat == "median" else np.mean
# Calculate the average
return avg_func(shelldata, axis=-1)

Expand Down
118 changes: 118 additions & 0 deletions test/test_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#

import numpy as np
import pytest

from nifreeze.data.filtering import (
BVAL_ATOL,
dwi_select_shells,
grand_mean_normalization,
robust_minmax_normalization,
)


def _generate_random_choices(request, values, count):
rng = request.node.rng

num_elements = len(values)

# Randomly distribute N among the given values
partitions = rng.multinomial(count, np.ones(num_elements) / num_elements)

# Create a list of selected values
selected_values = [
val for val, count in zip(values, partitions, strict=True) for _ in range(count)
]

return sorted(selected_values)


def _create_random_gtab_dataobj(request, n_gradients=10, shells=(1000, 2000, 3000), b0s=1):
rng = request.node.rng

# Generate a random number of elements for each shell
bvals_shells = _generate_random_choices(request, shells, n_gradients)

bvals = np.hstack([b0s * [0], bvals_shells])
bvecs = np.hstack([np.zeros((3, b0s)), rng.random((3, n_gradients))])

return bvals, bvecs


def _random_uniform_4d_data(request, size=(32, 32, 32, 5), a=0.0, b=1.0) -> np.ndarray:
"""Create 4D random uniform data for testing."""

rng = request.node.rng
data = rng.random(size=size).astype(np.float32)
return (b - a) * data + a


@pytest.mark.parametrize(
"n_gradients, shells, index, expected_output",
[(5, (1000, 2000, 3000), 3, np.asarray([False, True, True, True, True, False]))],
)
def test_dwi_select_shells(request, n_gradients, shells, index, expected_output):
bvals, bvecs = _create_random_gtab_dataobj(request, n_gradients=n_gradients, shells=shells)

gradients = np.vstack([bvecs, bvals[np.newaxis, :]], dtype="float32")

shell_mask = dwi_select_shells(
gradients.T,
index,
atol_low=BVAL_ATOL,
atol_high=BVAL_ATOL,
)

assert np.all(shell_mask == expected_output)


@pytest.mark.parametrize("a, b, mask, center", [(0.0, 2.0, None, 1)])
def test_grand_mean_normalization(request, a, b, mask, center):
data = _random_uniform_4d_data(request, a=a, b=b)

centers = np.median(data, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], center)
centers[centers < 1.0] = reference
drift = reference / centers
expected_output = data * drift

normalized_data = grand_mean_normalization(data, mask=mask, center=center)

assert np.allclose(normalized_data, expected_output, atol=1e-6)


@pytest.mark.parametrize("a, b, mask, p_min, p_max", [(0.0, 2.0, None, 5.0, 95.0)])
def test_robust_minmax_normalization(request, a, b, mask, p_min, p_max):
data = _random_uniform_4d_data(request, a=a, b=b)

reshaped_data = data.reshape((-1, data.shape[-1]))
p5 = np.percentile(reshaped_data, p_min, axis=0)
p95 = np.percentile(reshaped_data, p_max, axis=0)
p_range = p95 - p5
p_mean = np.mean(p_range)
p5_mean = np.mean(p5)
expected_output = (data - p5) * p_mean / p_range + p5_mean

normalized_data = robust_minmax_normalization(data, mask=mask, p_min=p_min, p_max=p_max)
assert np.allclose(normalized_data, expected_output, atol=1e-6)
Loading
Loading