Skip to content

Commit c59a1f8

Browse files
authored
Merge pull request #406 from bioimage-io/predict_cmd
add predict command
2 parents db331df + 7306f98 commit c59a1f8

27 files changed

+1486
-448
lines changed

README.md

+289-40
Large diffs are not rendered by default.

bioimageio/core/VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"version": "0.6.8"
2+
"version": "0.6.9"
33
}

bioimageio/core/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,20 @@
44

55
from bioimageio.spec import build_description as build_description
66
from bioimageio.spec import dump_description as dump_description
7+
from bioimageio.spec import load_dataset_description as load_dataset_description
78
from bioimageio.spec import load_description as load_description
89
from bioimageio.spec import (
910
load_description_and_validate_format_only as load_description_and_validate_format_only,
1011
)
12+
from bioimageio.spec import load_model_description as load_model_description
1113
from bioimageio.spec import save_bioimageio_package as save_bioimageio_package
1214
from bioimageio.spec import (
1315
save_bioimageio_package_as_folder as save_bioimageio_package_as_folder,
1416
)
1517
from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only
1618
from bioimageio.spec import validate_format as validate_format
1719

20+
from . import digest_spec as digest_spec
1821
from ._prediction_pipeline import PredictionPipeline as PredictionPipeline
1922
from ._prediction_pipeline import (
2023
create_prediction_pipeline as create_prediction_pipeline,
@@ -38,4 +41,4 @@
3841
# aliases
3942
test_resource = test_description
4043
load_resource = load_description
41-
load_model = load_description
44+
load_model = load_model_description

bioimageio/core/__main__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from bioimageio.core.commands import main
1+
from bioimageio.core.cli import Bioimageio
2+
3+
4+
def main():
5+
cli = Bioimageio() # pyright: ignore[reportCallIssue]
6+
cli.run()
7+
28

39
if __name__ == "__main__":
410
main()

bioimageio/core/_prediction_pipeline.py

+48-30
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def __init__(
5555
postprocessing: List[Processing],
5656
model_adapter: ModelAdapter,
5757
default_ns: Union[
58-
v0_5.ParameterizedSize.N,
59-
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
58+
v0_5.ParameterizedSize_N,
59+
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
6060
] = 10,
6161
default_batch_size: int = 1,
6262
) -> None:
@@ -179,40 +179,17 @@ def get_output_sample_id(self, input_sample_id: SampleId):
179179
self.model_description.id or self.model_description.name
180180
)
181181

182-
def predict_sample_with_blocking(
182+
def predict_sample_with_fixed_blocking(
183183
self,
184184
sample: Sample,
185+
input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
186+
*,
185187
skip_preprocessing: bool = False,
186188
skip_postprocessing: bool = False,
187-
ns: Optional[
188-
Union[
189-
v0_5.ParameterizedSize.N,
190-
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
191-
]
192-
] = None,
193-
batch_size: Optional[int] = None,
194189
) -> Sample:
195-
"""predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
196190
if not skip_preprocessing:
197191
self.apply_preprocessing(sample)
198192

199-
if isinstance(self.model_description, v0_4.ModelDescr):
200-
raise NotImplementedError(
201-
"predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}"
202-
)
203-
204-
ns = ns or self._default_ns
205-
if isinstance(ns, int):
206-
ns = {
207-
(ipt.id, a.id): ns
208-
for ipt in self.model_description.inputs
209-
for a in ipt.axes
210-
if isinstance(a.size, v0_5.ParameterizedSize)
211-
}
212-
input_block_shape = self.model_description.get_tensor_sizes(
213-
ns, batch_size or self._default_batch_size
214-
).inputs
215-
216193
n_blocks, input_blocks = sample.split_into_blocks(
217194
input_block_shape,
218195
halo=self._default_input_halo,
@@ -239,6 +216,47 @@ def predict_sample_with_blocking(
239216

240217
return predicted_sample
241218

219+
def predict_sample_with_blocking(
220+
self,
221+
sample: Sample,
222+
skip_preprocessing: bool = False,
223+
skip_postprocessing: bool = False,
224+
ns: Optional[
225+
Union[
226+
v0_5.ParameterizedSize_N,
227+
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
228+
]
229+
] = None,
230+
batch_size: Optional[int] = None,
231+
) -> Sample:
232+
"""predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
233+
234+
if isinstance(self.model_description, v0_4.ModelDescr):
235+
raise NotImplementedError(
236+
"`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
237+
+ f" {self.model_description.name}."
238+
+ " Consider using `predict_sample_with_fixed_blocking`"
239+
)
240+
241+
ns = ns or self._default_ns
242+
if isinstance(ns, int):
243+
ns = {
244+
(ipt.id, a.id): ns
245+
for ipt in self.model_description.inputs
246+
for a in ipt.axes
247+
if isinstance(a.size, v0_5.ParameterizedSize)
248+
}
249+
input_block_shape = self.model_description.get_tensor_sizes(
250+
ns, batch_size or self._default_batch_size
251+
).inputs
252+
253+
return self.predict_sample_with_fixed_blocking(
254+
sample,
255+
input_block_shape=input_block_shape,
256+
skip_preprocessing=skip_preprocessing,
257+
skip_postprocessing=skip_postprocessing,
258+
)
259+
242260
# def predict(
243261
# self,
244262
# inputs: Predict_IO,
@@ -310,8 +328,8 @@ def create_prediction_pipeline(
310328
),
311329
model_adapter: Optional[ModelAdapter] = None,
312330
ns: Union[
313-
v0_5.ParameterizedSize.N,
314-
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
331+
v0_5.ParameterizedSize_N,
332+
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
315333
] = 10,
316334
**deprecated_kwargs: Any,
317335
) -> PredictionPipeline:

