@@ -179,40 +179,17 @@ def get_output_sample_id(self, input_sample_id: SampleId):
179
179
self .model_description .id or self .model_description .name
180
180
)
181
181
182
- def predict_sample_with_blocking (
182
+ def predict_sample_with_fixed_blocking (
183
183
self ,
184
184
sample : Sample ,
185
+ input_block_shape : Mapping [MemberId , Mapping [AxisId , int ]],
186
+ * ,
185
187
skip_preprocessing : bool = False ,
186
188
skip_postprocessing : bool = False ,
187
- ns : Optional [
188
- Union [
189
- v0_5 .ParameterizedSize_N ,
190
- Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize_N ],
191
- ]
192
- ] = None ,
193
- batch_size : Optional [int ] = None ,
194
189
) -> Sample :
195
- """predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
196
190
if not skip_preprocessing :
197
191
self .apply_preprocessing (sample )
198
192
199
- if isinstance (self .model_description , v0_4 .ModelDescr ):
200
- raise NotImplementedError (
201
- "predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}"
202
- )
203
-
204
- ns = ns or self ._default_ns
205
- if isinstance (ns , int ):
206
- ns = {
207
- (ipt .id , a .id ): ns
208
- for ipt in self .model_description .inputs
209
- for a in ipt .axes
210
- if isinstance (a .size , v0_5 .ParameterizedSize )
211
- }
212
- input_block_shape = self .model_description .get_tensor_sizes (
213
- ns , batch_size or self ._default_batch_size
214
- ).inputs
215
-
216
193
n_blocks , input_blocks = sample .split_into_blocks (
217
194
input_block_shape ,
218
195
halo = self ._default_input_halo ,
@@ -239,6 +216,47 @@ def predict_sample_with_blocking(
239
216
240
217
return predicted_sample
241
218
219
+ def predict_sample_with_blocking (
220
+ self ,
221
+ sample : Sample ,
222
+ skip_preprocessing : bool = False ,
223
+ skip_postprocessing : bool = False ,
224
+ ns : Optional [
225
+ Union [
226
+ v0_5 .ParameterizedSize_N ,
227
+ Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize_N ],
228
+ ]
229
+ ] = None ,
230
+ batch_size : Optional [int ] = None ,
231
+ ) -> Sample :
232
+ """predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
233
+
234
+ if isinstance (self .model_description , v0_4 .ModelDescr ):
235
+ raise NotImplementedError (
236
+ "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
237
+ + f" { self .model_description .name } ."
238
+ + " Consider using `predict_sample_with_fixed_blocking`"
239
+ )
240
+
241
+ ns = ns or self ._default_ns
242
+ if isinstance (ns , int ):
243
+ ns = {
244
+ (ipt .id , a .id ): ns
245
+ for ipt in self .model_description .inputs
246
+ for a in ipt .axes
247
+ if isinstance (a .size , v0_5 .ParameterizedSize )
248
+ }
249
+ input_block_shape = self .model_description .get_tensor_sizes (
250
+ ns , batch_size or self ._default_batch_size
251
+ ).inputs
252
+
253
+ return self .predict_sample_with_fixed_blocking (
254
+ sample ,
255
+ input_block_shape = input_block_shape ,
256
+ skip_preprocessing = skip_preprocessing ,
257
+ skip_postprocessing = skip_postprocessing ,
258
+ )
259
+
242
260
# def predict(
243
261
# self,
244
262
# inputs: Predict_IO,
0 commit comments