Skip to content

Commit fbbfc57

Browse files
committed
Fix weight converters and return their corresponding v5 weight descr
1 parent 9e1e1fe commit fbbfc57

File tree

7 files changed

+102
-113
lines changed

7 files changed

+102
-113
lines changed

bioimageio/core/weight_converter/keras/_tensorflow.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from typing import no_type_check
66
from zipfile import ZipFile
77

8+
from bioimageio.spec._internal.version_type import Version
9+
from bioimageio.spec.model import v0_5
10+
811
try:
912
import tensorflow.saved_model
1013
except Exception:
@@ -39,7 +42,7 @@ def _convert_tf1(
3942
input_name: str,
4043
output_name: str,
4144
zip_weights: bool,
42-
):
45+
) -> v0_5.TensorflowSavedModelBundleWeightsDescr:
4346
try:
4447
# try to build the tf model with the keras import from tensorflow
4548
from bioimageio.core.weight_converter.keras._tensorflow import (
@@ -77,10 +80,16 @@ def build_tf_model():
7780
output_path = _zip_model_bundle(output_path)
7881
print("TensorFlow model exported to", output_path)
7982

80-
return 0
83+
return v0_5.TensorflowSavedModelBundleWeightsDescr(
84+
source=output_path,
85+
parent="keras_hdf5",
86+
tensorflow_version=Version(tensorflow.__version__),
87+
)
8188

8289

83-
def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool):
90+
def _convert_tf2(
91+
keras_weight_path: Path, output_path: Path, zip_weights: bool
92+
) -> v0_5.TensorflowSavedModelBundleWeightsDescr:
8493
try:
8594
# try to build the tf model with the keras import from tensorflow
8695
from bioimageio.core.weight_converter.keras._tensorflow import keras
@@ -95,12 +104,16 @@ def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool):
95104
output_path = _zip_model_bundle(output_path)
96105
print("TensorFlow model exported to", output_path)
97106

98-
return 0
107+
return v0_5.TensorflowSavedModelBundleWeightsDescr(
108+
source=output_path,
109+
parent="keras_hdf5",
110+
tensorflow_version=tensorflow.__version__,
111+
)
99112

100113

