Skip to content

Commit 44a0da7

Browse files
Clean up
1 parent 0001932 commit 44a0da7

File tree

1 file changed

+22
-47
lines changed

1 file changed

+22
-47
lines changed

bioimageio/core/prediction.py

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def _get_tiling(shape, tile_shape, halo, input_axes):
193193

194194
def _predict_with_tiling_impl(
195195
prediction_pipeline: PredictionPipeline,
196-
# TODO this can be anything with a numpy-like interface
197196
inputs: List[xr.DataArray],
198197
outputs: List[xr.DataArray],
199198
tile_shapes: List[dict],
@@ -219,9 +218,6 @@ def _predict_with_tiling_impl(
219218
assert all(isinstance(ax, str) for ax in input_.dims)
220219
input_axes: Tuple[str, ...] = input_.dims # noqa
221220

222-
# TODO need to adapt this that it supports out of core.
223-
# maybe xarray dask integration would help?
224-
# https://xarray.pydata.org/en/stable/user-guide/dask.html
225221
def load_tile(tile):
226222
inp = input_[tile]
227223
# whether to pad on the right or left of the dim for the spatial dims
@@ -406,16 +402,10 @@ def check_tiling(tiling):
406402
return tiling
407403

408404

409-
# TODO enable passing anything that is numpy array compatible, e.g. a zarr array
410-
# Maybe use xarray dask integration? See https://xarray.pydata.org/en/stable/user-guide/dask.html
411-
# TODO how do we do this with typing?
412405
def predict_with_tiling(
413406
prediction_pipeline: PredictionPipeline,
414-
# TODO needs to be list, use Sequence instead of List / Tuple, allow numpy like
415407
inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]],
416408
tiling: Union[bool, Dict[str, Dict[str, int]]],
417-
# TODO Sequence, allow numpy like
418-
outputs: Optional[Union[List[xr.DataArray]]] = None,
419409
verbose: bool = False,
420410
) -> List[xr.DataArray]:
421411
"""Run prediction with tiling for a single set of input(s) with a bioimage.io model.
@@ -424,7 +414,6 @@ def predict_with_tiling(
424414
prediction_pipeline: the prediction pipeline for the input model.
425415
inputs: the input(s) for this model represented as xarray data.
426416
tiling: the tiling settings. Pass True to derive from the model spec.
427-
outputs: optional output arrays.
428417
verbose: whether to print the prediction progress.
429418
"""
430419
if not tiling:
@@ -441,44 +430,30 @@ def predict_with_tiling(
441430
}
442431
)
443432

444-
if outputs is None:
445-
outputs = []
446-
for output_spec in prediction_pipeline.output_specs:
447-
if isinstance(output_spec.shape, ImplicitOutputShape):
448-
scale = dict(zip(output_spec.axes, output_spec.shape.scale))
449-
offset = dict(zip(output_spec.axes, output_spec.shape.offset))
450-
451-
# for now, we only support tiling if the spatial shape doesn't change
452-
# supporting this should not be so difficult, we would just need to apply the inverse
453-
# to "out_shape = scale * in_shape + 2 * offset" ("in_shape = (out_shape - 2 * offset) / scale")
454-
# to 'outer_tile' in 'get_tiling'
455-
if any(sc != 1 for ax, sc in scale.items() if ax in "xyz") or any(
456-
off != 0 for ax, off in offset.items() if ax in "xyz"
457-
):
458-
raise NotImplementedError("Tiling with a different output shape is not yet supported")
459-
460-
ref_input = named_inputs[output_spec.shape.reference_tensor]
461-
ref_input_shape = dict(zip(ref_input.dims, ref_input.shape))
462-
output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes)
463-
else:
464-
output_shape = tuple(output_spec.shape)
433+
outputs = []
434+
for output_spec in prediction_pipeline.output_specs:
435+
if isinstance(output_spec.shape, ImplicitOutputShape):
436+
scale = dict(zip(output_spec.axes, output_spec.shape.scale))
437+
offset = dict(zip(output_spec.axes, output_spec.shape.offset))
438+
439+
# for now, we only support tiling if the spatial shape doesn't change
440+
# supporting this should not be so difficult, we would just need to apply the inverse
441+
# to "out_shape = scale * in_shape + 2 * offset" ("in_shape = (out_shape - 2 * offset) / scale")
442+
# to 'outer_tile' in 'get_tiling'
443+
if any(sc != 1 for ax, sc in scale.items() if ax in "xyz") or any(
444+
off != 0 for ax, off in offset.items() if ax in "xyz"
445+
):
446+
raise NotImplementedError("Tiling with a different output shape is not yet supported")
447+
448+
ref_input = named_inputs[output_spec.shape.reference_tensor]
449+
ref_input_shape = dict(zip(ref_input.dims, ref_input.shape))
450+
output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes)
451+
else:
452+
output_shape = tuple(output_spec.shape)
465453

466-
outputs.append(
467-
xr.DataArray(np.zeros(output_shape, dtype=output_spec.data_type), dims=tuple(output_spec.axes))
468-
)
469-
elif len(outputs) != len(prediction_pipeline.output_specs):
470-
raise ValueError(
471-
f"Number of outputs are incompatible: expected {len(prediction_pipeline.output_specs)}, got {len(outputs)}"
454+
outputs.append(
455+
xr.DataArray(np.zeros(output_shape, dtype=output_spec.data_type), dims=tuple(output_spec.axes))
472456
)
473-
else:
474-
# eventually we need to fully validate the output shape against the spec, for now we only
475-
# support a single output of same spatial shape as the (single) input
476-
if len(outputs) != len(inputs):
477-
raise NotImplementedError("Tiling with a different number of inputs and outputs is not yet supported")
478-
spatial_in_shape = tuple(sh for ax, sh in zip(prediction_pipeline.input_specs[0].axes, inputs[0].shape))
479-
spatial_out_shape = tuple(sh for ax, sh in zip(prediction_pipeline.output_specs[0].axes, outputs[0].shape))
480-
if spatial_in_shape != spatial_out_shape:
481-
raise NotImplementedError("Tiling with a different output shape is not yet supported")
482457

483458
_predict_with_tiling_impl(
484459
prediction_pipeline,

0 commit comments

Comments
 (0)