2
2
3
3
import importlib .util
4
4
from itertools import chain
5
+ from pathlib import Path
5
6
from typing import (
6
7
Any ,
7
8
Callable ,
16
17
Union ,
17
18
)
18
19
20
+ import numpy as np
21
+ import xarray as xr
22
+ from loguru import logger
19
23
from numpy .typing import NDArray
20
24
from typing_extensions import Unpack , assert_never
21
25
26
+ from bioimageio .core .common import MemberId , PerMember , SampleId
27
+ from bioimageio .core .io import load_tensor
28
+ from bioimageio .core .sample import Sample
22
29
from bioimageio .spec ._internal .io_utils import HashKwargs , download
23
30
from bioimageio .spec .common import FileSource
24
31
from bioimageio .spec .model import AnyModelDescr , v0_4 , v0_5
30
37
)
31
38
from bioimageio .spec .utils import load_array
32
39
33
- from .axis import AxisId , AxisInfo , PerAxis
40
+ from .axis import AxisId , AxisInfo , AxisLike , PerAxis
34
41
from .block_meta import split_multiple_shapes_into_blocks
35
42
from .common import Halo , MemberId , PerMember , SampleId , TotalNumberOfBlocks
36
43
from .sample import (
@@ -329,12 +336,35 @@ def get_io_sample_block_metas(
329
336
)
330
337
331
338
339
+ def get_tensor (
340
+ src : Union [Tensor , xr .DataArray , NDArray [Any ], Path ],
341
+ ipt : Union [v0_4 .InputTensorDescr , v0_5 .InputTensorDescr ],
342
+ ):
343
+ """helper to cast/load various tensor sources"""
344
+
345
+ if isinstance (src , Tensor ):
346
+ return src
347
+
348
+ if isinstance (src , xr .DataArray ):
349
+ return Tensor .from_xarray (src )
350
+
351
+ if isinstance (src , np .ndarray ):
352
+ return Tensor .from_numpy (src , dims = get_axes_infos (ipt ))
353
+
354
+ if isinstance (src , Path ):
355
+ return load_tensor (src , axes = get_axes_infos (ipt ))
356
+
357
+ assert_never (src )
358
+
359
+
332
360
def create_sample_for_model (
333
361
model : AnyModelDescr ,
334
362
* ,
335
363
stat : Optional [Stat ] = None ,
336
364
sample_id : SampleId = None ,
337
- inputs : Optional [PerMember [NDArray [Any ]]] = None , # TODO: make non-optional
365
+ inputs : Optional [
366
+ PerMember [Union [Tensor , xr .DataArray , NDArray [Any ], Path ]]
367
+ ] = None , # TODO: make non-optional
338
368
** kwargs : NDArray [Any ], # TODO: deprecate in favor of `inputs`
339
369
) -> Sample :
340
370
"""Create a sample from a single set of input(s) for a specific bioimage.io model
@@ -359,10 +389,54 @@ def create_sample_for_model(
359
389
360
390
return Sample (
361
391
members = {
362
- m : Tensor . from_numpy (inputs [m ], dims = get_axes_infos ( ipt ) )
392
+ m : get_tensor (inputs [m ], ipt )
363
393
for m , ipt in model_inputs .items ()
364
394
if m in inputs
365
395
},
366
396
stat = {} if stat is None else stat ,
367
397
id = sample_id ,
368
398
)
399
+
400
+
401
+ def load_sample_for_model (
402
+ * ,
403
+ model : AnyModelDescr ,
404
+ paths : PerMember [Path ],
405
+ axes : Optional [PerMember [Sequence [AxisLike ]]] = None ,
406
+ stat : Optional [Stat ] = None ,
407
+ sample_id : Optional [SampleId ] = None ,
408
+ ):
409
+ """load a single sample from `paths` that can be processed by `model`"""
410
+
411
+ if axes is None :
412
+ axes = {}
413
+
414
+ # make sure members are keyed by MemberId, not string
415
+ paths = {MemberId (k ): v for k , v in paths .items ()}
416
+ axes = {MemberId (k ): v for k , v in axes .items ()}
417
+
418
+ model_inputs = {get_member_id (d ): d for d in model .inputs }
419
+
420
+ if unknown := {k for k in paths if k not in model_inputs }:
421
+ raise ValueError (f"Got unexpected paths for { unknown } " )
422
+
423
+ if unknown := {k for k in axes if k not in model_inputs }:
424
+ raise ValueError (f"Got unexpected axes hints for: { unknown } " )
425
+
426
+ members : Dict [MemberId , Tensor ] = {}
427
+ for m , p in paths .items ():
428
+ if m not in axes :
429
+ axes [m ] = get_axes_infos (model_inputs [m ])
430
+ logger .warning (
431
+ "loading paths with {}'s default input axes {} for input '{}'" ,
432
+ axes [m ],
433
+ model .id or model .name ,
434
+ m ,
435
+ )
436
+ members [m ] = load_tensor (p , axes [m ])
437
+
438
+ return Sample (
439
+ members = members ,
440
+ stat = {} if stat is None else stat ,
441
+ id = sample_id or tuple (sorted (paths .values ())),
442
+ )
0 commit comments