1
1
# type: ignore # TODO: type
2
- import warnings
2
+ from __future__ import annotations
3
3
from pathlib import Path
4
- from typing import Any , List , Sequence , cast
4
+ from typing import Any , List , Sequence , cast , Union
5
5
6
6
import numpy as np
7
7
from numpy .testing import assert_array_almost_equal
8
8
9
- from bioimageio .spec import load_description
10
- from bioimageio .spec .common import InvalidDescr
11
9
from bioimageio .spec .model import v0_4 , v0_5
12
10
13
11
from ...digest_spec import get_member_id , get_test_inputs
19
17
torch = None
20
18
21
19
22
- def add_onnx_weights (
23
- model_spec : "str | Path | v0_4.ModelDescr | v0_5.ModelDescr" ,
20
+ def convert_weights_to_onnx (
21
+ model_spec : Union [ v0_4 .ModelDescr , v0_5 .ModelDescr ] ,
24
22
* ,
25
23
output_path : Path ,
26
24
use_tracing : bool = True ,
27
25
test_decimal : int = 4 ,
28
26
verbose : bool = False ,
29
- opset_version : " int | None" = None ,
30
- ):
27
+ opset_version : int = 15 ,
28
+ ) -> v0_5 . OnnxWeightsDescr :
31
29
"""Convert model weights from format 'pytorch_state_dict' to 'onnx'.
32
30
33
31
Args:
@@ -36,16 +34,6 @@ def add_onnx_weights(
36
34
use_tracing: whether to use tracing or scripting to export the onnx format
37
35
test_decimal: precision for testing whether the results agree
38
36
"""
39
- if isinstance (model_spec , (str , Path )):
40
- loaded_spec = load_description (Path (model_spec ))
41
- if isinstance (loaded_spec , InvalidDescr ):
42
- raise ValueError (f"Bad resource description: { loaded_spec } " )
43
- if not isinstance (loaded_spec , (v0_4 .ModelDescr , v0_5 .ModelDescr )):
44
- raise TypeError (
45
- f"Path { model_spec } is a { loaded_spec .__class__ .__name__ } , expected a v0_4.ModelDescr or v0_5.ModelDescr"
46
- )
47
- model_spec = loaded_spec
48
-
49
37
state_dict_weights_descr = model_spec .weights .pytorch_state_dict
50
38
if state_dict_weights_descr is None :
51
39
raise ValueError (
@@ -54,9 +42,10 @@ def add_onnx_weights(
54
42
55
43
assert torch is not None
56
44
with torch .no_grad ():
57
-
58
45
sample = get_test_inputs (model_spec )
59
- input_data = [sample [get_member_id (ipt )].data .data for ipt in model_spec .inputs ]
46
+ input_data = [
47
+ sample .members [get_member_id (ipt )].data .data for ipt in model_spec .inputs
48
+ ]
60
49
input_tensors = [torch .from_numpy (ipt ) for ipt in input_data ]
61
50
model = load_torch_model (state_dict_weights_descr )
62
51
@@ -81,9 +70,9 @@ def add_onnx_weights(
81
70
try :
82
71
import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
83
72
except ImportError :
84
- msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
85
- warnings . warn ( msg )
86
- return
73
+ raise ImportError (
74
+ "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
75
+ )
87
76
88
77
# check the onnx model
89
78
sess = rt .InferenceSession (str (output_path ))
@@ -101,8 +90,11 @@ def add_onnx_weights(
101
90
try :
102
91
for exp , out in zip (expected_outputs , outputs ):
103
92
assert_array_almost_equal (exp , out , decimal = test_decimal )
104
- return 0
105
93
except AssertionError as e :
106
- msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n { str (e )} "
107
- warnings .warn (msg )
108
- return 1
94
+ raise ValueError (
95
+ f"Results before and after weights conversion do not agree:\n { str (e )} "
96
+ )
97
+
98
+ return v0_5 .OnnxWeightsDescr (
99
+ source = output_path , parent = "pytorch_state_dict" , opset_version = opset_version
100
+ )
0 commit comments