26
26
27
27
from pathlib import Path
28
28
from tempfile import TemporaryDirectory
29
- from typing import Self
29
+ from typing import Self , TypeVar
30
30
31
31
from tqdm import tqdm
32
32
38
38
)
39
39
from nifreeze .utils import iterators
40
40
41
+ DatasetT = TypeVar ("DatasetT" , bound = BaseDataset )
42
+
41
43
42
44
class Filter :
43
45
"""Alters an input data object (e.g., downsampling)."""
44
46
45
- def run (self , dataset : BaseDataset , ** kwargs ):
47
+ def run (self , dataset : DatasetT , ** kwargs ) -> DatasetT :
46
48
"""
47
49
Trigger execution of the designated filter.
48
50
@@ -53,8 +55,8 @@ def run(self, dataset: BaseDataset, **kwargs):
53
55
54
56
Returns
55
57
-------
56
- : obj:`~nifreeze.estimator.Estimator `
57
- The estimator , after fitting .
58
+ dataset : : obj:`~nifreeze.data.base.BaseDataset `
59
+ The dataset , after filtering .
58
60
59
61
"""
60
62
return dataset
@@ -69,7 +71,7 @@ def __init__(
69
71
self ,
70
72
model : BaseModel | str ,
71
73
strategy : str = "random" ,
72
- prev : Self | None = None ,
74
+ prev : Estimator | Filter | None = None ,
73
75
model_kwargs : dict | None = None ,
74
76
** kwargs ,
75
77
):
@@ -79,7 +81,7 @@ def __init__(
79
81
self ._model_kwargs = model_kwargs or {}
80
82
self ._align_kwargs = kwargs or {}
81
83
82
- def run (self , dataset : BaseDataset , ** kwargs ):
84
+ def run (self , dataset : DatasetT , ** kwargs ) -> Self :
83
85
"""
84
86
Trigger execution of the workflow this estimator belongs.
85
87
0 commit comments