Skip to content

Commit 59a3d21

Browse files
Merge pull request #245 from bioimage-io/tile-update
Fix computation for number of tiles
2 parents 7ebe658 + 237e8d7 commit 59a3d21

File tree

3 files changed

+65
-40
lines changed

3 files changed

+65
-40
lines changed

bioimageio/core/prediction.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def load_tile(tile):
228228

229229
if verbose:
230230
shape = {ax: sh for ax, sh in zip(prediction_pipeline.input_specs[0].axes, input_.shape)}
231-
n_tiles = int(np.prod([np.ceil(float(shape[ax]) / tsh) for ax, tsh in tile_shape.items()]))
231+
n_tiles = int(np.prod([np.ceil(float(shape[ax]) / (tsh - 2 * halo[ax])) for ax, tsh in tile_shape.items()]))
232232
tiles = tqdm(tiles, total=n_tiles, desc="prediction with tiling")
233233

234234
# we need to use padded prediction for the individual tiles in case the
@@ -388,7 +388,7 @@ def check_tiling(tiling):
388388
spatial_axes = [ax for ax in axes if ax in "xyz"]
389389
halo = tiling["halo"]
390390
tile = tiling["tile"]
391-
assert all(halo.get(ax, 0) > 0 for ax in spatial_axes)
391+
assert all(halo.get(ax, 0) >= 0 for ax in spatial_axes)
392392
assert all(tile.get(ax, 0) > 0 for ax in spatial_axes)
393393

394394
if isinstance(tiling, dict):
@@ -408,7 +408,8 @@ def check_tiling(tiling):
408408

409409
halo = output_spec.halo
410410
if halo is None:
411-
raise ValueError("Model does not provide a valid halo to use for tiling with default parameters")
411+
halo = [0] * len(axes)
412+
assert len(halo) == len(axes)
412413

413414
tiling = {
414415
"halo": {ax: ha for ax, ha in zip(axes, halo) if ax in "xyz"},
@@ -465,7 +466,21 @@ def predict_with_tiling(
465466
ref_input_shape = dict(zip(ref_input.dims, ref_input.shape))
466467
output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes)
467468
else:
468-
output_shape = tuple(output_spec.shape)
469+
if len(inputs) > 1:
470+
raise NotImplementedError
471+
input_spec = prediction_pipeline.input_specs[0]
472+
if input_spec.axes != output_spec.axes:
473+
raise NotImplementedError("Tiling with a different output shape is not yet supported")
474+
out_axes = output_spec.axes
475+
fixed_shape = tuple(output_spec.shape)
476+
if not all(fixed_shape[out_axes.index(ax)] == tile_shape for ax, tile_shape in tiling["tile"].items()):
477+
raise NotImplementedError("Tiling with a different output shape is not yet supported")
478+
479+
output_shape = list(inputs[0].shape)
480+
chan_id = out_axes.index("c")
481+
if fixed_shape[chan_id] != output_shape[chan_id]:
482+
output_shape[chan_id] = fixed_shape[chan_id]
483+
output_shape = tuple(output_shape)
469484

470485
outputs.append(xr.DataArray(np.zeros(output_shape, dtype=output_spec.data_type), dims=tuple(output_spec.axes)))
471486

tests/conftest.py

+42-36
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}")
1111