bioimageio/core/_resource_tests.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import traceback
22
import warnings
33
from itertools import product
4-
from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union
4+
from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union
55

66
import numpy as np
77
from loguru import logger
@@ -57,7 +57,7 @@ def test_description(
5757
*,
5858
format_version: Union[Literal["discover", "latest"], str] = "discover",
5959
weight_format: Optional[WeightsFormat] = None,
60-
devices: Optional[List[str]] = None,
60+
devices: Optional[Sequence[str]] = None,
6161
absolute_tolerance: float = 1.5e-4,
6262
relative_tolerance: float = 1e-4,
6363
decimal: Optional[int] = None,
@@ -83,7 +83,7 @@ def load_description_and_test(
8383
*,
8484
format_version: Union[Literal["discover", "latest"], str] = "discover",
8585
weight_format: Optional[WeightsFormat] = None,
86-
devices: Optional[List[str]] = None,
86+
devices: Optional[Sequence[str]] = None,
8787
absolute_tolerance: float = 1.5e-4,
8888
relative_tolerance: float = 1e-4,
8989
decimal: Optional[int] = None,
@@ -138,12 +138,12 @@ def load_description_and_test(
138138
def _test_model_inference(
139139
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
140140
weight_format: WeightsFormat,
141-
devices: Optional[List[str]],
141+
devices: Optional[Sequence[str]],
142142
absolute_tolerance: float,
143143
relative_tolerance: float,
144144
decimal: Optional[int],
145145
) -> None:
146-
test_name = "Reproduce test outputs from test inputs"
146+
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
147147
logger.info("starting '{}'", test_name)
148148
error: Optional[str] = None
149149
tb: List[str] = []
@@ -209,15 +209,15 @@ def _test_model_inference(
209209
def _test_model_inference_parametrized(
210210
model: v0_5.ModelDescr,
211211
weight_format: WeightsFormat,
212-
devices: Optional[List[str]],
212+
devices: Optional[Sequence[str]],
213213
) -> None:
214214
if not any(
215215
isinstance(a.size, v0_5.ParameterizedSize)
216216
for ipt in model.inputs
217217
for a in ipt.axes
218218
):
219219
# no parameterized sizes => set n=0
220-
ns: Set[v0_5.ParameterizedSize.N] = {0}
220+
ns: Set[v0_5.ParameterizedSize_N] = {0}
221221
else:
222222
ns = {0, 1, 2}
223223

@@ -236,7 +236,7 @@ def _test_model_inference_parametrized(
236236
# no batch axis
237237
batch_sizes = {1}
238238

239-
test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = {
239+
test_cases: Set[Tuple[v0_5.ParameterizedSize_N, BatchSize]] = {
240240
(n, b) for n, b in product(sorted(ns), sorted(batch_sizes))
241241
}
242242
logger.info(

bioimageio/core/axis.py

+6-21
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,6 @@ def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]):
2626
S = TypeVar("S", bound=str)
2727

2828

29-
def _get_axis_id(a: Union[Literal["b", "t", "i", "c"], S]):
30-
if a == "b":
31-
return AxisId("batch")
32-
elif a == "t":
33-
return AxisId("time")
34-
elif a == "i":
35-
return AxisId("index")
36-
elif a == "c":
37-
return AxisId("channel")
38-
else:
39-
return AxisId(a)
40-
41-
4229
AxisId = v0_5.AxisId
4330

4431
T = TypeVar("T")
@@ -47,7 +34,7 @@ def _get_axis_id(a: Union[Literal["b", "t", "i", "c"], S]):
4734
BatchSize = int
4835

4936
AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
50-
AxisLike = Union[AxisLetter, v0_5.AnyAxis, "Axis"]
37+
AxisLike = Union[AxisId, AxisLetter, v0_5.AnyAxis, "Axis"]
5138

5239

5340
@dataclass
@@ -62,7 +49,7 @@ def create(cls, axis: AxisLike) -> Axis:
6249
elif isinstance(axis, Axis):
6350
return Axis(id=axis.id, type=axis.type)
6451
elif isinstance(axis, str):
65-
return Axis(id=_get_axis_id(axis), type=_get_axis_type(axis))
52+
return Axis(id=AxisId(axis), type=_get_axis_type(axis))
6653
elif isinstance(axis, v0_5.AxisBase):
6754
return Axis(id=AxisId(axis.id), type=axis.type)
6855
else:
@@ -71,7 +58,7 @@ def create(cls, axis: AxisLike) -> Axis:
7158

7259
@dataclass
7360
class AxisInfo(Axis):
74-
maybe_singleton: bool
61+
maybe_singleton: bool # TODO: replace 'maybe_singleton' with size min/max for better axis guessing
7562

7663
@classmethod
7764
def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo:
@@ -80,18 +67,16 @@ def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisI
8067

8168
axis_base = super().create(axis)
8269
if maybe_singleton is None:
83-
if isinstance(axis, Axis):
84-
maybe_singleton = False
85-
elif isinstance(axis, str):
86-
maybe_singleton = axis == "b"
70+
if isinstance(axis, (Axis, str)):
71+
maybe_singleton = True
8772
else:
8873
if axis.size is None:
8974
maybe_singleton = True
9075
elif isinstance(axis.size, int):
9176
maybe_singleton = axis.size == 1
9277
elif isinstance(axis.size, v0_5.SizeReference):
9378
maybe_singleton = (
94-
False # TODO: check if singleton is ok for a `SizeReference`
79+
True # TODO: check if singleton is ok for a `SizeReference`
9580
)
9681
elif isinstance(
9782
axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize)

0 commit comments

Comments
 (0)