1
1
import collections
2
2
import os
3
- from copy import deepcopy
4
3
from itertools import product
5
4
from pathlib import Path
6
5
from typing import Dict , Iterator , List , NamedTuple , Optional , OrderedDict , Sequence , Tuple , Union
7
6
8
- import imageio
9
7
import numpy as np
10
8
import xarray as xr
11
9
from tqdm import tqdm
12
10
11
+ from bioimageio .core import image_helper
13
12
from bioimageio .core import load_resource_description
14
13
from bioimageio .core .prediction_pipeline import PredictionPipeline , create_prediction_pipeline
15
- from bioimageio .core .resource_io .nodes import ImplicitOutputShape , InputTensor , Model , ResourceDescription , OutputTensor
14
+ from bioimageio .core .resource_io .nodes import ImplicitOutputShape , Model , ResourceDescription
16
15
from bioimageio .spec .shared import raw_nodes
17
16
from bioimageio .spec .shared .raw_nodes import ResourceDescription as RawResourceDescription
18
17
19
18
20
- #
21
- # utility functions for prediction
22
- #
23
- def _require_axes (im , axes ):
24
- is_volume = "z" in axes
25
- # we assume images / volumes are loaded as one of
26
- # yx, yxc, zyxc
27
- if im .ndim == 2 :
28
- im_axes = ("y" , "x" )
29
- elif im .ndim == 3 :
30
- im_axes = ("z" , "y" , "x" ) if is_volume else ("y" , "x" , "c" )
31
- elif im .ndim == 4 :
32
- raise NotImplementedError
33
- else : # ndim >= 5 not implemented
34
- raise RuntimeError
35
-
36
- # add singleton channel dimension if not present
37
- if "c" not in im_axes :
38
- im = im [..., None ]
39
- im_axes = im_axes + ("c" ,)
40
-
41
- # add singleton batch dim
42
- im = im [None ]
43
- im_axes = ("b" ,) + im_axes
44
-
45
- # permute the axes correctly
46
- assert set (axes ) == set (im_axes )
47
- axes_permutation = tuple (im_axes .index (ax ) for ax in axes )
48
- im = im .transpose (axes_permutation )
49
- return im
50
-
51
-
52
- def _pad (im , axes : Sequence [str ], padding , pad_right = True ) -> Tuple [np .ndarray , Dict [str , slice ]]:
53
- assert im .ndim == len (axes ), f"{ im .ndim } , { len (axes )} "
54
-
55
- padding_ = deepcopy (padding )
56
- mode = padding_ .pop ("mode" , "dynamic" )
57
- assert mode in ("dynamic" , "fixed" )
58
-
59
- is_volume = "z" in axes
60
- if is_volume :
61
- assert len (padding_ ) == 3
62
- else :
63
- assert len (padding_ ) == 2
64
-
65
- if isinstance (pad_right , bool ):
66
- pad_right = len (axes ) * [pad_right ]
67
-
68
- pad_width = []
69
- crop = {}
70
- for ax , dlen , pr in zip (axes , im .shape , pad_right ):
71
-
72
- if ax in "zyx" :
73
- pad_to = padding_ [ax ]
74
-
75
- if mode == "dynamic" :
76
- r = dlen % pad_to
77
- pwidth = 0 if r == 0 else (pad_to - r )
78
- else :
79
- if pad_to < dlen :
80
- msg = f"Padding for axis { ax } failed; pad shape { pad_to } is smaller than the image shape { dlen } ."
81
- raise RuntimeError (msg )
82
- pwidth = pad_to - dlen
83
-
84
- pad_width .append ([0 , pwidth ] if pr else [pwidth , 0 ])
85
- crop [ax ] = slice (0 , dlen ) if pr else slice (pwidth , None )
86
- else :
87
- pad_width .append ([0 , 0 ])
88
- crop [ax ] = slice (None )
89
-
90
- im = np .pad (im , pad_width , mode = "symmetric" )
91
- return im , crop
92
-
93
-
94
- def _load_image (in_path , axes : Sequence [str ]) -> xr .DataArray :
95
- ext = os .path .splitext (in_path )[1 ]
96
- if ext == ".npy" :
97
- im = np .load (in_path )
98
- else :
99
- is_volume = "z" in axes
100
- im = imageio .volread (in_path ) if is_volume else imageio .imread (in_path )
101
- im = _require_axes (im , axes )
102
- return xr .DataArray (im , dims = axes )
103
-
104
-
105
- def _load_tensors (sources , tensor_specs : List [Union [InputTensor , OutputTensor ]]) -> List [xr .DataArray ]:
106
- return [_load_image (s , sspec .axes ) for s , sspec in zip (sources , tensor_specs )]
107
-
108
-
109
- def _to_channel_last (image ):
110
- chan_id = image .dims .index ("c" )
111
- if chan_id != image .ndim - 1 :
112
- target_axes = tuple (ax for ax in image .dims if ax != "c" ) + ("c" ,)
113
- image = image .transpose (* target_axes )
114
- return image
115
-
116
-
117
- def _save_image (out_path , image ):
118
- ext = os .path .splitext (out_path )[1 ]
119
- if ext == ".npy" :
120
- np .save (out_path , image )
121
- else :
122
- is_volume = "z" in image .dims
123
-
124
- # squeeze batch or channel axes if they are singletons
125
- squeeze = {ax : 0 if (ax in "bc" and sh == 1 ) else slice (None ) for ax , sh in zip (image .dims , image .shape )}
126
- image = image [squeeze ]
127
-
128
- if "b" in image .dims :
129
- raise RuntimeError (f"Cannot save prediction with batchsize > 1 as { ext } -file" )
130
- if "c" in image .dims : # image formats need channel last
131
- image = _to_channel_last (image )
132
-
133
- save_function = imageio .volsave if is_volume else imageio .imsave
134
- # most image formats only support channel dimensions of 1, 3 or 4;
135
- # if not we need to save the channels separately
136
- ndim = 3 if is_volume else 2
137
- save_as_single_image = image .ndim == ndim or (image .shape [- 1 ] in (3 , 4 ))
138
-
139
- if save_as_single_image :
140
- save_function (out_path , image )
141
- else :
142
- out_prefix , ext = os .path .splitext (out_path )
143
- for c in range (image .shape [- 1 ]):
144
- chan_out_path = f"{ out_prefix } -c{ c } { ext } "
145
- save_function (chan_out_path , image [..., c ])
146
-
147
-
148
19
def _apply_crop (data , crop ):
149
20
crop = tuple (crop [ax ] for ax in data .dims )
150
21
return data [crop ]
@@ -345,7 +216,7 @@ def predict_with_padding(
345
216
assert len (padding ) == len (prediction_pipeline .input_specs )
346
217
inputs , crops = zip (
347
218
* [
348
- _pad (inp , spec .axes , p , pad_right = pad_right )
219
+ image_helper . pad (inp , spec .axes , p , pad_right = pad_right )
349
220
for inp , spec , p in zip (inputs , prediction_pipeline .input_specs , padding )
350
221
]
351
222
)
@@ -508,7 +379,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
508
379
if padding and tiling :
509
380
raise ValueError ("Only one of padding or tiling is supported" )
510
381
511
- input_data = _load_tensors (inputs , prediction_pipeline .input_specs )
382
+ input_data = image_helper . load_tensors (inputs , prediction_pipeline .input_specs )
512
383
if padding is not None :
513
384
result = predict_with_padding (prediction_pipeline , input_data , padding )
514
385
elif tiling is not None :
@@ -519,7 +390,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
519
390
assert isinstance (result , list )
520
391
assert len (result ) == len (outputs )
521
392
for res , out in zip (result , outputs ):
522
- _save_image (out , res )
393
+ image_helper . save_image (out , res )
523
394
524
395
525
396
def predict_image (
0 commit comments