Skip to content

Commit 2c1db80

Browse files
committed
update build_model and tests
1 parent 348dc92 commit 2c1db80

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

bioimageio/core/build_spec/build_model.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import bioimageio.spec.model as model_spec
1313
from bioimageio.core import export_resource_package, load_raw_resource_description
1414
from bioimageio.core.resource_io.nodes import URI
15-
from bioimageio.core.resource_io.utils import resolve_local_source, resolve_source
1615
from bioimageio.spec.shared.raw_nodes import ImportableModule, ImportableSourceFile
16+
from bioimageio.spec.shared.utils import resolve_local_source, resolve_source
1717

1818
try:
1919
from typing import get_args
@@ -81,6 +81,7 @@ def _get_weights(
8181
model_kwargs=None,
8282
tensorflow_version=None,
8383
opset_version=None,
84+
dependencies=None,
8485
**kwargs,
8586
):
8687
weight_path = resolve_source(original_weight_source, root)
@@ -100,6 +101,8 @@ def _get_weights(
100101
weights = model_spec.raw_nodes.PytorchStateDictWeightsEntry(
101102
source=weight_source, sha256=weight_hash, **weight_kwargs
102103
)
104+
if dependencies is not None:
105+
weight_kwargs["dependencies"] = _get_dependencies(dependencies, root)
103106

104107
elif weight_type == "onnx":
105108
if opset_version is None:
@@ -745,6 +748,7 @@ def build_model(
745748
model_kwargs,
746749
tensorflow_version=tensorflow_version,
747750
opset_version=opset_version,
751+
dependencies=dependencies,
748752
**weight_kwargs,
749753
)
750754

@@ -813,8 +817,7 @@ def build_model(
813817

814818
if attachments is not None:
815819
kwargs["attachments"] = spec.rdf.raw_nodes.Attachments(**attachments)
816-
if dependencies is not None:
817-
kwargs["dependencies"] = _get_dependencies(dependencies, root)
820+
818821
if maintainers is not None:
819822
kwargs["maintainers"] = [model_spec.raw_nodes.Maintainer(**m) for m in maintainers]
820823
if parent is not None:

tests/build_spec/test_build_spec.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ def _test_build_spec(
1919
):
2020
from bioimageio.core.build_spec import build_model
2121

22-
model_spec = load_raw_resource_description(spec_path)
22+
model_spec = load_raw_resource_description(spec_path, update_to_format="latest")
2323
root = model_spec.root_path
2424
assert isinstance(model_spec, spec.model.raw_nodes.Model)
2525
weight_source = model_spec.weights[weight_type].source
2626

2727
cite = {entry.text: entry.doi if entry.url is missing else entry.url for entry in model_spec.cite}
2828

29+
dep_file = None
2930
if weight_type == "pytorch_state_dict":
3031
weight_spec = model_spec.weights["pytorch_state_dict"]
3132
model_kwargs = None if weight_spec.kwargs is missing else weight_spec.kwargs
@@ -35,6 +36,7 @@ def _test_build_spec(
3536
arch_path = os.path.abspath(os.path.join(root, arch_path))
3637
assert os.path.exists(arch_path)
3738
architecture = f"{arch_path}:{cls_name}"
39+
dep_file = None if weight_spec.dependencies is missing else resolve_source(weight_spec.dependencies.file, root)
3840
weight_type_ = None # the weight type can be auto-detected
3941
elif weight_type == "torchscript":
4042
architecture = None
@@ -45,7 +47,6 @@ def _test_build_spec(
4547
model_kwargs = None
4648
weight_type_ = None # the weight type can be auto-detected
4749

48-
dep_file = None if model_spec.dependencies is missing else resolve_source(model_spec.dependencies.file, root)
4950
authors = [{"name": auth.name, "affiliation": auth.affiliation} for auth in model_spec.authors]
5051

5152
input_axes = [input_.axes for input_ in model_spec.inputs]

tests/resource_io/test_load_rdf.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import os.path
22
import pathlib
33
from pathlib import Path
4-
from tempfile import NamedTemporaryFile
54

65
import pytest
7-
from marshmallow import ValidationError
86

97
from bioimageio.core.resource_io.utils import resolve_source
108

@@ -80,7 +78,7 @@ def test_load_remote_model_with_folders():
8078

8179
# todo: point to real model with nested folders, not this temporary sandbox one
8280
rdf_url = "https://sandbox.zenodo.org/record/892199/files/rdf.yaml"
83-
raw_model = load_raw_resource_description(rdf_url)
81+
raw_model = load_raw_resource_description(rdf_url, update_to_format="latest")
8482
assert isinstance(raw_model, raw_nodes.Model)
8583
model = load_resource_description(rdf_url)
8684
assert isinstance(model, nodes.Model)

0 commit comments

Comments
 (0)