Skip to content

Commit

Permalink
Merge pull request #54 from jhlegarreta/TestAbsentMaskWarning
Browse files Browse the repository at this point in the history
ENH: Test absent mask warning
  • Loading branch information
oesteban authored Jan 23, 2025
2 parents 72ae039 + b78f9d8 commit f82305c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ filterwarnings = [
"ignore:Updating b0_threshold to.*:UserWarning",
# scikit-learn
"ignore:The optimal value found for dimension.*:sklearn.exceptions.ConvergenceWarning",
# masks
"ignore:No mask provided;.*:UserWarning",
]


Expand Down
9 changes: 5 additions & 4 deletions src/nifreeze/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@

from nifreeze.exceptions import ModelNotFittedError

mask_absence_warn_msg = (
"No mask provided; consider using a mask to avoid issues in model optimization."
)


class ModelFactory:
"""A factory for instantiating data models."""
Expand Down Expand Up @@ -98,10 +102,7 @@ def __init__(self, mask=None, **kwargs):

# Setup brain mask
if mask is None:
warn(
"No mask provided; consider using a mask to avoid issues in model optimization.",
stacklevel=2,
)
warn(mask_absence_warn_msg, stacklevel=2)

self._mask = mask

Expand Down
16 changes: 13 additions & 3 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#
"""Unit tests exercising models."""

import contextlib

import numpy as np
import pytest
from dipy.sims.voxel import single_tensor
Expand All @@ -31,11 +33,13 @@
from nifreeze.data.splitting import lovo_split
from nifreeze.exceptions import ModelNotFittedError
from nifreeze.model._dipy import GaussianProcessModel
from nifreeze.model.base import mask_absence_warn_msg
from nifreeze.model.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0
from nifreeze.testing import simulations as _sim


def test_trivial_model():
@pytest.mark.parametrize("use_mask", (False, True))
def test_trivial_model(use_mask):
"""Check the implementation of the trivial B0 model."""

rng = np.random.default_rng(1234)
Expand All @@ -45,7 +49,12 @@ def test_trivial_model():
model.TrivialModel()

size = (2, 2, 2)
mask = np.ones(size, dtype=bool)
mask = None
if use_mask:
mask = np.ones(size, dtype=bool)
context = contextlib.nullcontext()
else:
context = pytest.warns(UserWarning, match=mask_absence_warn_msg)

_S0 = rng.normal(size=size)

Expand All @@ -55,7 +64,8 @@ def test_trivial_model():
a_max=DEFAULT_MAX_S0,
)

tmodel = model.TrivialModel(mask=mask, predicted=_clipped_S0)
with context:
tmodel = model.TrivialModel(mask=mask, predicted=_clipped_S0)

data = None
assert tmodel.fit(data) is None
Expand Down

0 comments on commit f82305c

Please sign in to comment.