Skip to content

Commit 2b645a7

Browse files
authored
Merge pull request #393 from bioimage-io/backward_comp
Improve backward compatibility
2 parents 60251e9 + cb1154d commit 2b645a7

File tree

8 files changed

+280
-53
lines changed

8 files changed

+280
-53
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ The model specification and its validation tools can be found at <https://github
124124

125125
## Changelog
126126

127+
### 0.6.6
128+
129+
* add aliases to match previous API more closely
130+
127131
### 0.6.5
128132

129133
* improve adapter error messages

bioimageio/core/VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"version": "0.6.5"
2+
"version": "0.6.6"
33
}

bioimageio/core/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@
2727
from .axis import AxisId as AxisId
2828
from .block_meta import BlockMeta as BlockMeta
2929
from .common import MemberId as MemberId
30+
from .prediction import predict as predict
31+
from .prediction import predict_many as predict_many
3032
from .sample import Sample as Sample
3133
from .tensor import Tensor as Tensor
3234
from .utils import VERSION
3335

3436
__version__ = VERSION
3537

38+
# aliases
3639
test_resource = test_description
40+
load_resource = load_description
41+
load_model = load_description

bioimageio/core/commands.py

-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
load_description_and_validate_format_only,
1212
save_bioimageio_package,
1313
)
14-
from bioimageio.spec.collection import CollectionDescr
1514
from bioimageio.spec.dataset import DatasetDescr
1615
from bioimageio.spec.model import ModelDescr
1716
from bioimageio.spec.model.v0_5 import WeightsFormat
@@ -94,7 +93,6 @@ def validate_format(
9493
model RDF {ModelDescr.implemented_format_version}
9594
dataset RDF {DatasetDescr.implemented_format_version}
9695
notebook RDF {NotebookDescr.implemented_format_version}
97-
collection RDF {CollectionDescr.implemented_format_version}
9896
9997
"""
10098

bioimageio/core/digest_spec.py

+77-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import importlib.util
44
from itertools import chain
5+
from pathlib import Path
56
from typing import (
67
Any,
78
Callable,
@@ -16,9 +17,15 @@
1617
Union,
1718
)
1819

20+
import numpy as np
21+
import xarray as xr
22+
from loguru import logger
1923
from numpy.typing import NDArray
2024
from typing_extensions import Unpack, assert_never
2125

26+
from bioimageio.core.common import MemberId, PerMember, SampleId
27+
from bioimageio.core.io import load_tensor
28+
from bioimageio.core.sample import Sample
2229
from bioimageio.spec._internal.io_utils import HashKwargs, download
2330
from bioimageio.spec.common import FileSource
2431
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
@@ -30,7 +37,7 @@
3037
)
3138
from bioimageio.spec.utils import load_array
3239

33-
from .axis import AxisId, AxisInfo, PerAxis
40+
from .axis import AxisId, AxisInfo, AxisLike, PerAxis
3441
from .block_meta import split_multiple_shapes_into_blocks
3542
from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
3643
from .sample import (
@@ -329,12 +336,35 @@ def get_io_sample_block_metas(
329336
)
330337

331338

339+
def get_tensor(
340+
src: Union[Tensor, xr.DataArray, NDArray[Any], Path],
341+
ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
342+
):
343+
"""helper to cast/load various tensor sources"""
344+
345+
if isinstance(src, Tensor):
346+
return src
347+
348+
if isinstance(src, xr.DataArray):
349+
return Tensor.from_xarray(src)
350+
351+
if isinstance(src, np.ndarray):
352+
return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
353+
354+
if isinstance(src, Path):
355+
return load_tensor(src, axes=get_axes_infos(ipt))
356+
357+
assert_never(src)
358+
359+
332360
def create_sample_for_model(
333361
model: AnyModelDescr,
334362
*,
335363
stat: Optional[Stat] = None,
336364
sample_id: SampleId = None,
337-
inputs: Optional[PerMember[NDArray[Any]]] = None, # TODO: make non-optional
365+
inputs: Optional[
366+
PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]
367+
] = None, # TODO: make non-optional
338368
**kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs`
339369
) -> Sample:
340370
"""Create a sample from a single set of input(s) for a specific bioimage.io model
@@ -359,10 +389,54 @@ def create_sample_for_model(
359389

360390
return Sample(
361391
members={
362-
m: Tensor.from_numpy(inputs[m], dims=get_axes_infos(ipt))
392+
m: get_tensor(inputs[m], ipt)
363393
for m, ipt in model_inputs.items()
364394
if m in inputs
365395
},
366396
stat={} if stat is None else stat,
367397
id=sample_id,
368398
)
399+
400+
401+
def load_sample_for_model(
402+
*,
403+
model: AnyModelDescr,
404+
paths: PerMember[Path],
405+
axes: Optional[PerMember[Sequence[AxisLike]]] = None,
406+
stat: Optional[Stat] = None,
407+
sample_id: Optional[SampleId] = None,
408+
):
409+
"""load a single sample from `paths` that can be processed by `model`"""
410+
411+
if axes is None:
412+
axes = {}
413+
414+
# make sure members are keyed by MemberId, not string
415+
paths = {MemberId(k): v for k, v in paths.items()}
416+
axes = {MemberId(k): v for k, v in axes.items()}
417+
418+
model_inputs = {get_member_id(d): d for d in model.inputs}
419+
420+
if unknown := {k for k in paths if k not in model_inputs}:
421+
raise ValueError(f"Got unexpected paths for {unknown}")
422+
423+
if unknown := {k for k in axes if k not in model_inputs}:
424+
raise ValueError(f"Got unexpected axes hints for: {unknown}")
425+
426+
members: Dict[MemberId, Tensor] = {}
427+
for m, p in paths.items():
428+
if m not in axes:
429+
axes[m] = get_axes_infos(model_inputs[m])
430+
logger.warning(
431+
"loading paths with {}'s default input axes {} for input '{}'",
432+
axes[m],
433+
model.id or model.name,
434+
m,
435+
)
436+
members[m] = load_tensor(p, axes[m])
437+
438+
return Sample(
439+
members=members,
440+
stat={} if stat is None else stat,
441+
id=sample_id or tuple(sorted(paths.values())),
442+
)

bioimageio/core/io.py

+20-45
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
from pathlib import Path
2-
from typing import Any, Dict, Optional, Sequence
2+
from typing import Any, Optional, Sequence, Union
33

44
import imageio
5-
from loguru import logger
65
from numpy.typing import NDArray
76

8-
from bioimageio.spec.model import AnyModelDescr
9-
from bioimageio.spec.utils import load_array
7+
from bioimageio.spec.utils import load_array, save_array
108

119
from .axis import Axis, AxisLike
12-
from .common import MemberId, PerMember, SampleId
13-
from .digest_spec import get_axes_infos, get_member_id
1410
from .sample import Sample
15-
from .stat_measures import Stat
1611
from .tensor import Tensor
1712

1813

@@ -26,6 +21,7 @@ def load_image(path: Path, is_volume: bool) -> NDArray[Any]:
2621

2722

2823
def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor:
24+
# TODO: load axis meta data
2925
array = load_image(
3026
path,
3127
is_volume=(
@@ -36,45 +32,24 @@ def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor
3632
return Tensor.from_numpy(array, dims=axes)
3733

3834

39-
def load_sample_for_model(
40-
*,
41-
model: AnyModelDescr,
42-
paths: PerMember[Path],
43-
axes: Optional[PerMember[Sequence[AxisLike]]] = None,
44-
stat: Optional[Stat] = None,
45-
sample_id: Optional[SampleId] = None,
46-
):
47-
"""load a single sample from `paths` that can be processed by `model`"""
48-
49-
if axes is None:
50-
axes = {}
51-
52-
# make sure members are keyed by MemberId, not string
53-
paths = {MemberId(k): v for k, v in paths.items()}
54-
axes = {MemberId(k): v for k, v in axes.items()}
55-
56-
model_inputs = {get_member_id(d): d for d in model.inputs}
35+
def save_tensor(path: Path, tensor: Tensor) -> None:
36+
# TODO: save axis meta data
37+
data: NDArray[Any] = tensor.data.to_numpy()
38+
if path.suffix == ".npy":
39+
save_array(path, data)
40+
else:
41+
imageio.volwrite(path, data)
5742

58-
if unknown := {k for k in paths if k not in model_inputs}:
59-
raise ValueError(f"Got unexpected paths for {unknown}")
6043

61-
if unknown := {k for k in axes if k not in model_inputs}:
62-
raise ValueError(f"Got unexpected axes hints for: {unknown}")
44+
def save_sample(path: Union[Path, str], sample: Sample) -> None:
45+
"""save a sample to path
6346
64-
members: Dict[MemberId, Tensor] = {}
65-
for m, p in paths.items():
66-
if m not in axes:
67-
axes[m] = get_axes_infos(model_inputs[m])
68-
logger.warning(
69-
"loading paths with {}'s default input axes {} for input '{}'",
70-
axes[m],
71-
model.id or model.name,
72-
m,
73-
)
74-
members[m] = load_tensor(p, axes[m])
47+
`path` must contain `{member_id}` and may contain `{sample_id}`,
48+
which are resolved with the `sample` object.
49+
"""
50+
path = str(path).format(sample_id=sample.id)
51+
if "{member_id}" not in path:
52+
raise ValueError(f"missing `{{member_id}}` in path {path}")
7553

76-
return Sample(
77-
members=members,
78-
stat={} if stat is None else stat,
79-
id=sample_id or tuple(sorted(paths.values())),
80-
)
54+
for m, t in sample.members.items():
55+
save_tensor(Path(path.format(member_id=m)), t)

0 commit comments

Comments
 (0)