@@ -197,6 +197,7 @@ def _predict_with_tiling_impl(
197
197
outputs : List [xr .DataArray ],
198
198
tile_shapes : List [dict ],
199
199
halos : List [dict ],
200
+ verbose : bool = False ,
200
201
):
201
202
if len (inputs ) > 1 :
202
203
raise NotImplementedError ("Tiling with multiple inputs not implemented yet" )
@@ -224,6 +225,11 @@ def load_tile(tile):
224
225
pad_right = [tile [ax ].start == 0 if ax in "xyz" else None for ax in input_axes ]
225
226
return inp , pad_right
226
227
228
+ if verbose :
229
+ shape = {ax : sh for ax , sh in zip (prediction_pipeline .input_specs [0 ].axes , input_ .shape )}
230
+ n_tiles = int (np .prod ([np .ceil (float (shape [ax ]) / tsh ) for ax , tsh in tile_shape .items ()]))
231
+ tiles = tqdm (tiles , total = n_tiles , desc = "prediction with tiling" )
232
+
227
233
# we need to use padded prediction for the individual tiles in case the
228
234
# border tiles don't match the requested tile shape
229
235
padding = {ax : tile_shape [ax ] + 2 * halo [ax ] for ax in input_axes if ax in "xyz" }
@@ -400,13 +406,15 @@ def predict_with_tiling(
400
406
prediction_pipeline : PredictionPipeline ,
401
407
inputs : Union [xr .DataArray , List [xr .DataArray ], Tuple [xr .DataArray ]],
402
408
tiling : Union [bool , Dict [str , Dict [str , int ]]],
409
+ verbose : bool = False ,
403
410
) -> List [xr .DataArray ]:
404
411
"""Run prediction with tiling for a single set of input(s) with a bioimage.io model.
405
412
406
413
Args:
407
414
prediction_pipeline: the prediction pipeline for the input model.
408
415
inputs: the input(s) for this model represented as xarray data.
409
416
tiling: the tiling settings. Pass True to derive from the model spec.
417
+ verbose: whether to print the prediction progress.
410
418
"""
411
419
if not tiling :
412
420
raise ValueError
@@ -451,6 +459,7 @@ def predict_with_tiling(
451
459
outputs ,
452
460
tile_shapes = [tiling ["tile" ]], # todo: update tiling for multiple inputs/outputs
453
461
halos = [tiling ["halo" ]],
462
+ verbose = verbose ,
454
463
)
455
464
456
465
return outputs
0 commit comments