Skip to content

Commit 7a78174

Browse files
committed
First tests done
1 parent 1668502 commit 7a78174

12 files changed

+516
-134
lines changed

sclblonnx/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from .node import \
2828
node, \
2929
add_node, \
30-
add_nodes
30+
add_nodes, \
31+
delete_node
3132

3233
from .constant import \
3334
constant, \

sclblonnx/input.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from onnx import helper as xhelp
22
from onnx import onnx_ml_pb2 as xpb2
3-
4-
from sclblonnx.utils import _value, _data_type, _parse_element
3+
from sclblonnx.utils import _value, _data_type, _parse_element, _print
54

65

76
# list_inputs lists all inputs of a graph
87
def list_inputs(graph: xpb2.GraphProto):
9-
""" Tries to list the outputs of a given graph.
8+
""" Tries to list the inputs of a given graph.
109
1110
Args:
1211
graph the ONNX graph
1312
"""
1413
if type(graph) is not xpb2.GraphProto:
15-
print("graph is not a valid ONNX graph.")
14+
_print("graph is not a valid ONNX graph.")
1615
return False
1716

1817
i = 1
@@ -21,6 +20,9 @@ def list_inputs(graph: xpb2.GraphProto):
2120
print("Input {}: Name: '{}', Type: {}, Dimension: {}".format(i, name, dtype, shape))
2221
i += 1
2322

23+
if i == 1:
24+
print("No inputs found.")
25+
2426
return True
2527

2628

@@ -38,13 +40,14 @@ def add_input(
3840
name: String, the name of the input as used to determine the graph topology.
3941
data_type: String, the data type of the input. Run list_data_types() for an overview.
4042
dimensions: List[] specifying the dimensions of the input.
43+
**kwargs
4144
4245
Returns:
4346
The extended graph.
4447
4548
"""
4649
if type(graph) is not xpb2.GraphProto:
47-
print("graph is not a valid ONNX graph.")
50+
_print("graph is not a valid ONNX graph.")
4851
return False
4952

