Skip to content

Commit 544249b

Browse files
committed
fix: continue adaptations to new Estimator
1 parent 9706c1e commit 544249b

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

src/nifreeze/estimator.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,28 @@
3939
class Filter:
4040
"""Alters an input data object (e.g., downsampling)."""
4141

42+
def run(self, dataset: BaseDataset, **kwargs):
43+
"""
44+
Trigger execution of the designated filter.
45+
46+
Parameters
47+
----------
48+
dataset : :obj:`~nifreeze.data.base.BaseDataset`
49+
The input dataset this estimator operates on.
50+
51+
Returns
52+
-------
53+
:obj:`~nifreeze.estimator.Estimator`
54+
The estimator, after fitting.
55+
56+
"""
57+
return dataset
58+
4259

4360
class Estimator:
4461
"""Estimates rigid-body head-motion and distortions derived from eddy-currents."""
4562

46-
__slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
63+
__slots__ = ("_model", "_strategy", "_dataset", "_prev", "_model_kwargs", "_align_kwargs")
4764

4865
def __init__(
4966
self,
@@ -56,10 +73,24 @@ def __init__(
5673
self._model = model
5774
self._prev = prev
5875
self._strategy = strategy
59-
self._model_kwargs = model_kwargs
60-
self._align_kwargs = kwargs
76+
self._model_kwargs = model_kwargs or {}
77+
self._align_kwargs = kwargs or {}
6178

6279
def run(self, dataset: BaseDataset, **kwargs):
80+
"""
81+
Trigger execution of the workflow this estimator belongs.
82+
83+
Parameters
84+
----------
85+
dataset : :obj:`~nifreeze.data.base.BaseDataset`
86+
The input dataset this estimator operates on.
87+
88+
Returns
89+
-------
90+
:obj:`~nifreeze.estimator.Estimator`
91+
The estimator, after fitting.
92+
93+
"""
6394
if self._prev is not None:
6495
result = self._prev.run(dataset, **kwargs)
6596
if isinstance(self._prev, Filter):
@@ -69,7 +100,7 @@ def run(self, dataset: BaseDataset, **kwargs):
69100

70101
# Prepare iterator
71102
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
72-
index_iter = iterfunc(dataset, seed=kwargs.get("seed", None))
103+
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))
73104

74105
# Initialize model
75106
if isinstance(self._model, str):

src/nifreeze/model/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ModelFactory:
3333
"""A factory for instantiating data models."""
3434

3535
@staticmethod
36-
def init(model="DTI", **kwargs):
36+
def init(model=None, **kwargs):
3737
"""
3838
Instantiate a diffusion model.
3939
@@ -49,6 +49,9 @@ def init(model="DTI", **kwargs):
4949
A model object compliant with DIPY's interface.
5050
5151
"""
52+
if model is None:
53+
raise RuntimeError("No model identifier provided.")
54+
5255
if model.lower() in ("s0", "b0"):
5356
return TrivialModel(predicted=kwargs.pop("S0"), gtab=kwargs.pop("gtab"))
5457

@@ -143,7 +146,7 @@ def fit(self, data, **kwargs):
143146
"""Do nothing."""
144147

145148
def predict(self, *_, **kwargs):
146-
"""Return the *b=0* map."""
149+
"""Return the reference map."""
147150

148151
# No need to check fit (if not fitted, has raised already)
149152
return self._predicted

test/test_integration.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path):
7070
brainmask=dwdata.brainmask,
7171
)
7272

73-
estimator = Estimator(dwi_motion, model="b0")
73+
estimator = Estimator("b0")
7474
estimator.run(
75-
data=dwi_motion,
76-
models=("b0",),
75+
dwi_motion,
7776
seed=None,
7877
align_kwargs={
7978
"config_file": "b0-to-b0_level0.json",

0 commit comments

Comments
 (0)