Skip to content

Commit 18e9de5

Browse files
authored
Merge pull request #68 from jhlegarreta/AddMiscTypeHints
ENH: Add type hints across miscellaneous methods
2 parents 7e363d8 + f005b7e commit 18e9de5

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

Diff for: src/nifreeze/data/dmri.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import h5py
3434
import nibabel as nb
3535
import numpy as np
36+
import numpy.typing as npt
3637
from nibabel.spatialimages import SpatialImage
3738
from nitransforms.linear import Affine
3839

@@ -369,11 +370,11 @@ def load(
369370

370371

371372
def find_shelling_scheme(
372-
bvals,
373-
num_bins=DEFAULT_NUM_BINS,
374-
multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR,
375-
bval_cap=DEFAULT_HIGHB_THRESHOLD,
376-
):
373+
bvals: np.ndarray,
374+
num_bins: int = DEFAULT_NUM_BINS,
375+
multishell_nonempty_bin_count_thr: int = DEFAULT_MULTISHELL_BIN_COUNT_THR,
376+
bval_cap: float = DEFAULT_HIGHB_THRESHOLD,
377+
) -> tuple[str, list[npt.NDArray[np.floating]], list[np.floating]]:
377378
"""
378379
Find the shelling scheme on the given b-values.
379380
@@ -390,7 +391,7 @@ def find_shelling_scheme(
390391
Number of bins.
391392
multishell_nonempty_bin_count_thr : :obj:`int`, optional
392393
Bin count to consider a multi-shell scheme.
393-
bval_cap : :obj:`int`, optional
394+
bval_cap : :obj:`float`, optional
394395
Maximum b-value to be considered in a multi-shell scheme.
395396
396397
Returns

Diff for: src/nifreeze/model/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ModelFactory:
3636
"""A factory for instantiating data models."""
3737

3838
@staticmethod
39-
def init(model=None, **kwargs):
39+
def init(model: str | None = None, **kwargs):
4040
"""
4141
Instantiate a diffusion model.
4242
@@ -136,7 +136,7 @@ def __init__(self, dataset, stat="median", **kwargs):
136136
super().__init__(dataset, **kwargs)
137137
self._stat = stat
138138

139-
def fit_predict(self, index, **kwargs):
139+
def fit_predict(self, index: int, **kwargs):
140140
"""
141141
Return the expectation map.
142142

Diff for: src/nifreeze/model/dmri.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nifreeze.data.dmri import (
3030
DEFAULT_CLIP_PERCENTILE,
3131
DTI_MIN_ORIENTATIONS,
32+
DWI,
3233
)
3334
from nifreeze.model.base import BaseModel, ExpectationModel
3435

@@ -51,7 +52,7 @@ class BaseDWIModel(BaseModel):
5152
"_modelargs": "Arguments acceptable by the underlying DIPY-like model.",
5253
}
5354

54-
def __init__(self, dataset, **kwargs):
55+
def __init__(self, dataset: DWI, **kwargs):
5556
r"""Initialization.
5657
5758
Parameters
@@ -117,7 +118,7 @@ def _fit(self, index, n_jobs=None, **kwargs):
117118
self._model = None # Preempt further actions on the model
118119
return n_jobs
119120

120-
def fit_predict(self, index, **kwargs):
121+
def fit_predict(self, index: int, **kwargs):
121122
"""
122123
Predict asynchronously chunk-by-chunk the diffusion signal.
123124
@@ -140,7 +141,7 @@ def fit_predict(self, index, **kwargs):
140141
if n_models == 1:
141142
predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0}))
142143
else:
143-
S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models
144+
S0 = np.array_split(S0, n_models) if S0 is not None else np.full(n_models, None)
144145

145146
predicted = [None] * n_models
146147

@@ -173,7 +174,15 @@ class AverageDWIModel(ExpectationModel):
173174

174175
__slots__ = ("_th_low", "_th_high", "_detrend")
175176

176-
def __init__(self, dataset, stat="median", th_low=100, th_high=100, detrend=False, **kwargs):
177+
def __init__(
178+
self,
179+
dataset: DWI,
180+
stat: str = "median",
181+
th_low: float = 100.0,
182+
th_high: float = 100.0,
183+
detrend: bool = False,
184+
**kwargs,
185+
):
177186
r"""
178187
Implement object initialization.
179188
@@ -183,10 +192,10 @@ def __init__(self, dataset, stat="median", th_low=100, th_high=100, detrend=Fals
183192
Reference to a DWI object.
184193
stat : :obj:`str`, optional
185194
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.
186-
th_low : :obj:`numbers.Number`, optional
195+
th_low : :obj:`float`, optional
187196
A lower bound for the b-value corresponding to the diffusion weighted images
188197
that will be averaged.
189-
th_high : :obj:`numbers.Number`, optional
198+
th_high : :obj:`float`, optional
190199
An upper bound for the b-value corresponding to the diffusion weighted images
191200
that will be averaged.
192201
detrend : :obj:`bool`, optional

0 commit comments

Comments
 (0)