Skip to content

Commit 27a6ed5

Browse files
authored
Enable ONNX test in CI (#2363)
* Enable ONNX test in CI
1 parent 1f863f9 commit 27a6ed5

File tree

3 files changed

+6
-20
lines changed

3 files changed

+6
-20
lines changed

requirements/developer.txt

+2
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ twine==4.0.2
1515
mypy==1.3.0
1616
torchpippy==0.1.1
1717
intel_extension_for_pytorch==2.0.100; sys_platform != 'win32' and sys_platform != 'darwin'
18+
onnxruntime==1.15.0
19+
onnx==1.14.0

test/pytest/test_onnx.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
import subprocess
22

3-
import pytest
43
import torch
5-
6-
try:
7-
import onnx
8-
import torch.onnx
9-
10-
print(
11-
onnx.__version__
12-
) # Adding this so onnx import doesn't get removed by pre-commit
13-
ONNX_ENABLED = True
14-
except:
15-
ONNX_ENABLED = False
4+
import torch.onnx
165

176

187
class ToyModel(torch.nn.Module):
@@ -28,7 +17,6 @@ def forward(self, x):
2817

2918

3019
# For a custom model you still need to manually author your converter, as far as I can tell there isn't a nice out of the box that exists
31-
@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed")
3220
def test_convert_to_onnx():
3321
model = ToyModel()
3422
dummy_input = torch.randn(1, 1)
@@ -55,7 +43,6 @@ def test_convert_to_onnx():
5543
)
5644

5745

58-
@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed")
5946
def test_model_packaging_and_start():
6047
subprocess.run("mkdir model_store", shell=True)
6148
subprocess.run(
@@ -65,7 +52,6 @@ def test_model_packaging_and_start():
6552
)
6653

6754

68-
@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed")
6955
def test_model_start():
7056
subprocess.run(
7157
"torchserve --start --ncs --model-store model_store --models onnx.mar",
@@ -74,14 +60,12 @@ def test_model_start():
7460
)
7561

7662

77-
@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed")
7863
def test_inference():
7964
subprocess.run(
8065
"curl -X POST http://127.0.0.1:8080/predictions/onnx --data-binary '1'",
8166
shell=True,
8267
)
8368

8469

85-
@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed")
8670
def test_stop():
8771
subprocess.run("torchserve --stop", shell=True, check=True)

ts/torch_handler/base_handler.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@
7777
ONNX_AVAILABLE = False
7878

7979

80-
def setup_ort_session(model_pt_path):
80+
def setup_ort_session(model_pt_path, map_location):
8181
providers = (
8282
["CUDAExecutionProvider", "CPUExecutionProvider"]
83-
if self.map_location == "cuda"
83+
if map_location == "cuda"
8484
else ["CPUExecutionProvider"]
8585
)
8686

@@ -168,7 +168,7 @@ def initialize(self, context):
168168

169169
# Convert your model by following instructions: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
170170
elif self.model_pt_path.endswith(".onnx") and ONNX_AVAILABLE:
171-
self.model = setup_ort_session(self.model_pt_path)
171+
self.model = setup_ort_session(self.model_pt_path, self.map_location)
172172
logger.info("Succesfully setup ort session")
173173

174174
else:

0 commit comments

Comments
 (0)