Skip to content

Commit ce6f3b2

Browse files
oestebaneffigies
andcommitted
Apply suggestions from code review
Co-authored-by: Chris Markiewicz <[email protected]>
1 parent 97cfc2f commit ce6f3b2

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

src/nifreeze/estimator.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from pathlib import Path
2828
from tempfile import TemporaryDirectory
29-
from typing import Self
29+
from typing import Self, TypeVar
3030

3131
from tqdm import tqdm
3232

@@ -39,10 +39,13 @@
3939
from nifreeze.utils import iterators
4040

4141

42+
DatasetT = TypeVar("DatasetT", bound=BaseDataset)
43+
44+
4245
class Filter:
4346
"""Alters an input data object (e.g., downsampling)."""
4447

45-
def run(self, dataset: BaseDataset, **kwargs):
48+
def run(self, dataset: DatasetT, **kwargs) -> DatasetT:
4649
"""
4750
Trigger execution of the designated filter.
4851
@@ -53,8 +56,8 @@ def run(self, dataset: BaseDataset, **kwargs):
5356
5457
Returns
5558
-------
56-
:obj:`~nifreeze.estimator.Estimator`
57-
The estimator, after fitting.
59+
dataset : :obj:`~nifreeze.data.base.BaseDataset`
60+
The dataset, after filtering.
5861
5962
"""
6063
return dataset
@@ -69,7 +72,7 @@ def __init__(
6972
self,
7073
model: BaseModel | str,
7174
strategy: str = "random",
72-
prev: Self | None = None,
75+
prev: Estimator | Filter | None = None,
7376
model_kwargs: dict | None = None,
7477
**kwargs,
7578
):
@@ -79,7 +82,7 @@ def __init__(
7982
self._model_kwargs = model_kwargs or {}
8083
self._align_kwargs = kwargs or {}
8184

82-
def run(self, dataset: BaseDataset, **kwargs):
85+
def run(self, dataset: DatasetT, **kwargs) -> Self:
8386
"""
8487
Trigger execution of the workflow this estimator belongs.
8588

src/nifreeze/model/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323
"""Base infrastructure for nifreeze's models."""
2424

25+
from abc import abstractmethod
2526
from warnings import warn
2627

2728
import numpy as np
@@ -95,7 +96,8 @@ def __init__(self, dataset, **kwargs):
9596
if dataset.brainmask is None:
9697
warn(mask_absence_warn_msg, stacklevel=2)
9798

98-
def fit_predict(self, *_, **kwargs):
99+
@abstractmethod
100+
def fit_predict(self, index, **kwargs) -> np.ndarray:
99101
"""Fit and predict the indicate index of the dataset (abstract signature)."""
100102
raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.")
101103

@@ -139,7 +141,7 @@ def __init__(self, dataset, stat="median", **kwargs):
139141
super().__init__(dataset, **kwargs)
140142
self._stat = stat
141143

142-
def fit_predict(self, index, *_, **kwargs):
144+
def fit_predict(self, index, **kwargs):
143145
"""
144146
Return the expectation map.
145147

0 commit comments

Comments
 (0)