forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_pytorch_jit_onnx.py
95 lines (78 loc) · 2.78 KB
/
test_pytorch_jit_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Owner(s): ["module: onnx"]
import onnxruntime
import torch
from torch._C import parse_ir
from torch.onnx import verification
from test_pytorch_common import TestCase, run_tests
def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version):
r"""
This function exports torch::jit::Graph object
to serialized ONNX ModelProto.
This function is for testing purpose.
It only keeps the essential parts for IR graph conversions.
It also does not interact with actual PyTorch modules nor
PyTorch tensor inputs.
"""
from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version
from torch.onnx.utils import _optimize_graph
# Shape inference is required because some ops' symbolic functions
# generate sub-graphs based on inputs' types.
_set_onnx_shape_inference(True)
_set_opset_version(opset_version)
graph = _optimize_graph(graph, operator_export_type, params_dict={})
proto, _, _, _ = graph._export_onnx(
{},
opset_version,
{},
False,
operator_export_type,
False,
False,
{},
True,
"",
{},
)
return proto
class _TestJITIRToONNX:
"""Abstract base class for test cases.
Intentionally not a sub-class of unittest.TestCase so that unittest / pytest
don't run it directly. unitest.TestCase is mixed in as another base class when
creating concrete sub-types. See MakeTestCase().
"""
opset_version = -1 # Sub-classes must override
ort_providers = ["CPUExecutionProvider"]
def run_test(self, graph_ir, example_inputs):
graph = parse_ir(graph_ir)
jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)
onnx_proto = _jit_graph_to_onnx_model(
graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version
)
ort_sess = onnxruntime.InferenceSession(
onnx_proto, providers=self.ort_providers
)
ort_outs = verification._run_ort(ort_sess, example_inputs)
verification._compare_ort_pytorch_outputs(
ort_outs, jit_outs, rtol=1e-3, atol=1e-7
)
def test_example_ir(self):
graph_ir = """
graph(%1 : Float(2, 3),
%2 : Float(2, 3)):
%3 : int = prim::Constant[value=1]()
%4 : Float(2, 3) = aten::add(%1, %2, %3)
return (%4)
"""
a = torch.randn(2, 3)
b = torch.randn(2, 3)
self.run_test(graph_ir, (a, b))
def MakeTestCase(opset_version: int) -> type:
name = f"TestJITIRToONNX_opset{opset_version}"
return type(
str(name),
(TestCase,),
dict(_TestJITIRToONNX.__dict__, opset_version=opset_version),
)
TestJITIRToONNX_opset14 = MakeTestCase(14)
if __name__ == "__main__":
run_tests()