101114
def convert_weights_to_tensorflow_saved_model_bundle(
102115
model: ModelDescr, output_path: Path
103-
):
116+
) -> v0_5.TensorflowSavedModelBundleWeightsDescr:
104117
"""Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'.
105118
106119
Adapted from

bioimageio/core/weight_converter/torch/_onnx.py

+19-27
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# type: ignore # TODO: type
2-
import warnings
2+
from __future__ import annotations
33
from pathlib import Path
4-
from typing import Any, List, Sequence, cast
4+
from typing import Any, List, Sequence, cast, Union
55

66
import numpy as np
77
from numpy.testing import assert_array_almost_equal
88

9-
from bioimageio.spec import load_description
10-
from bioimageio.spec.common import InvalidDescr
119
from bioimageio.spec.model import v0_4, v0_5
1210

1311
from ...digest_spec import get_member_id, get_test_inputs
@@ -19,15 +17,15 @@
1917
torch = None
2018

2119

22-
def add_onnx_weights(
23-
model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr",
20+
def convert_weights_to_onnx(
21+
model_spec: Union[v0_4.ModelDescr, v0_5.ModelDescr],
2422
*,
2523
output_path: Path,
2624
use_tracing: bool = True,
2725
test_decimal: int = 4,
2826
verbose: bool = False,
29-
opset_version: "int | None" = None,
30-
):
27+
opset_version: int = 15,
28+
) -> v0_5.OnnxWeightsDescr:
3129
"""Convert model weights from format 'pytorch_state_dict' to 'onnx'.
3230
3331
Args:
@@ -36,16 +34,6 @@ def add_onnx_weights(
3634
use_tracing: whether to use tracing or scripting to export the onnx format
3735
test_decimal: precision for testing whether the results agree
3836
"""
39-
if isinstance(model_spec, (str, Path)):
40-
loaded_spec = load_description(Path(model_spec))
41-
if isinstance(loaded_spec, InvalidDescr):
42-
raise ValueError(f"Bad resource description: {loaded_spec}")
43-
if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)):
44-
raise TypeError(
45-
f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr"
46-
)
47-
model_spec = loaded_spec
48-
4937
state_dict_weights_descr = model_spec.weights.pytorch_state_dict
5038
if state_dict_weights_descr is None:
5139
raise ValueError(
@@ -54,9 +42,10 @@ def add_onnx_weights(
5442

5543
assert torch is not None
5644
with torch.no_grad():
57-
5845
sample = get_test_inputs(model_spec)
59-
input_data = [sample[get_member_id(ipt)].data.data for ipt in model_spec.inputs]
46+
input_data = [
47+
sample.members[get_member_id(ipt)].data.data for ipt in model_spec.inputs
48+
]
6049
input_tensors = [torch.from_numpy(ipt) for ipt in input_data]
6150
model = load_torch_model(state_dict_weights_descr)
6251

@@ -81,9 +70,9 @@ def add_onnx_weights(
8170
try:
8271
import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
8372
except ImportError:
84-
msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
85-
warnings.warn(msg)
86-
return
73+
raise ImportError(
74+
"The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
75+
)
8776

8877
# check the onnx model
8978
sess = rt.InferenceSession(str(output_path))
@@ -101,8 +90,11 @@ def add_onnx_weights(
10190
try:
10291
for exp, out in zip(expected_outputs, outputs):
10392
assert_array_almost_equal(exp, out, decimal=test_decimal)
104-
return 0
10593
except AssertionError as e:
106-
msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}"
107-
warnings.warn(msg)
108-
return 1
94+
raise ValueError(
95+
f"Results before and after weights conversion do not agree:\n {str(e)}"
96+
)
97+
98+
return v0_5.OnnxWeightsDescr(
99+
source=output_path, parent="pytorch_state_dict", opset_version=opset_version
100+
)

bioimageio/core/weight_converter/torch/_torchscript.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# type: ignore # TODO: type
2+
from __future__ import annotations
23
from pathlib import Path
34
from typing import List, Sequence, Union
45

56
import numpy as np
67
from numpy.testing import assert_array_almost_equal
8+
from torch.jit import ScriptModule
79
from typing_extensions import Any, assert_never
810

911
from bioimageio.spec.model import v0_4, v0_5
@@ -17,12 +19,11 @@
1719
torch = None
1820

1921

20-
# FIXME: remove Any
2122
def _check_predictions(
2223
model: Any,
2324
scripted_model: Any,
24-
model_spec: "v0_4.ModelDescr | v0_5.ModelDescr",
25-
input_data: Sequence["torch.Tensor"],
25+
model_spec: v0_4.ModelDescr | v0_5.ModelDescr,
26+
input_data: Sequence[torch.Tensor],
2627
):
2728
assert torch is not None
2829

@@ -77,22 +78,27 @@ def _check(input_: Sequence[torch.Tensor]) -> None:
7778
else:
7879
assert_never(axis.size)
7980

80-
half_step = [st // 2 for st in step]
81+
input_data = input_data[0]
82+
max_shape = input_data.shape
8183
max_steps = 4
8284

8385
# check that input and output agree for decreasing input sizes
8486
for step_factor in range(1, max_steps + 1):
8587
slice_ = tuple(
86-
slice(None) if st == 0 else slice(step_factor * st, -step_factor * st)
87-
for st in half_step
88-
)
89-
this_input = [inp[slice_] for inp in input_data]
90-
this_shape = this_input[0].shape
91-
if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)):
92-
raise ValueError(
93-
f"Mismatched shapes: {this_shape}. Expected at least {min_shape}"
88+
(
89+
slice(None)
90+
if step_dim == 0
91+
else slice(0, max_dim - step_factor * step_dim, 1)
9492
)
95-
_check(this_input)
93+
for max_dim, step_dim in zip(max_shape, step)
94+
)
95+
sliced_input = input_data[slice_]
96+
if any(
97+
sliced_dim < min_dim
98+
for sliced_dim, min_dim in zip(sliced_input.shape, min_shape)
99+
):
100+
return
101+
_check([sliced_input])
96102

97103

98104
def convert_weights_to_torchscript(
@@ -107,7 +113,6 @@ def convert_weights_to_torchscript(
107113
output_path: where to save the torchscript weights
108114
use_tracing: whether to use tracing or scripting to export the torchscript format
109115
"""
110-
111116
state_dict_weights_descr = model_descr.weights.pytorch_state_dict
112117
if state_dict_weights_descr is None:
113118
raise ValueError(
@@ -118,26 +123,20 @@ def convert_weights_to_torchscript(
118123

119124
with torch.no_grad():
120125
input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data]
121-
122126
model = load_torch_model(state_dict_weights_descr)
123-
124-
# FIXME: remove Any
125-
if use_tracing:
126-
scripted_model: Any = torch.jit.trace(model, input_data)
127-
else:
128-
scripted_model: Any = torch.jit.script(model)
129-
127+
scripted_module: ScriptModule = (
128+
torch.jit.trace(model, input_data)
129+
if use_tracing
130+
else torch.jit.script(model)
131+
)
130132
_check_predictions(
131133
model=model,
132-
scripted_model=scripted_model,
134+
scripted_model=scripted_module,
133135
model_spec=model_descr,
134136
input_data=input_data,
135137
)
136138

137-
# save the torchscript model
138-
scripted_model.save(
139-
str(output_path)
140-
) # does not support Path, so need to cast to str
139+
scripted_module.save(str(output_path))
141140

142141
return v0_5.TorchscriptWeightsDescr(
143142
source=output_path,

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
extras_require={
4848
"pytorch": ["torch>=1.6", "torchvision", "keras>=3.0"],
4949
"tensorflow": ["tensorflow", "keras>=2.15"],
50-
"onnx": ["onnxruntime"],
50+
"onnx": ["onnxruntime", "onnx"],
5151
"dev": [
5252
"black",
5353
# "crick", # currently requires python<=3.9

tests/weight_converter/keras/test_tensorflow.py

+12-28
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,33 @@
33
from pathlib import Path
44

55
import pytest
6-
76
from bioimageio.spec import load_description
8-
from bioimageio.spec.model.v0_5 import ModelDescr
7+
from bioimageio.spec.model import v0_5
98

9+
from bioimageio.core.weight_converter.keras._tensorflow import (
10+
convert_weights_to_tensorflow_saved_model_bundle,
11+
)
1012

11-
@pytest.mark.skip(
12-
"tensorflow converter not updated yet"
13-
) # TODO: test tensorflow converter
14-
def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path):
15-
from bioimageio.core.weight_converter.keras import (
16-
convert_weights_to_tensorflow_saved_model_bundle,
17-
)
1813

19-
out_path = tmp_path / "weights"
14+
@pytest.mark.skip()
15+
def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path):
2016
model = load_description(any_keras_model)
21-
assert isinstance(model, ModelDescr), model.validation_summary.format()
17+
out_path = tmp_path / "weights.h5"
2218
ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path)
2319
assert out_path.exists()
24-
assert (out_path / "variables").exists()
25-
assert (out_path / "saved_model.pb").exists()
26-
assert (
27-
ret_val == 0
28-
) # check for correctness is done in converter and returns 0 if it passes
20+
assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr)
21+
assert ret_val.source == out_path
2922

3023

31-
@pytest.mark.skip(
32-
"tensorflow converter not updated yet"
33-
) # TODO: test tensorflow converter
24+
@pytest.mark.skip()
3425
def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path):
35-
from bioimageio.core.weight_converter.keras import (
36-
convert_weights_to_tensorflow_saved_model_bundle,
37-
)
38-
3926
out_path = tmp_path / "weights.zip"
4027
model = load_description(any_keras_model)
41-
assert isinstance(model, ModelDescr), model.validation_summary.format()
4228
ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path)
29+
4330
assert out_path.exists()
44-
assert (
45-
ret_val == 0
46-
) # check for correctness is done in converter and returns 0 if it passes
31+
assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr)
4732

