2
2
from onnx import helper as xhelp
3
3
from onnx import onnx_ml_pb2 as xpb2
4
4
5
- from sclblonnx .utils import _data_type
5
+ from sclblonnx .utils import _data_type , _print
6
6
from sclblonnx .node import add_node
7
7
8
8
@@ -17,10 +17,15 @@ def constant(name: str,
17
17
name: Name of the (output value of the) constant node to determine the graph topology
18
18
value: Values of the node (as a np.array)
19
19
data_type: Data type of the node
20
+ **kwargs
20
21
21
22
Returns:
22
- The extended graph .
23
+ A constant node .
23
24
"""
25
+ if not name :
26
+ _print ("Unable to create unnamed constant." )
27
+ return False
28
+
24
29
dtype = _data_type (data_type )
25
30
if not dtype :
26
31
return False
@@ -30,7 +35,7 @@ def constant(name: str,
30
35
value = xhelp .make_tensor (name = name + "-values" , data_type = dtype ,
31
36
dims = value .shape , vals = value .flatten ()), ** kwargs )
32
37
except Exception as e :
33
- print ("Unable to create the constant node: " + str (e ))
38
+ _print ("Unable to create the constant node: " + str (e ))
34
39
return False
35
40
36
41
return constant_node
@@ -43,7 +48,9 @@ def add_constant(
43
48
value : np .array ,
44
49
data_type : str ,
45
50
** kwargs ):
46
- """ Add a constant node to an existing graph
51
+ """ Create and add a constant node to an existing graph.
52
+
53
+ Note: use add_node() if you want to add an existing constant node to an existing graph
47
54
48
55
Args:
49
56
graph: A graph, onnx.onnx_ml_pb2.GraphProto.
@@ -67,12 +74,17 @@ def add_constant(
67
74
value = xhelp .make_tensor (name = name + "-values" , data_type = dtype ,
68
75
dims = value .shape , vals = value .flatten ()), ** kwargs )
69
76
except Exception as e :
70
- print ("Unable to create the constant node: " + str (e ))
77
+ _print ("Unable to create the constant node: " + str (e ))
78
+ return False
79
+
80
+ try :
81
+ graph = add_node (graph , constant_node , ** kwargs )
82
+ except Exception as e :
83
+ _print ("Unable to add the constant node to the graph: " + str (e ))
71
84
return False
72
85
73
- graph = add_node (graph , constant_node , ** kwargs )
74
86
if not graph :
75
- print ("Unable to add constant node to graph." )
87
+ _print ("Unable to add constant node to graph." )
76
88
return False
77
89
return graph
78
90
0 commit comments