1
1
from pathlib import Path
2
- from typing import Any , Dict , Optional , Sequence , Union
2
+ from typing import Any , Optional , Sequence , Union
3
3
4
4
import imageio
5
- import numpy as np
6
- import xarray as xr
7
- from loguru import logger
8
5
from numpy .typing import NDArray
9
- from typing_extensions import assert_never
10
6
11
- from bioimageio .spec .model import AnyModelDescr , v0_4 , v0_5
12
7
from bioimageio .spec .utils import load_array , save_array
13
8
14
9
from .axis import Axis , AxisLike
15
- from .common import MemberId , PerMember , SampleId
16
- from .digest_spec import get_axes_infos , get_member_id
17
10
from .sample import Sample
18
- from .stat_measures import Stat
19
11
from .tensor import Tensor
20
12
21
13
@@ -40,27 +32,6 @@ def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor
40
32
return Tensor .from_numpy (array , dims = axes )
41
33
42
34
43
- def get_tensor (
44
- src : Union [Tensor , xr .DataArray , NDArray [Any ], Path ],
45
- ipt : Union [v0_4 .InputTensorDescr , v0_5 .InputTensorDescr ],
46
- ):
47
- """helper to cast/load various tensor sources"""
48
-
49
- if isinstance (src , Tensor ):
50
- return src
51
-
52
- if isinstance (src , xr .DataArray ):
53
- return Tensor .from_xarray (src )
54
-
55
- if isinstance (src , np .ndarray ):
56
- return Tensor .from_numpy (src , dims = get_axes_infos (ipt ))
57
-
58
- if isinstance (src , Path ):
59
- return load_tensor (src , axes = get_axes_infos (ipt ))
60
-
61
- assert_never (src )
62
-
63
-
64
35
def save_tensor (path : Path , tensor : Tensor ) -> None :
65
36
# TODO: save axis meta data
66
37
data : NDArray [Any ] = tensor .data .to_numpy ()
@@ -82,47 +53,3 @@ def save_sample(path: Union[Path, str], sample: Sample) -> None:
82
53
83
54
for m , t in sample .members .items ():
84
55
save_tensor (Path (path .format (member_id = m )), t )
85
-
86
-
87
- def load_sample_for_model (
88
- * ,
89
- model : AnyModelDescr ,
90
- paths : PerMember [Path ],
91
- axes : Optional [PerMember [Sequence [AxisLike ]]] = None ,
92
- stat : Optional [Stat ] = None ,
93
- sample_id : Optional [SampleId ] = None ,
94
- ):
95
- """load a single sample from `paths` that can be processed by `model`"""
96
-
97
- if axes is None :
98
- axes = {}
99
-
100
- # make sure members are keyed by MemberId, not string
101
- paths = {MemberId (k ): v for k , v in paths .items ()}
102
- axes = {MemberId (k ): v for k , v in axes .items ()}
103
-
104
- model_inputs = {get_member_id (d ): d for d in model .inputs }
105
-
106
- if unknown := {k for k in paths if k not in model_inputs }:
107
- raise ValueError (f"Got unexpected paths for { unknown } " )
108
-
109
- if unknown := {k for k in axes if k not in model_inputs }:
110
- raise ValueError (f"Got unexpected axes hints for: { unknown } " )
111
-
112
- members : Dict [MemberId , Tensor ] = {}
113
- for m , p in paths .items ():
114
- if m not in axes :
115
- axes [m ] = get_axes_infos (model_inputs [m ])
116
- logger .warning (
117
- "loading paths with {}'s default input axes {} for input '{}'" ,
118
- axes [m ],
119
- model .id or model .name ,
120
- m ,
121
- )
122
- members [m ] = load_tensor (p , axes [m ])
123
-
124
- return Sample (
125
- members = members ,
126
- stat = {} if stat is None else stat ,
127
- id = sample_id or tuple (sorted (paths .values ())),
128
- )
0 commit comments