Skip to content

Commit beadefd

Browse files
Extend tests for build_spec functionality
1 parent 55ed672 commit beadefd

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

tests/build_spec/test_add_weights.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from bioimageio.core import export_resource_package, load_raw_resource_description, load_resource_description
2+
from bioimageio.core.resource_tests import test_model as _test_model
23

34

45
def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs):
@@ -28,6 +29,10 @@ def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs):
2829
for weight in new_rdf.weights.values():
2930
assert weight.source.exists()
3031

32+
test_res = _test_model(out_path, added_weights)
33+
test_res = _test_model(out_path)
34+
assert test_res["error"] is None
35+
3136

3237
def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path):
3338
_test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript")

tests/build_spec/test_build_spec.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from bioimageio.core import load_raw_resource_description, load_resource_description
44
from bioimageio.core.resource_io import nodes
55
from bioimageio.core.resource_io.utils import resolve_source
6+
from bioimageio.core.resource_tests import test_model as _test_model
67
from marshmallow import missing
78

89

@@ -81,17 +82,19 @@ def _test_build_spec(
8182
output_path=out_path,
8283
add_deepimagej_config=add_deepimagej_config,
8384
maintainers=[{"github_user": "jane_doe"}],
85+
input_names=[inp.name for inp in model_spec.inputs],
86+
output_names=[out.name for out in model_spec.outputs],
8487
)
8588
if architecture is not None:
8689
kwargs["architecture"] = architecture
8790
if model_kwargs is not None:
88-
kwargs["kwargs"] = model_kwargs
91+
kwargs["model_kwargs"] = model_kwargs
8992
if tensorflow_version is not None:
9093
kwargs["tensorflow_version"] = tensorflow_version
9194
if opset_version is not None:
9295
kwargs["opset_version"] = opset_version
9396
if use_implicit_output_shape:
94-
kwargs["input_name"] = ["input"]
97+
kwargs["input_names"] = ["input"]
9598
kwargs["output_reference"] = ["input"]
9699
kwargs["output_scale"] = [[1.0, 1.0, 1.0, 1.0]]
97100
kwargs["output_offset"] = [[0.0, 0.0, 0.0, 0.0]]
@@ -134,6 +137,10 @@ def _test_build_spec(
134137
assert attachments.files is not missing
135138
assert n_processing == len(attachments.files)
136139

140+
# test inference for the model to ensure that the weights were written correctly
141+
test_res = _test_model(out_path)
142+
assert test_res["error"] is None
143+
137144

138145
def test_build_spec_pytorch(any_torch_model, tmp_path):
139146
_test_build_spec(any_torch_model, tmp_path / "model.zip", "pytorch_state_dict")

0 commit comments

Comments
 (0)