1
1
from onnx import helper as xhelp
2
2
from onnx import onnx_ml_pb2 as xpb2
3
-
4
- from sclblonnx .utils import _value , _data_type , _parse_element
3
+ from sclblonnx .utils import _value , _data_type , _parse_element , _print
5
4
6
5
7
6
# list_inputs lists all inputs of a graph
8
7
def list_inputs (graph : xpb2 .GraphProto ):
9
- """ Tries to list the outputs of a given graph.
8
+ """ Tries to list the inputs of a given graph.
10
9
11
10
Args:
12
11
graph the ONNX graph
13
12
"""
14
13
if type (graph ) is not xpb2 .GraphProto :
15
- print ("graph is not a valid ONNX graph." )
14
+ _print ("graph is not a valid ONNX graph." )
16
15
return False
17
16
18
17
i = 1
@@ -21,6 +20,9 @@ def list_inputs(graph: xpb2.GraphProto):
21
20
print ("Input {}: Name: '{}', Type: {}, Dimension: {}" .format (i , name , dtype , shape ))
22
21
i += 1
23
22
23
+ if i == 1 :
24
+ print ("No inputs found." )
25
+
24
26
return True
25
27
26
28
@@ -38,13 +40,14 @@ def add_input(
38
40
name: String, the name of the input as used to determine the graph topology.
39
41
data_type: String, the data type of the input. Run list_data_types() for an overview.
40
42
dimensions: List[] specifying the dimensions of the input.
43
+ **kwargs
41
44
42
45
Returns:
43
46
The extended graph.
44
47
45
48
"""
46
49
if type (graph ) is not xpb2 .GraphProto :
47
- print ("graph is not a valid ONNX graph." )
50
+ _print ("graph is not a valid ONNX graph." )
48
51
return False
49
52
50
53
dtype = _data_type (data_type )
@@ -54,21 +57,34 @@ def add_input(
54
57
try :
55
58
graph .input .append (xhelp .make_tensor_value_info (name , dtype , dimensions , ** kwargs ), * kwargs )
56
59
except Exception as e :
57
- print ("Unable to add the input: " + str (e ))
60
+ _print ("Unable to add the input: " + str (e ))
58
61
return False
59
62
return graph
60
63
61
64
62
65
# rename_input renames an input
63
66
def rename_input (graph , current_name , new_name ):
64
- # We have to rename the input itself:
67
+ """ Rename an input to a graph
68
+
69
+ Args:
70
+ graph: A graph, onnx.onnx_ml_pb2.GraphProto.
71
+ current_name: String, the current input name.
72
+ new_name: String, the name desired input name.
73
+
74
+ Returns:
75
+ The changed graph.
76
+ """
77
+ if type (graph ) is not xpb2 .GraphProto :
78
+ _print ("graph is not a valid ONNX graph." )
79
+ return False
80
+
65
81
found = False
66
82
for input in graph .input :
67
83
if input .name == current_name :
68
84
input .name = new_name
69
85
found = True
70
86
if not found :
71
- print ( "err " )
87
+ _print ( "Unable to find the input to rename. " )
72
88
return False
73
89
74
90
# And rename it in every nodes that takes this as input:
@@ -93,14 +109,15 @@ def replace_input(
93
109
graph: A graph, onnx.onnx_ml_pb2.GraphProto.
94
110
name: String, the name of the input as used to determine the graph topology.
95
111
data_type: String, the data type of the input. Run list_data_types() for an overview.
96
- dimensions: List[] specifying the dimensions of the input.
112
+ dimensions: List[] specifying the dimensions of the input.,
113
+ **kwargs
97
114
98
115
Returns:
99
116
The extended graph.
100
117
101
118
"""
102
119
if type (graph ) is not xpb2 .GraphProto :
103
- print ("graph is not a valid ONNX graph." )
120
+ _print ("graph is not a valid ONNX graph." )
104
121
return graph
105
122
106
123
# Remove the named input
@@ -111,33 +128,32 @@ def replace_input(
111
128
graph .input .remove (elem )
112
129
found = True
113
130
except Exception as e :
114
- print ("Unable to iterate the inputs. " + str (e ))
131
+ _print ("Unable to iterate the inputs. " + str (e ))
115
132
return False
116
133
if not found :
117
- print ("Unable to find the input by name." )
134
+ _print ("Unable to find the input by name." )
118
135
119
136
# Create the new value
120
137
try :
121
138
val = _value (name , data_type , dimensions , ** kwargs )
122
139
except Exception as e :
123
- print ("Unable to create value. " + str (e ))
140
+ _print ("Unable to create value. " + str (e ))
124
141
return False
125
142
126
143
# Add the value to the input
127
144
try :
128
145
graph .input .append (val , * kwargs )
129
146
except Exception as e :
130
- print ("Unable to add the input: " + str (e ))
147
+ _print ("Unable to add the input: " + str (e ))
131
148
return False
132
149
133
150
return graph
134
151
135
152
136
153
# delete_input deletes an existing input
137
154
def delete_input (
138
- graph : xpb2 .GraphProto ,
139
- name : str ,
140
- ** kwargs ):
155
+ graph : xpb2 .GraphProto ,
156
+ name : str ):
141
157
""" Removes an existing input of a graph by name
142
158
143
159
Args:
@@ -159,10 +175,10 @@ def delete_input(
159
175
graph .input .remove (elem )
160
176
found = True
161
177
except Exception as e :
162
- print ("Unable to iterate the inputs. " + str (e ))
178
+ _print ("Unable to iterate the inputs. " + str (e ))
163
179
return False
164
180
if not found :
165
- print ("Unable to find the input by name." )
181
+ _print ("Unable to find the input by name." )
166
182
return False
167
183
168
- return graph
184
+ return graph
0 commit comments