Skip to content

Commit 9332de8

Browse files
Merge pull request #201 from bioimage-io/more-build-spec-tests2
Extend tests for build_spec functionality
2 parents 1309b55 + 963b445 commit 9332de8

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

bioimageio/core/build_spec/add_weights.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def add_weights(
6363
except Exception as e:
6464
raise e
6565
finally:
66+
# clean up tmp files
67+
os.remove(weight_out)
6668
if tmp_arch is not None:
6769
os.remove(tmp_arch)
70+
# for some reason the weights are also copied to the cwd.
71+
# not sure why this happens, but it needs to be cleaned up, unless these are the input weigths
72+
weights_cwd = Path(os.path.split(weight_uri)[1])
73+
if weights_cwd.exists() and weights_cwd.absolute() != Path(weight_uri).absolute():
74+
os.remove(weights_cwd)
6875
return model

tests/build_spec/test_add_weights.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import os
12
from bioimageio.core import export_resource_package, load_raw_resource_description, load_resource_description
3+
from bioimageio.core.resource_tests import test_model as _test_model
24

35

46
def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs):
@@ -28,6 +30,13 @@ def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs):
2830
for weight in new_rdf.weights.values():
2931
assert weight.source.exists()
3032

33+
test_res = _test_model(out_path, added_weights)
34+
test_res = _test_model(out_path)
35+
assert test_res["error"] is None
36+
37+
# make sure the weights were cleaned from the cwd
38+
assert not os.path.exists(os.path.split(weight_path)[1])
39+
3140

3241
def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path):
3342
_test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript")

tests/build_spec/test_build_spec.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from bioimageio.core import load_raw_resource_description, load_resource_description
44
from bioimageio.core.resource_io import nodes
55
from bioimageio.core.resource_io.utils import resolve_source
6+
from bioimageio.core.resource_tests import test_model as _test_model
67
from marshmallow import missing
78

89

@@ -81,17 +82,19 @@ def _test_build_spec(
8182
output_path=out_path,
8283
add_deepimagej_config=add_deepimagej_config,
8384
maintainers=[{"github_user": "jane_doe"}],
85+
input_names=[inp.name for inp in model_spec.inputs],
86+
output_names=[out.name for out in model_spec.outputs],
8487
)
8588
if architecture is not None:
8689
kwargs["architecture"] = architecture
8790
if model_kwargs is not None:
88-
kwargs["kwargs"] = model_kwargs
91+
kwargs["model_kwargs"] = model_kwargs
8992
if tensorflow_version is not None:
9093
kwargs["tensorflow_version"] = tensorflow_version
9194
if opset_version is not None:
9295
kwargs["opset_version"] = opset_version
9396
if use_implicit_output_shape:
94-
kwargs["input_name"] = ["input"]
97+
kwargs["input_names"] = ["input"]
9598
kwargs["output_reference"] = ["input"]
9699
kwargs["output_scale"] = [[1.0, 1.0, 1.0, 1.0]]
97100
kwargs["output_offset"] = [[0.0, 0.0, 0.0, 0.0]]
@@ -134,6 +137,10 @@ def _test_build_spec(
134137
assert attachments.files is not missing
135138
assert n_processing == len(attachments.files)
136139

140+
# test inference for the model to ensure that the weights were written correctly
141+
test_res = _test_model(out_path)
142+
assert test_res["error"] is None
143+
137144

138145
def test_build_spec_pytorch(any_torch_model, tmp_path):
139146
_test_build_spec(any_torch_model, tmp_path / "model.zip", "pytorch_state_dict")

tests/weight_converter/keras/test_tensorflow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import zipfile
2+
3+
14
def test_tensorflow_converter(any_keras_model, tmp_path):
25
from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle
36

@@ -16,3 +19,9 @@ def test_tensorflow_converter_zipped(any_keras_model, tmp_path):
1619
ret_val = convert_weights_to_tensorflow_saved_model_bundle(any_keras_model, out_path)
1720
assert out_path.exists()
1821
assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes
22+
23+
# make sure that the zip package was created correctly
24+
expected_names = {"saved_model.pb", "variables/variables.index"}
25+
with zipfile.ZipFile(out_path, "r") as f:
26+
names = set([name for name in f.namelist()])
27+
assert len(expected_names - names) == 0

0 commit comments

Comments
 (0)