Skip to content

Commit 0b51bd4

Browse files
authored
Merge pull request #423 from jhlegarreta/enh/implement-start-end-index
ENH: Implement start/end index for the estimator
2 parents 96df7b3 + 6a2b038 commit 0b51bd4

5 files changed

Lines changed: 930 additions & 244 deletions

File tree

src/nifreeze/estimator.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,16 @@ def run(self, dataset: DatasetT, **kwargs) -> DatasetT:
7474
class Estimator:
7575
"""Orchestrates components for a single estimation step."""
7676

77-
__slots__ = ("_model", "_single_fit", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
77+
__slots__ = (
78+
"_model",
79+
"_single_fit",
80+
"_strategy",
81+
"_prev",
82+
"_model_kwargs",
83+
"_align_kwargs",
84+
"_start_index",
85+
"_stop_index",
86+
)
7887

7988
def __init__(
8089
self,
@@ -83,6 +92,8 @@ def __init__(
8392
prev: Estimator | Filter | None = None,
8493
model_kwargs: dict | None = None,
8594
single_fit: bool = False,
95+
start_index: int = 0,
96+
stop_index: int | None = None,
8697
**kwargs,
8798
):
8899
self._model = model
@@ -92,6 +103,9 @@ def __init__(
92103
self._model_kwargs = model_kwargs or {}
93104
self._align_kwargs = kwargs or {}
94105

106+
self._start_index = start_index
107+
self._stop_index = stop_index
108+
95109
def run(self, dataset: DatasetT, **kwargs) -> Self:
96110
"""
97111
Trigger execution of the workflow this estimator belongs.
@@ -118,14 +132,26 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
118132
num_voxels = dataset.brainmask.sum() if dataset.brainmask is not None else dataset.size3d
119133
chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1)
120134

135+
# Calculate the size parameter for the iterator (exclusive upper bound)
136+
n = len(dataset)
137+
size = n
138+
start_index = self._start_index or 0
139+
stop_index = (
140+
None
141+
if self._stop_index is None
142+
else (n + self._stop_index if self._stop_index < 0 else self._stop_index)
143+
)
144+
121145
# Prepare iterator
122146
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
123147
index_iter = iterfunc(
124-
size=len(dataset),
148+
size=size,
125149
bvals=kwargs.pop("bvals", None),
126150
uptake=kwargs.pop("uptake", None),
127151
seed=kwargs.get("seed", None),
128152
round_decimals=kwargs.pop("round_decimals", iterators.DEFAULT_ROUND_DECIMALS),
153+
start_index=start_index,
154+
stop_index=stop_index,
129155
)
130156

131157
# Initialize model
@@ -167,7 +193,12 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
167193
kwargs["num_threads"] = n_threads
168194
kwargs = self._align_kwargs | kwargs
169195

170-
dataset_length = len(dataset)
196+
# Calculate effective dataset length for progress bar
197+
if self._stop_index is not None:
198+
dataset_length = self._stop_index - self._start_index
199+
else:
200+
dataset_length = len(dataset) - self._start_index
201+
171202
with TemporaryDirectory() as tmp_dir:
172203
print(f"Processing in <{tmp_dir}>")
173204
ptmp_dir = Path(tmp_dir)

0 commit comments

Comments
 (0)