@@ -74,7 +74,16 @@ def run(self, dataset: DatasetT, **kwargs) -> DatasetT:
7474class Estimator :
7575 """Orchestrates components for a single estimation step."""
7676
77- __slots__ = ("_model" , "_single_fit" , "_strategy" , "_prev" , "_model_kwargs" , "_align_kwargs" )
77+ __slots__ = (
78+ "_model" ,
79+ "_single_fit" ,
80+ "_strategy" ,
81+ "_prev" ,
82+ "_model_kwargs" ,
83+ "_align_kwargs" ,
84+ "_start_index" ,
85+ "_stop_index" ,
86+ )
7887
7988 def __init__ (
8089 self ,
@@ -83,6 +92,8 @@ def __init__(
8392 prev : Estimator | Filter | None = None ,
8493 model_kwargs : dict | None = None ,
8594 single_fit : bool = False ,
95+ start_index : int = 0 ,
96+ stop_index : int | None = None ,
8697 ** kwargs ,
8798 ):
8899 self ._model = model
@@ -92,6 +103,9 @@ def __init__(
92103 self ._model_kwargs = model_kwargs or {}
93104 self ._align_kwargs = kwargs or {}
94105
106+ self ._start_index = start_index
107+ self ._stop_index = stop_index
108+
95109 def run (self , dataset : DatasetT , ** kwargs ) -> Self :
96110 """
97111 Trigger execution of the workflow this estimator belongs.
@@ -118,14 +132,26 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
118132 num_voxels = dataset .brainmask .sum () if dataset .brainmask is not None else dataset .size3d
119133 chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1 )
120134
135+ # Calculate the size parameter for the iterator (exclusive upper bound)
136+ n = len (dataset )
137+ size = n
138+ start_index = self ._start_index or 0
139+ stop_index = (
140+ None
141+ if self ._stop_index is None
142+ else (n + self ._stop_index if self ._stop_index < 0 else self ._stop_index )
143+ )
144+
121145 # Prepare iterator
122146 iterfunc = getattr (iterators , f"{ self ._strategy } _iterator" )
123147 index_iter = iterfunc (
124- size = len ( dataset ) ,
148+ size = size ,
125149 bvals = kwargs .pop ("bvals" , None ),
126150 uptake = kwargs .pop ("uptake" , None ),
127151 seed = kwargs .get ("seed" , None ),
128152 round_decimals = kwargs .pop ("round_decimals" , iterators .DEFAULT_ROUND_DECIMALS ),
153+ start_index = start_index ,
154+ stop_index = stop_index ,
129155 )
130156
131157 # Initialize model
@@ -167,7 +193,12 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
167193 kwargs ["num_threads" ] = n_threads
168194 kwargs = self ._align_kwargs | kwargs
169195
170- dataset_length = len (dataset )
196+ # Calculate effective dataset length for progress bar
197+ if self ._stop_index is not None :
198+ dataset_length = self ._stop_index - self ._start_index
199+ else :
200+ dataset_length = len (dataset ) - self ._start_index
201+
171202 with TemporaryDirectory () as tmp_dir :
172203 print (f"Processing in <{ tmp_dir } >" )
173204 ptmp_dir = Path (tmp_dir )
0 commit comments