Skip to content

Commit 152fb72

Browse files
committed
resolve cyclic dependency
1 parent 1bff5d8 commit 152fb72

File tree

2 files changed

+72
-76
lines changed

2 files changed

+72
-76
lines changed

bioimageio/core/digest_spec.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717
Union,
1818
)
1919

20+
import numpy as np
2021
import xarray as xr
22+
from loguru import logger
2123
from numpy.typing import NDArray
2224
from typing_extensions import Unpack, assert_never
2325

26+
from bioimageio.core.common import MemberId, PerMember, SampleId
27+
from bioimageio.core.io import load_tensor
28+
from bioimageio.core.sample import Sample
2429
from bioimageio.spec._internal.io_utils import HashKwargs, download
2530
from bioimageio.spec.common import FileSource
2631
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
@@ -32,10 +37,9 @@
3237
)
3338
from bioimageio.spec.utils import load_array
3439

35-
from .axis import AxisId, AxisInfo, PerAxis
40+
from .axis import AxisId, AxisInfo, AxisLike, PerAxis
3641
from .block_meta import split_multiple_shapes_into_blocks
3742
from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
38-
from .io import get_tensor
3943
from .sample import (
4044
LinearSampleAxisTransform,
4145
Sample,
@@ -332,6 +336,27 @@ def get_io_sample_block_metas(
332336
)
333337

334338

339+
def get_tensor(
340+
src: Union[Tensor, xr.DataArray, NDArray[Any], Path],
341+
ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
342+
):
343+
"""helper to cast/load various tensor sources"""
344+
345+
if isinstance(src, Tensor):
346+
return src
347+
348+
if isinstance(src, xr.DataArray):
349+
return Tensor.from_xarray(src)
350+
351+
if isinstance(src, np.ndarray):
352+
return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
353+
354+
if isinstance(src, Path):
355+
return load_tensor(src, axes=get_axes_infos(ipt))
356+
357+
assert_never(src)
358+
359+
335360
def create_sample_for_model(
336361
model: AnyModelDescr,
337362
*,
@@ -371,3 +396,47 @@ def create_sample_for_model(
371396
stat={} if stat is None else stat,
372397
id=sample_id,
373398
)
399+
400+
401+
def load_sample_for_model(
402+
*,
403+
model: AnyModelDescr,
404+
paths: PerMember[Path],
405+
axes: Optional[PerMember[Sequence[AxisLike]]] = None,
406+
stat: Optional[Stat] = None,
407+
sample_id: Optional[SampleId] = None,
408+
):
409+
"""load a single sample from `paths` that can be processed by `model`"""
410+
411+
if axes is None:
412+
axes = {}
413+
414+
# make sure members are keyed by MemberId, not string
415+
paths = {MemberId(k): v for k, v in paths.items()}
416+
axes = {MemberId(k): v for k, v in axes.items()}
417+
418+
model_inputs = {get_member_id(d): d for d in model.inputs}
419+
420+
if unknown := {k for k in paths if k not in model_inputs}:
421+
raise ValueError(f"Got unexpected paths for {unknown}")
422+
423+
if unknown := {k for k in axes if k not in model_inputs}:
424+
raise ValueError(f"Got unexpected axes hints for: {unknown}")
425+
426+
members: Dict[MemberId, Tensor] = {}
427+
for m, p in paths.items():
428+
if m not in axes:
429+
axes[m] = get_axes_infos(model_inputs[m])
430+
logger.warning(
431+
"loading paths with {}'s default input axes {} for input '{}'",
432+
axes[m],
433+
model.id or model.name,
434+
m,
435+
)
436+
members[m] = load_tensor(p, axes[m])
437+
438+
return Sample(
439+
members=members,
440+
stat={} if stat is None else stat,
441+
id=sample_id or tuple(sorted(paths.values())),
442+
)

bioimageio/core/io.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
11
from pathlib import Path
2-
from typing import Any, Dict, Optional, Sequence, Union
2+
from typing import Any, Optional, Sequence, Union
33

44
import imageio
5-
import numpy as np
6-
import xarray as xr
7-
from loguru import logger
85
from numpy.typing import NDArray
9-
from typing_extensions import assert_never
106

11-
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
127
from bioimageio.spec.utils import load_array, save_array
138

149
from .axis import Axis, AxisLike
15-
from .common import MemberId, PerMember, SampleId
16-
from .digest_spec import get_axes_infos, get_member_id
1710
from .sample import Sample
18-
from .stat_measures import Stat
1911
from .tensor import Tensor
2012

2113

@@ -40,27 +32,6 @@ def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor
4032
return Tensor.from_numpy(array, dims=axes)
4133

4234

43-
def get_tensor(
44-
src: Union[Tensor, xr.DataArray, NDArray[Any], Path],
45-
ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
46-
):
47-
"""helper to cast/load various tensor sources"""
48-
49-
if isinstance(src, Tensor):
50-
return src
51-
52-
if isinstance(src, xr.DataArray):
53-
return Tensor.from_xarray(src)
54-
55-
if isinstance(src, np.ndarray):
56-
return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
57-
58-
if isinstance(src, Path):
59-
return load_tensor(src, axes=get_axes_infos(ipt))
60-
61-
assert_never(src)
62-
63-
6435
def save_tensor(path: Path, tensor: Tensor) -> None:
6536
# TODO: save axis meta data
6637
data: NDArray[Any] = tensor.data.to_numpy()
@@ -82,47 +53,3 @@ def save_sample(path: Union[Path, str], sample: Sample) -> None:
8253

8354
for m, t in sample.members.items():
8455
save_tensor(Path(path.format(member_id=m)), t)
85-
86-
87-
def load_sample_for_model(
88-
*,
89-
model: AnyModelDescr,
90-
paths: PerMember[Path],
91-
axes: Optional[PerMember[Sequence[AxisLike]]] = None,
92-
stat: Optional[Stat] = None,
93-
sample_id: Optional[SampleId] = None,
94-
):
95-
"""load a single sample from `paths` that can be processed by `model`"""
96-
97-
if axes is None:
98-
axes = {}
99-
100-
# make sure members are keyed by MemberId, not string
101-
paths = {MemberId(k): v for k, v in paths.items()}
102-
axes = {MemberId(k): v for k, v in axes.items()}
103-
104-
model_inputs = {get_member_id(d): d for d in model.inputs}
105-
106-
if unknown := {k for k in paths if k not in model_inputs}:
107-
raise ValueError(f"Got unexpected paths for {unknown}")
108-
109-
if unknown := {k for k in axes if k not in model_inputs}:
110-
raise ValueError(f"Got unexpected axes hints for: {unknown}")
111-
112-
members: Dict[MemberId, Tensor] = {}
113-
for m, p in paths.items():
114-
if m not in axes:
115-
axes[m] = get_axes_infos(model_inputs[m])
116-
logger.warning(
117-
"loading paths with {}'s default input axes {} for input '{}'",
118-
axes[m],
119-
model.id or model.name,
120-
m,
121-
)
122-
members[m] = load_tensor(p, axes[m])
123-
124-
return Sample(
125-
members=members,
126-
stat={} if stat is None else stat,
127-
id=sample_id or tuple(sorted(paths.values())),
128-
)

0 commit comments

Comments
 (0)