Skip to content

Commit 79c4d9f

Browse files
Merge pull request #271 from bioimage-io/update-pred
Updates to prediction functionality etc
2 parents f241d3a + 200ba09 commit 79c4d9f

File tree

4 files changed

+9
-8
lines changed

4 files changed

+9
-8
lines changed

bioimageio/core/VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"version": "0.5.3post2"
2+
"version": "0.5.4"
33
}

bioimageio/core/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
serialize_raw_resource_description,
1212
)
1313
from .prediction_pipeline import create_prediction_pipeline
14-
from .prediction import predict_image, predict_images
14+
from .prediction import predict_image, predict_images, predict_with_padding, predict_with_tiling
15+
from .resource_tests import check_input_shape, check_output_shape, test_resource

bioimageio/core/prediction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def check_padding(padding):
183183
def predict_with_padding(
184184
prediction_pipeline: PredictionPipeline,
185185
inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]],
186-
padding: Union[bool, Dict[str, int]],
186+
padding: Union[bool, Dict[str, int]] = True,
187187
pad_right: bool = True,
188188
) -> List[xr.DataArray]:
189189
"""Run prediction with padding for a single set of input(s) with a bioimage.io model.
@@ -305,7 +305,7 @@ def check_tiling(tiling):
305305
def predict_with_tiling(
306306
prediction_pipeline: PredictionPipeline,
307307
inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]],
308-
tiling: Union[bool, Dict[str, Dict[str, int]]],
308+
tiling: Union[bool, Dict[str, Dict[str, int]]] = True,
309309
verbose: bool = False,
310310
) -> List[xr.DataArray]:
311311
"""Run prediction with tiling for a single set of input(s) with a bioimage.io model.

bioimageio/core/resource_tests.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_model(
6060
)
6161

6262

63-
def _validate_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
63+
def check_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
6464
if isinstance(shape_spec, list):
6565
if shape != tuple(shape_spec):
6666
return False
@@ -81,7 +81,7 @@ def _validate_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
8181
return True
8282

8383

84-
def _validate_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) -> bool:
84+
def check_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) -> bool:
8585
if isinstance(shape_spec, list):
8686
return shape == tuple(shape_spec)
8787
elif isinstance(shape_spec, ImplicitOutputShape):
@@ -129,7 +129,7 @@ def test_resource(
129129
assert len(inputs) == len(model.inputs) # should be checked by validation
130130
input_shapes = {}
131131
for idx, (ipt, ipt_spec) in enumerate(zip(inputs, model.inputs)):
132-
if not _validate_input_shape(tuple(ipt.shape), ipt_spec.shape):
132+
if not check_input_shape(tuple(ipt.shape), ipt_spec.shape):
133133
raise ValidationError(
134134
f"Shape {tuple(ipt.shape)} of test input {idx} '{ipt_spec.name}' does not match "
135135
f"input shape description: {ipt_spec.shape}."
@@ -138,7 +138,7 @@ def test_resource(
138138

139139
assert len(expected) == len(model.outputs) # should be checked by validation
140140
for idx, (out, out_spec) in enumerate(zip(expected, model.outputs)):
141-
if not _validate_output_shape(tuple(out.shape), out_spec.shape, input_shapes):
141+
if not check_output_shape(tuple(out.shape), out_spec.shape, input_shapes):
142142
error = (error or "") + (
143143
f"Shape {tuple(out.shape)} of test output {idx} '{out_spec.name}' does not match "
144144
f"output shape description: {out_spec.shape}."

0 commit comments

Comments
 (0)