5053
dtype = _data_type(data_type)
@@ -54,21 +57,34 @@ def add_input(
5457
try:
5558
graph.input.append(xhelp.make_tensor_value_info(name, dtype, dimensions, **kwargs), *kwargs)
5659
except Exception as e:
57-
print("Unable to add the input: " + str(e))
60+
_print("Unable to add the input: " + str(e))
5861
return False
5962
return graph
6063

6164

6265
# rename_input renames an input
6366
def rename_input(graph, current_name, new_name):
64-
# We have to rename the input itself:
67+
""" Rename an input to a graph
68+
69+
Args:
70+
graph: A graph, onnx.onnx_ml_pb2.GraphProto.
71+
current_name: String, the current input name.
72+
new_name: String, the name desired input name.
73+
74+
Returns:
75+
The changed graph.
76+
"""
77+
if type(graph) is not xpb2.GraphProto:
78+
_print("graph is not a valid ONNX graph.")
79+
return False
80+
6581
found = False
6682
for input in graph.input:
6783
if input.name == current_name:
6884
input.name = new_name
6985
found = True
7086
if not found:
71-
print("err")
87+
_print("Unable to find the input to rename.")
7288
return False
7389

7490
# And rename it in every nodes that takes this as input:
@@ -93,14 +109,15 @@ def replace_input(
93109
graph: A graph, onnx.onnx_ml_pb2.GraphProto.
94110
name: String, the name of the input as used to determine the graph topology.
95111
data_type: String, the data type of the input. Run list_data_types() for an overview.
96-
dimensions: List[] specifying the dimensions of the input.
112+
dimensions: List[] specifying the dimensions of the input.,
113+
**kwargs
97114
98115
Returns:
99116
The extended graph.
100117
101118
"""
102119
if type(graph) is not xpb2.GraphProto:
103-
print("graph is not a valid ONNX graph.")
120+
_print("graph is not a valid ONNX graph.")
104121
return graph
105122

106123
# Remove the named input
@@ -111,33 +128,32 @@ def replace_input(
111128
graph.input.remove(elem)
112129
found = True
113130
except Exception as e:
114-
print("Unable to iterate the inputs. " + str(e))
131+
_print("Unable to iterate the inputs. " + str(e))
115132
return False
116133
if not found:
117-
print("Unable to find the input by name.")
134+
_print("Unable to find the input by name.")
118135

119136
# Create the new value
120137
try:
121138
val = _value(name, data_type, dimensions, **kwargs)
122139
except Exception as e:
123-
print("Unable to create value. " + str(e))
140+
_print("Unable to create value. " + str(e))
124141
return False
125142

126143
# Add the value to the input
127144
try:
128145
graph.input.append(val, *kwargs)
129146
except Exception as e:
130-
print("Unable to add the input: " + str(e))
147+
_print("Unable to add the input: " + str(e))
131148
return False
132149

133150
return graph
134151

135152

136153
# delete_input deletes an existing input
137154
def delete_input(
138-
graph: xpb2.GraphProto,
139-
name: str,
140-
**kwargs):
155+
graph: xpb2.GraphProto,
156+
name: str):
141157
""" Removes an existing input of a graph by name
142158
143159
Args:
@@ -159,10 +175,10 @@ def delete_input(
159175
graph.input.remove(elem)
160176
found = True
161177
except Exception as e:
162-
print("Unable to iterate the inputs. " + str(e))
178+
_print("Unable to iterate the inputs. " + str(e))
163179
return False
164180
if not found:
165-
print("Unable to find the input by name.")
181+
_print("Unable to find the input by name.")
166182
return False
167183

168-
return graph
184+
return graph

sclblonnx/node.py

+54-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from onnx import helper as xhelp
22
from onnx import onnx_ml_pb2 as xpb2
3-
43
import sclblonnx._globals as glob
4+
from sclblonnx.utils import _print
55

66

77
# Node creates a new node
@@ -18,6 +18,7 @@ def node(
1818
inputs: [] list of inputs (names to determine the graph topology)
1919
outputs: [] list of outputs (names to determine the graph topology)
2020
name: The name of this node (Optional)
21+
**kwargs
2122
"""
2223
if not name:
2324
name = "sclbl-onnx-node" + str(glob.NODE_COUNT)
@@ -26,7 +27,7 @@ def node(
2627
try:
2728
node = xhelp.make_node(op_type, inputs, outputs, name, **kwargs)
2829
except Exception as e:
29-
print("Unable to create node: " + str(e))
30+
_print("Unable to create node: " + str(e))
3031
return False
3132
return node
3233

@@ -43,20 +44,23 @@ def add_node(
4344
Args:
4445
graph: A graph, onnx.onnx_ml_pb2.GraphProto.
4546
node: A node, onnx.onnx_ml_pb2.NodeProto.
47+
**kwargs
4648
4749
Returns:
4850
The extended graph.
4951
"""
5052
if type(graph) is not xpb2.GraphProto:
51-
print("graph is not a valid ONNX graph.")
53+
_print("The graph is not a valid ONNX graph.")
5254
return False
55+
5356
if type(node) is not xpb2.NodeProto:
54-
print("node is not a valid ONNX node.")
57+
_print("The node is not a valid ONNX node.")
5558
return False
59+
5660
try:
5761
graph.node.append(node, **kwargs)
5862
except Exception as e:
59-
print("Unable to extend graph: " + str(e))
63+
_print("Unable to extend graph: " + str(e))
6064
return False
6165
return graph
6266

@@ -72,6 +76,7 @@ def add_nodes(
7276
Args:
7377
graph: A graph, onnx.onnx_ml_pb2.GraphProto.
7478
nodes: A list of nodes, [onnx.onnx_ml_pb2.NodeProto].
79+
**kwargs
7580
7681
Returns:
7782
The extended graph.
@@ -80,8 +85,50 @@ def add_nodes(
8085
print("graph is not a valid ONNX graph.")
8186
return False
8287

83-
for node in nodes:
88+
for node in nodes: # error handling in add_node
8489
graph = add_node(graph, node, **kwargs)
8590
if not graph:
8691
return False
87-
return graph
92+
93+
return graph
94+
95+
96+
# delete_node deletes a node from a graph
97+
def delete_node(
98+
graph: xpb2.GraphProto,
99+
node_name: str = "",
100+
**kwargs):
101+
""" Add node appends a node to graph g and returns the extended graph
102+
103+
Prints a message and returns False if fails.
104+
105+
Args:
106+
graph: A graph, onnx.onnx_ml_pb2.GraphProto.
107+
node_name: Name of the node to remove.
108+
**kwargs
109+
110+
Returns:
111+
The extended graph.
112+
"""
113+
if type(graph) is not xpb2.GraphProto:
114+
_print("The graph is not a valid ONNX graph.")
115+
return False
116+
117+
if not node_name:
118+
_print("Please specify a node name.")
119+
return False
120+
121+
found = False
122+
try:
123+
for elem in graph.node:
124+
if elem.name == node_name:
125+
graph.node.remove(elem)
126+
found = True
127+
except Exception as e:
128+
_print("Unable to iterate the nodes. " + str(e))
129+
return False
130+
if not found:
131+
_print("Unable to find the node by name.")
132+
return False
133+
134+
return graph

0 commit comments

Comments
 (0)