1212
# test models for various frameworks
13-
torch_models = []
14-
torch_models_pre_3_10 = [
13+
torch_models = [
1514
"unet2d_fixed_shape",
1615
"unet2d_multi_tensor",
1716
"unet2d_nuclei_broad_model",
@@ -97,9 +96,6 @@
9796
# load all model packages we need for testing
9897
load_model_packages = set()
9998
if not skip_torch:
100-
if torch_version < (3, 10):
101-
torch_models += torch_models_pre_3_10
102-
10399
load_model_packages |= set(torch_models + torchscript_models)
104100

105101
if not skip_onnx:
@@ -130,35 +126,6 @@ def pytest_configure():
130126
# model groups of the form any_<weight format>_model that include all models providing a specific weight format
131127
#
132128

133-
# written as model group to automatically skip on missing torch
134-
@pytest.fixture(params=[] if skip_torch or torch_version >= (3, 10) else ["unet2d_nuclei_broad_model"])
135-
def unet2d_nuclei_broad_model(request):
136-
return pytest.model_packages[request.param]
137-
138-
139-
# written as model group to automatically skip on missing torch
140-
@pytest.fixture(params=[] if skip_torch or torch_version >= (3, 10) else ["unet2d_diff_output_shape"])
141-
def unet2d_diff_output_shape(request):
142-
return pytest.model_packages[request.param]
143-
144-
145-
# written as model group to automatically skip on missing tensorflow 1
146-
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
147-
def stardist_wrong_shape(request):
148-
return pytest.model_packages[request.param]
149-
150-
151-
# written as model group to automatically skip on missing tensorflow 1
152-
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"])
153-
def stardist_wrong_shape2(request):
154-
return pytest.model_packages[request.param]
155-
156-
157-
# written as model group to automatically skip on missing tensorflow 1
158-
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"])
159-
def stardist(request):
160-
return pytest.model_packages[request.param]
161-
162129

163130
@pytest.fixture(params=[] if skip_torch else torch_models)
164131
def any_torch_model(request):
@@ -200,19 +167,22 @@ def any_model(request):
200167
return pytest.model_packages[request.param]
201168

202169

170+
# TODO it would be nice to just generate fixtures for all the individual models dynamically
203171
#
204172
# temporary fixtures to test not with all, but only a manual selection of models
205173
# (models/functionality should be improved to get rid of this specific model group)
206174
#
175+
176+
207177
@pytest.fixture(
208-
params=[] if skip_torch or torch_version >= (3, 10) else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]
178+
params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]
209179
)
210180
def unet2d_fixed_shape_or_not(request):
211181
return pytest.model_packages[request.param]
212182

213183

214184
@pytest.fixture(
215-
params=[] if skip_torch or torch_version >= (3, 10) else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]
185+
params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]
216186
)
217187
def unet2d_multi_tensor_or_not(request):
218188
return pytest.model_packages[request.param]
@@ -221,3 +191,39 @@ def unet2d_multi_tensor_or_not(request):
221191
@pytest.fixture(params=[] if skip_keras else ["unet2d_keras"])
222192
def unet2d_keras(request):
223193
return pytest.model_packages[request.param]
194+
195+
196+
# written as model group to automatically skip on missing torch
197+
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"])
198+
def unet2d_nuclei_broad_model(request):
199+
return pytest.model_packages[request.param]
200+
201+
202+
# written as model group to automatically skip on missing torch
203+
@pytest.fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"])
204+
def unet2d_diff_output_shape(request):
205+
return pytest.model_packages[request.param]
206+
207+
208+
# written as model group to automatically skip on missing torch
209+
@pytest.fixture(params=[] if skip_torch else ["unet2d_fixed_shape"])
210+
def unet2d_fixed_shape(request):
211+
return pytest.model_packages[request.param]
212+
213+
214+
# written as model group to automatically skip on missing tensorflow 1
215+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
216+
def stardist_wrong_shape(request):
217+
return pytest.model_packages[request.param]
218+
219+
220+
# written as model group to automatically skip on missing tensorflow 1
221+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"])
222+
def stardist_wrong_shape2(request):
223+
return pytest.model_packages[request.param]
224+
225+
226+
# written as model group to automatically skip on missing tensorflow 1
227+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"])
228+
def stardist(request):
229+
return pytest.model_packages[request.param]

tests/test_prediction.py

+4
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def test_predict_image_with_tiling_channel_last(stardist, tmp_path):
176176
_test_predict_image_with_tiling(stardist, tmp_path, 0.13)
177177

178178

179+
def test_predict_image_with_tiling_fixed_output_shape(unet2d_fixed_shape, tmp_path):
180+
_test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025)
181+
182+
179183
def test_predict_images(unet2d_nuclei_broad_model, tmp_path):
180184
from bioimageio.core.prediction import predict_images
181185

0 commit comments

Comments
 (0)