Skip to content

Commit 668e240

Browse files
oestebaneffigies
andcommitted
Apply suggestions from code review
Co-authored-by: Chris Markiewicz <[email protected]>
1 parent 97cfc2f commit 668e240

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/nifreeze/estimator.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from pathlib import Path
2828
from tempfile import TemporaryDirectory
29-
from typing import Self
29+
from typing import Self, TypeVar
3030

3131
from tqdm import tqdm
3232

@@ -38,11 +38,13 @@
3838
)
3939
from nifreeze.utils import iterators
4040

41+
DatasetT = TypeVar("DatasetT", bound=BaseDataset)
42+
4143

4244
class Filter:
4345
"""Alters an input data object (e.g., downsampling)."""
4446

45-
def run(self, dataset: BaseDataset, **kwargs):
47+
def run(self, dataset: DatasetT, **kwargs) -> DatasetT:
4648
"""
4749
Trigger execution of the designated filter.
4850
@@ -53,8 +55,8 @@ def run(self, dataset: BaseDataset, **kwargs):
5355
5456
Returns
5557
-------
56-
:obj:`~nifreeze.estimator.Estimator`
57-
The estimator, after fitting.
58+
dataset : :obj:`~nifreeze.data.base.BaseDataset`
59+
The dataset, after filtering.
5860
5961
"""
6062
return dataset
@@ -69,7 +71,7 @@ def __init__(
6971
self,
7072
model: BaseModel | str,
7173
strategy: str = "random",
72-
prev: Self | None = None,
74+
prev: Estimator | Filter | None = None,
7375
model_kwargs: dict | None = None,
7476
**kwargs,
7577
):
@@ -79,7 +81,7 @@ def __init__(
7981
self._model_kwargs = model_kwargs or {}
8082
self._align_kwargs = kwargs or {}
8183

82-
def run(self, dataset: BaseDataset, **kwargs):
84+
def run(self, dataset: DatasetT, **kwargs) -> Self:
8385
"""
8486
Trigger execution of the workflow this estimator belongs.
8587

src/nifreeze/model/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323
"""Base infrastructure for nifreeze's models."""
2424

25+
from abc import abstractmethod
2526
from warnings import warn
2627

2728
import numpy as np
@@ -95,7 +96,8 @@ def __init__(self, dataset, **kwargs):
9596
if dataset.brainmask is None:
9697
warn(mask_absence_warn_msg, stacklevel=2)
9798

98-
def fit_predict(self, *_, **kwargs):
99+
@abstractmethod
100+
def fit_predict(self, index, **kwargs) -> np.ndarray:
99101
"""Fit and predict the indicate index of the dataset (abstract signature)."""
100102
raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.")
101103

@@ -139,7 +141,7 @@ def __init__(self, dataset, stat="median", **kwargs):
139141
super().__init__(dataset, **kwargs)
140142
self._stat = stat
141143

142-
def fit_predict(self, index, *_, **kwargs):
144+
def fit_predict(self, index, **kwargs):
143145
"""
144146
Return the expectation map.
145147

0 commit comments

Comments
 (0)