Skip to content

Commit f241d3a

Browse files
Merge pull request #266 from bioimage-io/image-helper
Refactor image functionality
2 parents 7af0e0f + 94bcd3f commit f241d3a

File tree

5 files changed

+225
-135
lines changed

5 files changed

+225
-135
lines changed

bioimageio/core/image_helper.py

+180
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import os
2+
from copy import deepcopy
3+
from typing import Dict, List, Optional, Sequence, Tuple, Union
4+
5+
import imageio
6+
import numpy as np
7+
from xarray import DataArray
8+
from bioimageio.core.resource_io.nodes import InputTensor, OutputTensor
9+
10+
11+
#
12+
# helper functions to transform input images / output tensors to the required axes
13+
#
14+
15+
16+
def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None):
17+
"""Transform input image into output tensor with desired axes.
18+
19+
Args:
20+
image: the input image
21+
tensor_axes: the desired tensor axes
22+
input_axes: the axes of the input image (optional)
23+
"""
24+
# if the image axes are not given deduce them from the required axes and image shape
25+
if image_axes is None:
26+
has_z_axis = "z" in tensor_axes
27+
ndim = image.ndim
28+
if ndim == 2:
29+
image_axes = "yx"
30+
elif ndim == 3:
31+
image_axes = "zyx" if has_z_axis else "cyx"
32+
elif ndim == 4:
33+
image_axes = "czyx"
34+
elif ndim == 5:
35+
image_axes = "bczyx"
36+
else:
37+
raise ValueError(f"Invalid number of image dimensions: {ndim}")
38+
tensor = DataArray(image, dims=tuple(image_axes))
39+
# expand the missing image axes
40+
missing_axes = tuple(set(tensor_axes) - set(image_axes))
41+
tensor = tensor.expand_dims(dim=missing_axes)
42+
# transpose to the correct axis order
43+
tensor = tensor.transpose(*tuple(tensor_axes))
44+
# return numpy array
45+
return tensor.values
46+
47+
48+
def _drop_axis_default(axis_name, axis_len):
49+
# spatial axes: drop at middle coordnate
50+
# other axes (channel or batch): drop at 0 coordinate
51+
return axis_len // 2 if axis_name in "zyx" else 0
52+
53+
54+
def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: str, drop_function=_drop_axis_default):
55+
"""Transform output tensor into image with desired axes.
56+
57+
Args:
58+
tensor: the output tensor
59+
tensor_axes: bioimageio model spec
60+
output_axes: the desired output axes
61+
drop_function: function that determines how to drop unwanted axes
62+
"""
63+
if len(tensor_axes) != tensor.ndim:
64+
raise ValueError(f"Number of axes {len(tensor_axes)} and dimension of tensor {tensor.ndim} don't match")
65+
shape = {ax_name: sh for ax_name, sh in zip(tensor_axes, tensor.shape)}
66+
output = DataArray(tensor, dims=tuple(tensor_axes))
67+
# drop unwanted axes
68+
drop_axis_names = tuple(set(tensor_axes) - set(output_axes))
69+
drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names}
70+
output = output[drop_axes]
71+
# transpose to the desired axis order
72+
output = output.transpose(*tuple(output_axes))
73+
# return numpy array
74+
return output.values
75+
76+
77+
def to_channel_last(image):
78+
chan_id = image.dims.index("c")
79+
if chan_id != image.ndim - 1:
80+
target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",)
81+
image = image.transpose(*target_axes)
82+
return image
83+
84+
85+
#
86+
# helper functions for loading and saving images
87+
#
88+
89+
90+
def load_image(in_path, axes: Sequence[str]) -> DataArray:
91+
ext = os.path.splitext(in_path)[1]
92+
if ext == ".npy":
93+
im = np.load(in_path)
94+
else:
95+
is_volume = "z" in axes
96+
im = imageio.volread(in_path) if is_volume else imageio.imread(in_path)
97+
im = transform_input_image(im, axes)
98+
return DataArray(im, dims=axes)
99+
100+
101+
def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[DataArray]:
102+
return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)]
103+
104+
105+
def save_image(out_path, image):
106+
ext = os.path.splitext(out_path)[1]
107+
if ext == ".npy":
108+
np.save(out_path, image)
109+
else:
110+
is_volume = "z" in image.dims
111+
112+
# squeeze batch or channel axes if they are singletons
113+
squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)}
114+
image = image[squeeze]
115+
116+
if "b" in image.dims:
117+
raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file")
118+
if "c" in image.dims: # image formats need channel last
119+
image = to_channel_last(image)
120+
121+
save_function = imageio.volsave if is_volume else imageio.imsave
122+
# most image formats only support channel dimensions of 1, 3 or 4;
123+
# if not we need to save the channels separately
124+
ndim = 3 if is_volume else 2
125+
save_as_single_image = image.ndim == ndim or (image.shape[-1] in (3, 4))
126+
127+
if save_as_single_image:
128+
save_function(out_path, image)
129+
else:
130+
out_prefix, ext = os.path.splitext(out_path)
131+
for c in range(image.shape[-1]):
132+
chan_out_path = f"{out_prefix}-c{c}{ext}"
133+
save_function(chan_out_path, image[..., c])
134+
135+
136+
#
137+
# helper function for padding
138+
#
139+
140+
141+
def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]:
142+
assert image.ndim == len(axes), f"{image.ndim}, {len(axes)}"
143+
144+
padding_ = deepcopy(padding)
145+
mode = padding_.pop("mode", "dynamic")
146+
assert mode in ("dynamic", "fixed")
147+
148+
is_volume = "z" in axes
149+
if is_volume:
150+
assert len(padding_) == 3
151+
else:
152+
assert len(padding_) == 2
153+
154+
if isinstance(pad_right, bool):
155+
pad_right = len(axes) * [pad_right]
156+
157+
pad_width = []
158+
crop = {}
159+
for ax, dlen, pr in zip(axes, image.shape, pad_right):
160+
161+
if ax in "zyx":
162+
pad_to = padding_[ax]
163+
164+
if mode == "dynamic":
165+
r = dlen % pad_to
166+
pwidth = 0 if r == 0 else (pad_to - r)
167+
else:
168+
if pad_to < dlen:
169+
msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}."
170+
raise RuntimeError(msg)
171+
pwidth = pad_to - dlen
172+
173+
pad_width.append([0, pwidth] if pr else [pwidth, 0])
174+
crop[ax] = slice(0, dlen) if pr else slice(pwidth, None)
175+
else:
176+
pad_width.append([0, 0])
177+
crop[ax] = slice(None)
178+
179+
image = np.pad(image, pad_width, mode="symmetric")
180+
return image, crop