48-
# make sure that the zip package was created correctly
4933
expected_names = {"saved_model.pb", "variables/variables.index"}
5034
with zipfile.ZipFile(out_path, "r") as f:
5135
names = set([name for name in f.namelist()])
+15-10
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
# type: ignore # TODO enable type checking
22
import os
3-
from pathlib import Path
43

5-
import pytest
4+
from bioimageio.spec import load_description
5+
from bioimageio.spec.model import v0_5
66

7+
from bioimageio.core.weight_converter.torch._onnx import convert_weights_to_onnx
78

8-
@pytest.mark.skip("onnx converter not updated yet") # TODO: test onnx converter
9-
def test_onnx_converter(convert_to_onnx: Path, tmp_path: Path):
10-
from bioimageio.core.weight_converter.torch._onnx import convert_weights_to_onnx
119

10+
def test_onnx_converter(convert_to_onnx, tmp_path):
11+
bio_model = load_description(convert_to_onnx)
1212
out_path = tmp_path / "weights.onnx"
13-
ret_val = convert_weights_to_onnx(convert_to_onnx, out_path, test_decimal=3)
13+
opset_version = 15
14+
ret_val = convert_weights_to_onnx(
15+
model_spec=bio_model,
16+
output_path=out_path,
17+
test_decimal=3,
18+
opset_version=opset_version,
19+
)
1420
assert os.path.exists(out_path)
15-
if not pytest.skip_onnx:
16-
assert (
17-
ret_val == 0
18-
) # check for correctness is done in converter and returns 0 if it passes
21+
assert isinstance(ret_val, v0_5.OnnxWeightsDescr)
22+
assert ret_val.opset_version == opset_version
23+
assert ret_val.source == out_path
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
# type: ignore # TODO enable type checking
2-
from pathlib import Path
3-
42
import pytest
3+
from bioimageio.spec import load_description
4+
from bioimageio.spec.model import v0_5
55

6-
from bioimageio.spec.model import v0_4, v0_5
7-
6+
from bioimageio.core.weight_converter.torch._torchscript import (
7+
convert_weights_to_torchscript,
8+
)
89

9-
@pytest.mark.skip(
10-
"torchscript converter not updated yet"
11-
) # TODO: test torchscript converter
12-
def test_torchscript_converter(
13-
any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path
14-
):
15-
from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript
1610

11+
@pytest.mark.skip()
12+
def test_torchscript_converter(any_torch_model, tmp_path):
13+
bio_model = load_description(any_torch_model)
1714
out_path = tmp_path / "weights.pt"
18-
ret_val = convert_weights_to_torchscript(any_torch_model, out_path)
15+
ret_val = convert_weights_to_torchscript(bio_model, out_path)
1916
assert out_path.exists()
20-
assert (
21-
ret_val == 0
22-
) # check for correctness is done in converter and returns 0 if it passes
17+
assert isinstance(ret_val, v0_5.TorchscriptWeightsDescr)
18+
assert ret_val.source == out_path

0 commit comments

Comments
 (0)