Skip to content

Commit 1668502

Browse files
committed
Tests for constant.
1 parent c241600 commit 1668502

9 files changed

+52
-25
lines changed

sclblonnx/constant.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from onnx import helper as xhelp
33
from onnx import onnx_ml_pb2 as xpb2
44

5-
from sclblonnx.utils import _data_type
5+
from sclblonnx.utils import _data_type, _print
66
from sclblonnx.node import add_node
77

88

@@ -17,10 +17,15 @@ def constant(name: str,
1717
name: Name of the (output value of the) constant node to determine the graph topology
1818
value: Values of the node (as a np.array)
1919
data_type: Data type of the node
20+
**kwargs
2021
2122
Returns:
22-
The extended graph.
23+
A constant node.
2324
"""
25+
if not name:
26+
_print("Unable to create unnamed constant.")
27+
return False
28+
2429
dtype = _data_type(data_type)
2530
if not dtype:
2631
return False
@@ -30,7 +35,7 @@ def constant(name: str,
3035
value=xhelp.make_tensor(name=name + "-values", data_type=dtype,
3136
dims=value.shape, vals=value.flatten()), **kwargs)
3237
except Exception as e:
33-
print("Unable to create the constant node: " + str(e))
38+
_print("Unable to create the constant node: " + str(e))
3439
return False
3540

3641
return constant_node
@@ -43,7 +48,9 @@ def add_constant(
4348
value: np.array,
4449
data_type: str,
4550
**kwargs):
46-
""" Add a constant node to an existing graph
51+
""" Create and add a constant node to an existing graph.
52+
53+
Note: use add_node() if you want to add an existing constant node to an existing graph
4754
4855
Args:
4956
graph: A graph, onnx.onnx_ml_pb2.GraphProto.
@@ -67,12 +74,17 @@ def add_constant(
6774
value=xhelp.make_tensor(name=name + "-values", data_type=dtype,
6875
dims=value.shape, vals=value.flatten()), **kwargs)
6976
except Exception as e:
70-
print("Unable to create the constant node: " + str(e))
77+
_print("Unable to create the constant node: " + str(e))
78+
return False
79+
80+
try:
81+
graph = add_node(graph, constant_node, **kwargs)
82+
except Exception as e:
83+
_print("Unable to add the constant node to the graph: " + str(e))
7184
return False
7285

73-
graph = add_node(graph, constant_node, **kwargs)
7486
if not graph:
75-
print("Unable to add constant node to graph.")
87+
_print("Unable to add constant node to graph.")
7688
return False
7789
return graph
7890

sclblonnx/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _data_type(data_string: str):
145145
for key, val in glob.DATA_TYPES.items():
146146
if key == data_string:
147147
return val
148-
print("Data type not found. Use `list_data_types()` to list all supported data types.")
148+
_print("Data type not found. Use `list_data_types()` to list all supported data types.")
149149
return False
150150

151151

test/.tmp.onnx

-17
This file was deleted.

test/test_constant.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from sclblonnx import constant, empty_graph, node, add_node, add_constant, add_output, run, display, clean
2+
import numpy as np
3+
4+
5+
def test_constant():
6+
c = constant("", np.array([1,2,]), "FLOAT")
7+
assert not c, "Constant creation should have failed without a name."
8+
c = constant("constant", np.array([1,2,]), "NONE")
9+
assert not c, "Constant creation should have failed without a valid data type."
10+
c = constant("constant", np.array([1,2,]), "FLOAT")
11+
check = getattr(c, "output", False)
12+
assert check[0] == "constant", "Constant creation should have worked."
13+
14+
15+
def test_add_constant():
16+
17+
# Simple add graph
18+
g = empty_graph()
19+
n1 = node('Add', inputs=['x1', 'x2'], outputs=['sum'])
20+
g = add_node(g, n1)
21+
22+
# Add input and constant
23+
g = add_constant(g, 'x1', np.array([1]), "INT64")
24+
g = add_constant(g, 'x2', np.array([5]), "INT64")
25+
26+
# Output:
27+
g = add_output(g, 'sum', "INT64", [1])
28+
29+
# This works, but seems to fail for other data types...
30+
result = run(g, inputs={}, outputs=["sum"])
31+
assert result[0] == 6, "Add constant failed."
32+
# todo(McK): Does not work for INT16 / INT8, check?

test/test_input.py

Whitespace-only changes.

test/test_node.py

Whitespace-only changes.

test/test_output.py

Whitespace-only changes.

test/test_utils.py

Whitespace-only changes.

test/test_validate.py

Whitespace-only changes.

0 commit comments

Comments
 (0)