Skip to content

Commit 039da0f

Browse files
committed
test_main.py almost done...
1 parent 95a825b commit 039da0f

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

sclblonnx/main.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def graph_from_file(
6363
def graph_to_file(
6464
graph: xpb2.GraphProto,
6565
filename: str,
66-
_producer: str = "sclblonnx"):
66+
_producer: str = "sclblonnx",
67+
**kwargs):
6768
""" graph_to_file stores an onnx graph to a .onnx file
6869
6970
Stores a graph to a file
@@ -84,13 +85,13 @@ def graph_to_file(
8485
_print("Unable to save: Graph is not an ONNX graph")
8586

8687
try:
87-
mod = xhelp.make_model(graph, producer_name=_producer)
88+
mod = xhelp.make_model(graph, producer_name=_producer, **kwargs)
8889
except Exception as e:
8990
print("Unable to convert graph to model: " + str(e))
9091
return False
9192

9293
try:
93-
xsave(mod, filename)
94+
xsave(mod, filename, **kwargs)
9495
except Exception as e:
9596
print("Unable to save the model: " + str(e))
9697
return False
@@ -103,7 +104,8 @@ def run(
103104
graph: xpb2.GraphProto,
104105
inputs: {},
105106
outputs: [],
106-
_tmpfile: str = ".tmp.onnx"):
107+
_tmpfile: str = ".tmp.onnx",
108+
**kwargs):
107109
""" run executes a give graph with the given input and returns the output
108110
109111
Args:
@@ -122,7 +124,7 @@ def run(
122124
return False
123125

124126
try:
125-
sess = xrt.InferenceSession(_tmpfile)
127+
sess = xrt.InferenceSession(_tmpfile, **kwargs)
126128
out = sess.run(outputs, inputs)
127129
except Exception as e:
128130
_print("Failed to run the model: " + str(e))
@@ -139,7 +141,7 @@ def run(
139141
# display uses Netron to display a graph
140142
def display(
141143
graph: xpb2.GraphProto,
142-
_tmpfile: str = '.temp.onnx'):
144+
_tmpfile: str = '.tmp.onnx'):
143145
""" display a onnx graph using netron.
144146
145147
Pass a graph to the display function to open it in Netron.
@@ -267,6 +269,7 @@ def list_data_types():
267269
""" List all available data types. """
268270
_print(json.dumps(glob.DATA_TYPES, indent=2), "MSG")
269271
_print("Note: STRINGS are not supported at this time.", "LIT")
272+
return True
270273

271274

272275
# list_operators prints all operators available within Scailable
@@ -279,6 +282,7 @@ def list_operators():
279282
print("Unable to locate the ONNX_VERSION INFO.")
280283
return False
281284
_print(json.dumps(glob.ONNX_VERSION_INFO['operators'], indent=2), "MSG")
285+
return True
282286

283287

284288
# No command line options for this script:

test/test_main.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import numpy as np
33
from onnx import onnx_ml_pb2 as xpb2
4-
from sclblonnx import empty_graph, graph_from_file, graph_to_file, run
4+
from sclblonnx import empty_graph, graph_from_file, graph_to_file, run, list_data_types, list_operators, sclbl_input
55

66

77
def test_empty_graph():
@@ -44,6 +44,22 @@ def test_display():
4444
return True # No test for display
4545

4646

47-
# def test_scblbl_input():
48-
# def test_list_data_types
49-
# def test_list_operators
47+
def test_scblbl_input():
48+
example = {"in": np.array([1,2,3,4]).astype(np.int32)}
49+
result = sclbl_input(example)
50+
assert result == '{"input": CAQQBkoQAQAAAAIAAAADAAAABAAAAA==, "type":"pb"}', "PB output not correct."
51+
52+
example = {"x1": np.array([1,2,3,4]).astype(np.int32), "x2": np.array([1,2,3,4]).astype(np.int32)}
53+
result = sclbl_input(example)
54+
assert result == '{"input": ["CAQQBkoQAQAAAAIAAAADAAAABAAAAA==","CAQQBkoQAQAAAAIAAAADAAAABAAAAA=="], "type":"pb"}', "PB output 2 not correct."
55+
56+
# todo: Do raw tests.
57+
58+
59+
def test_list_data_types():
60+
test = list_data_types()
61+
assert test, "Data types should be listed."
62+
63+
def test_list_operators():
64+
test = list_operators()
65+
assert test, "Operators should be listed."

0 commit comments

Comments
 (0)