@@ -193,7 +193,6 @@ def _get_tiling(shape, tile_shape, halo, input_axes):
193
193
194
194
def _predict_with_tiling_impl (
195
195
prediction_pipeline : PredictionPipeline ,
196
- # TODO this can be anything with a numpy-like interface
197
196
inputs : List [xr .DataArray ],
198
197
outputs : List [xr .DataArray ],
199
198
tile_shapes : List [dict ],
@@ -219,9 +218,6 @@ def _predict_with_tiling_impl(
219
218
assert all (isinstance (ax , str ) for ax in input_ .dims )
220
219
input_axes : Tuple [str , ...] = input_ .dims # noqa
221
220
222
- # TODO need to adapt this that it supports out of core.
223
- # maybe xarray dask integration would help?
224
- # https://xarray.pydata.org/en/stable/user-guide/dask.html
225
221
def load_tile (tile ):
226
222
inp = input_ [tile ]
227
223
# whether to pad on the right or left of the dim for the spatial dims
@@ -406,16 +402,10 @@ def check_tiling(tiling):
406
402
return tiling
407
403
408
404
409
- # TODO enable passing anything that is numpy array compatible, e.g. a zarr array
410
- # Maybe use xarray dask integration? See https://xarray.pydata.org/en/stable/user-guide/dask.html
411
- # TODO how do we do this with typing?
412
405
def predict_with_tiling (
413
406
prediction_pipeline : PredictionPipeline ,
414
- # TODO needs to be list, use Sequence instead of List / Tuple, allow numpy like
415
407
inputs : Union [xr .DataArray , List [xr .DataArray ], Tuple [xr .DataArray ]],
416
408
tiling : Union [bool , Dict [str , Dict [str , int ]]],
417
- # TODO Sequence, allow numpy like
418
- outputs : Optional [Union [List [xr .DataArray ]]] = None ,
419
409
verbose : bool = False ,
420
410
) -> List [xr .DataArray ]:
421
411
"""Run prediction with tiling for a single set of input(s) with a bioimage.io model.
@@ -424,7 +414,6 @@ def predict_with_tiling(
424
414
prediction_pipeline: the prediction pipeline for the input model.
425
415
inputs: the input(s) for this model represented as xarray data.
426
416
tiling: the tiling settings. Pass True to derive from the model spec.
427
- outputs: optional output arrays.
428
417
verbose: whether to print the prediction progress.
429
418
"""
430
419
if not tiling :
@@ -441,44 +430,30 @@ def predict_with_tiling(
441
430
}
442
431
)
443
432
444
- if outputs is None :
445
- outputs = []
446
- for output_spec in prediction_pipeline .output_specs :
447
- if isinstance (output_spec .shape , ImplicitOutputShape ):
448
- scale = dict (zip (output_spec .axes , output_spec .shape .scale ))
449
- offset = dict (zip (output_spec .axes , output_spec .shape .offset ))
450
-
451
- # for now, we only support tiling if the spatial shape doesn't change
452
- # supporting this should not be so difficult, we would just need to apply the inverse
453
- # to "out_shape = scale * in_shape + 2 * offset" ("in_shape = (out_shape - 2 * offset) / scale")
454
- # to 'outer_tile' in 'get_tiling'
455
- if any (sc != 1 for ax , sc in scale .items () if ax in "xyz" ) or any (
456
- off != 0 for ax , off in offset .items () if ax in "xyz"
457
- ):
458
- raise NotImplementedError ("Tiling with a different output shape is not yet supported" )
459
-
460
- ref_input = named_inputs [output_spec .shape .reference_tensor ]
461
- ref_input_shape = dict (zip (ref_input .dims , ref_input .shape ))
462
- output_shape = tuple (int (scale [ax ] * ref_input_shape [ax ] + 2 * offset [ax ]) for ax in output_spec .axes )
463
- else :
464
- output_shape = tuple (output_spec .shape )
433
+ outputs = []
434
+ for output_spec in prediction_pipeline .output_specs :
435
+ if isinstance (output_spec .shape , ImplicitOutputShape ):
436
+ scale = dict (zip (output_spec .axes , output_spec .shape .scale ))
437
+ offset = dict (zip (output_spec .axes , output_spec .shape .offset ))
438
+
439
+ # for now, we only support tiling if the spatial shape doesn't change
440
+ # supporting this should not be so difficult, we would just need to apply the inverse
441
+ # to "out_shape = scale * in_shape + 2 * offset" ("in_shape = (out_shape - 2 * offset) / scale")
442
+ # to 'outer_tile' in 'get_tiling'
443
+ if any (sc != 1 for ax , sc in scale .items () if ax in "xyz" ) or any (
444
+ off != 0 for ax , off in offset .items () if ax in "xyz"
445
+ ):
446
+ raise NotImplementedError ("Tiling with a different output shape is not yet supported" )
447
+
448
+ ref_input = named_inputs [output_spec .shape .reference_tensor ]
449
+ ref_input_shape = dict (zip (ref_input .dims , ref_input .shape ))
450
+ output_shape = tuple (int (scale [ax ] * ref_input_shape [ax ] + 2 * offset [ax ]) for ax in output_spec .axes )
451
+ else :
452
+ output_shape = tuple (output_spec .shape )
465
453
466
- outputs .append (
467
- xr .DataArray (np .zeros (output_shape , dtype = output_spec .data_type ), dims = tuple (output_spec .axes ))
468
- )
469
- elif len (outputs ) != len (prediction_pipeline .output_specs ):
470
- raise ValueError (
471
- f"Number of outputs are incompatible: expected { len (prediction_pipeline .output_specs )} , got { len (outputs )} "
454
+ outputs .append (
455
+ xr .DataArray (np .zeros (output_shape , dtype = output_spec .data_type ), dims = tuple (output_spec .axes ))
472
456
)
473
- else :
474
- # eventually we need to fully validate the output shape against the spec, for now we only
475
- # support a single output of same spatial shape as the (single) input
476
- if len (outputs ) != len (inputs ):
477
- raise NotImplementedError ("Tiling with a different number of inputs and outputs is not yet supported" )
478
- spatial_in_shape = tuple (sh for ax , sh in zip (prediction_pipeline .input_specs [0 ].axes , inputs [0 ].shape ))
479
- spatial_out_shape = tuple (sh for ax , sh in zip (prediction_pipeline .output_specs [0 ].axes , outputs [0 ].shape ))
480
- if spatial_in_shape != spatial_out_shape :
481
- raise NotImplementedError ("Tiling with a different output shape is not yet supported" )
482
457
483
458
_predict_with_tiling_impl (
484
459
prediction_pipeline ,
0 commit comments