bioimageio/core/prediction.py

+5-134
Original file line numberDiff line numberDiff line change
@@ -1,150 +1,21 @@
11
import collections
22
import os
3-
from copy import deepcopy
43
from itertools import product
54
from pathlib import Path
65
from typing import Dict, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union
76

8-
import imageio
97
import numpy as np
108
import xarray as xr
119
from tqdm import tqdm
1210

11+
from bioimageio.core import image_helper
1312
from bioimageio.core import load_resource_description
1413
from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline
15-
from bioimageio.core.resource_io.nodes import ImplicitOutputShape, InputTensor, Model, ResourceDescription, OutputTensor
14+
from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescription
1615
from bioimageio.spec.shared import raw_nodes
1716
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
1817

1918

20-
#
21-
# utility functions for prediction
22-
#
23-
def _require_axes(im, axes):
24-
is_volume = "z" in axes
25-
# we assume images / volumes are loaded as one of
26-
# yx, yxc, zyxc
27-
if im.ndim == 2:
28-
im_axes = ("y", "x")
29-
elif im.ndim == 3:
30-
im_axes = ("z", "y", "x") if is_volume else ("y", "x", "c")
31-
elif im.ndim == 4:
32-
raise NotImplementedError
33-
else: # ndim >= 5 not implemented
34-
raise RuntimeError
35-
36-
# add singleton channel dimension if not present
37-
if "c" not in im_axes:
38-
im = im[..., None]
39-
im_axes = im_axes + ("c",)
40-
41-
# add singleton batch dim
42-
im = im[None]
43-
im_axes = ("b",) + im_axes
44-
45-
# permute the axes correctly
46-
assert set(axes) == set(im_axes)
47-
axes_permutation = tuple(im_axes.index(ax) for ax in axes)
48-
im = im.transpose(axes_permutation)
49-
return im
50-
51-
52-
def _pad(im, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]:
53-
assert im.ndim == len(axes), f"{im.ndim}, {len(axes)}"
54-
55-
padding_ = deepcopy(padding)
56-
mode = padding_.pop("mode", "dynamic")
57-
assert mode in ("dynamic", "fixed")
58-
59-
is_volume = "z" in axes
60-
if is_volume:
61-
assert len(padding_) == 3
62-
else:
63-
assert len(padding_) == 2
64-
65-
if isinstance(pad_right, bool):
66-
pad_right = len(axes) * [pad_right]
67-
68-
pad_width = []
69-
crop = {}
70-
for ax, dlen, pr in zip(axes, im.shape, pad_right):
71-
72-
if ax in "zyx":
73-
pad_to = padding_[ax]
74-
75-
if mode == "dynamic":
76-
r = dlen % pad_to
77-
pwidth = 0 if r == 0 else (pad_to - r)
78-
else:
79-
if pad_to < dlen:
80-
msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}."
81-
raise RuntimeError(msg)
82-
pwidth = pad_to - dlen
83-
84-
pad_width.append([0, pwidth] if pr else [pwidth, 0])
85-
crop[ax] = slice(0, dlen) if pr else slice(pwidth, None)
86-
else:
87-
pad_width.append([0, 0])
88-
crop[ax] = slice(None)
89-
90-
im = np.pad(im, pad_width, mode="symmetric")
91-
return im, crop
92-
93-
94-
def _load_image(in_path, axes: Sequence[str]) -> xr.DataArray:
95-
ext = os.path.splitext(in_path)[1]
96-
if ext == ".npy":
97-
im = np.load(in_path)
98-
else:
99-
is_volume = "z" in axes
100-
im = imageio.volread(in_path) if is_volume else imageio.imread(in_path)
101-
im = _require_axes(im, axes)
102-
return xr.DataArray(im, dims=axes)
103-
104-
105-
def _load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[xr.DataArray]:
106-
return [_load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)]
107-
108-
109-
def _to_channel_last(image):
110-
chan_id = image.dims.index("c")
111-
if chan_id != image.ndim - 1:
112-
target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",)
113-
image = image.transpose(*target_axes)
114-
return image
115-
116-
117-
def _save_image(out_path, image):
118-
ext = os.path.splitext(out_path)[1]
119-
if ext == ".npy":
120-
np.save(out_path, image)
121-
else:
122-
is_volume = "z" in image.dims
123-
124-
# squeeze batch or channel axes if they are singletons
125-
squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)}
126-
image = image[squeeze]
127-
128-
if "b" in image.dims:
129-
raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file")
130-
if "c" in image.dims: # image formats need channel last
131-
image = _to_channel_last(image)
132-
133-
save_function = imageio.volsave if is_volume else imageio.imsave
134-
# most image formats only support channel dimensions of 1, 3 or 4;
135-
# if not we need to save the channels separately
136-
ndim = 3 if is_volume else 2
137-
save_as_single_image = image.ndim == ndim or (image.shape[-1] in (3, 4))
138-
139-
if save_as_single_image:
140-
save_function(out_path, image)
141-
else:
142-
out_prefix, ext = os.path.splitext(out_path)
143-
for c in range(image.shape[-1]):
144-
chan_out_path = f"{out_prefix}-c{c}{ext}"
145-
save_function(chan_out_path, image[..., c])
146-
147-
14819
def _apply_crop(data, crop):
14920
crop = tuple(crop[ax] for ax in data.dims)
15021
return data[crop]
@@ -345,7 +216,7 @@ def predict_with_padding(
345216
assert len(padding) == len(prediction_pipeline.input_specs)
346217
inputs, crops = zip(
347218
*[
348-
_pad(inp, spec.axes, p, pad_right=pad_right)
219+
image_helper.pad(inp, spec.axes, p, pad_right=pad_right)
349220
for inp, spec, p in zip(inputs, prediction_pipeline.input_specs, padding)
350221
]
351222
)
@@ -508,7 +379,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
508379
if padding and tiling:
509380
raise ValueError("Only one of padding or tiling is supported")
510381

511-
input_data = _load_tensors(inputs, prediction_pipeline.input_specs)
382+
input_data = image_helper.load_tensors(inputs, prediction_pipeline.input_specs)
512383
if padding is not None:
513384
result = predict_with_padding(prediction_pipeline, input_data, padding)
514385
elif tiling is not None:
@@ -519,7 +390,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
519390
assert isinstance(result, list)
520391
assert len(result) == len(outputs)
521392
for res, out in zip(result, outputs):
522-
_save_image(out, res)
393+
image_helper.save_image(out, res)
523394

524395

525396
def predict_image(

bioimageio/core/resource_tests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _validate_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) ->
8989
if ref_tensor not in input_shapes:
9090
raise ValidationError(f"The reference tensor name {ref_tensor} is not in {input_shapes}")
9191
ipt_shape = numpy.array(input_shapes[ref_tensor])
92-
scale = numpy.array(shape_spec.scale)
92+
scale = numpy.array([0.0 if sc is None else sc for sc in shape_spec.scale])
9393
offset = numpy.array(shape_spec.offset)
9494
exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset
9595

tests/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
3333
"unet2d_nuclei_broad/rdf.yaml"
3434
),
35+
"unet2d_expand_output_shape": (
36+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
37+
"unet2d_nuclei_broad/rdf_expand_output_shape.yaml"
38+
),
3539
"unet2d_fixed_shape": (
3640
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
3741
"unet2d_fixed_shape/rdf.yaml"
@@ -205,6 +209,12 @@ def unet2d_diff_output_shape(request):
205209
return pytest.model_packages[request.param]
206210

207211

212+
# written as model group to automatically skip on missing torch
213+
@pytest.fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"])
214+
def unet2d_expand_output_shape(request):
215+
return pytest.model_packages[request.param]
216+
217+
208218
# written as model group to automatically skip on missing torch
209219
@pytest.fixture(params=[] if skip_torch else ["unet2d_fixed_shape"])
210220
def unet2d_fixed_shape(request):

0 commit comments

Comments
 (0)