Skip to content

Commit ebe126b

Browse files
authored
Merge pull request #394 from bioimage-io/update_usage
Update model usage example
2 parents 2b645a7 + 6d6117a commit ebe126b

File tree

7 files changed

+163
-376
lines changed

7 files changed

+163
-376
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ The model specification and its validation tools can be found at <https://github
124124

125125
## Changelog
126126

127+
### 0.6.7
128+
129+
* `predict()` argument `inputs` may be sample
130+
127131
### 0.6.6
128132

129133
* add aliases to match previous API more closely

bioimageio/core/VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"version": "0.6.6"
2+
"version": "0.6.7"
33
}

bioimageio/core/io.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def save_sample(path: Union[Path, str], sample: Sample) -> None:
4747
`path` must contain `{member_id}` and may contain `{sample_id}`,
4848
which are resolved with the `sample` object.
4949
"""
50-
path = str(path).format(sample_id=sample.id)
5150
if "{member_id}" not in path:
5251
raise ValueError(f"missing `{{member_id}}` in path {path}")
5352

53+
path = str(path).format(sample_id=sample.id, member_id="{member_id}")
54+
5455
for m, t in sample.members.items():
5556
save_tensor(Path(path.format(member_id=m)), t)

bioimageio/core/prediction.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def predict(
3939
model: Union[
4040
PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
4141
],
42-
inputs: PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]],
42+
inputs: Union[Sample, PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]],
4343
sample_id: Hashable = "sample",
4444
blocksize_parameter: Optional[
4545
Union[
@@ -56,7 +56,7 @@ def predict(
5656
Args:
5757
model: model to predict with.
5858
May be given as RDF source, model description or prediction pipeline.
59-
inputs: the named input(s) for this model as a dictionary
59+
inputs: the input sample or the named input(s) for this model as a dictionary
6060
sample_id: the sample id.
6161
blocksize_parameter: (optional) tile the input into blocks parametrized by
6262
blocksize according to any parametrized axis sizes defined in the model RDF
@@ -82,9 +82,12 @@ def predict(
8282

8383
pp = create_prediction_pipeline(model)
8484

85-
sample = create_sample_for_model(
86-
pp.model_description, inputs=inputs, sample_id=sample_id
87-
)
85+
if isinstance(inputs, Sample):
86+
sample = inputs
87+
else:
88+
sample = create_sample_for_model(
89+
pp.model_description, inputs=inputs, sample_id=sample_id
90+
)
8891

8992
if blocksize_parameter is None:
9093
output = pp.predict_sample_without_blocking(

bioimageio/core/proc_setup.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,18 @@
1111

1212
from typing_extensions import assert_never
1313

14+
from bioimageio.core.common import MemberId
15+
from bioimageio.core.digest_spec import get_member_ids
1416
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
1517
from bioimageio.spec.model.v0_5 import TensorId
1618

17-
from .proc_ops import AddKnownDatasetStats, Processing, UpdateStats, get_proc_class
19+
from .proc_ops import (
20+
AddKnownDatasetStats,
21+
EnsureDtype,
22+
Processing,
23+
UpdateStats,
24+
get_proc_class,
25+
)
1826
from .sample import Sample
1927
from .stat_calculators import StatsCalculator
2028
from .stat_measures import DatasetMeasure, Measure, MeasureValue
@@ -87,12 +95,8 @@ def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcess
8795
pre_measures: Set[Measure] = set()
8896
post_measures: Set[Measure] = set()
8997

90-
if isinstance(model, v0_4.ModelDescr):
91-
input_ids = {TensorId(str(d.name)) for d in model.inputs}
92-
output_ids = {TensorId(str(d.name)) for d in model.outputs}
93-
else:
94-
input_ids = {d.id for d in model.inputs}
95-
output_ids = {d.id for d in model.outputs}
98+
input_ids = set(get_member_ids(model.inputs))
99+
output_ids = set(get_member_ids(model.outputs))
96100

97101
def prepare_procs(tensor_descrs: Sequence[TensorDescr]):
98102
procs: List[Processing] = []

example/dataset_creation.ipynb

-125
This file was deleted.

0 commit comments

Comments
 (0)