@@ -63,7 +63,8 @@ def graph_from_file(
63
63
def graph_to_file (
64
64
graph : xpb2 .GraphProto ,
65
65
filename : str ,
66
- _producer : str = "sclblonnx" ):
66
+ _producer : str = "sclblonnx" ,
67
+ ** kwargs ):
67
68
""" graph_to_file stores an onnx graph to a .onnx file
68
69
69
70
Stores a graph to a file
@@ -84,13 +85,13 @@ def graph_to_file(
84
85
_print ("Unable to save: Graph is not an ONNX graph" )
85
86
86
87
try :
87
- mod = xhelp .make_model (graph , producer_name = _producer )
88
+ mod = xhelp .make_model (graph , producer_name = _producer , ** kwargs )
88
89
except Exception as e :
89
90
print ("Unable to convert graph to model: " + str (e ))
90
91
return False
91
92
92
93
try :
93
- xsave (mod , filename )
94
+ xsave (mod , filename , ** kwargs )
94
95
except Exception as e :
95
96
print ("Unable to save the model: " + str (e ))
96
97
return False
@@ -103,7 +104,8 @@ def run(
103
104
graph : xpb2 .GraphProto ,
104
105
inputs : {},
105
106
outputs : [],
106
- _tmpfile : str = ".tmp.onnx" ):
107
+ _tmpfile : str = ".tmp.onnx" ,
108
+ ** kwargs ):
107
109
""" run executes a give graph with the given input and returns the output
108
110
109
111
Args:
@@ -122,7 +124,7 @@ def run(
122
124
return False
123
125
124
126
try :
125
- sess = xrt .InferenceSession (_tmpfile )
127
+ sess = xrt .InferenceSession (_tmpfile , ** kwargs )
126
128
out = sess .run (outputs , inputs )
127
129
except Exception as e :
128
130
_print ("Failed to run the model: " + str (e ))
@@ -139,7 +141,7 @@ def run(
139
141
# display uses Netron to display a graph
140
142
def display (
141
143
graph : xpb2 .GraphProto ,
142
- _tmpfile : str = '.temp .onnx' ):
144
+ _tmpfile : str = '.tmp .onnx' ):
143
145
""" display a onnx graph using netron.
144
146
145
147
Pass a graph to the display function to open it in Netron.
@@ -267,6 +269,7 @@ def list_data_types():
267
269
""" List all available data types. """
268
270
_print (json .dumps (glob .DATA_TYPES , indent = 2 ), "MSG" )
269
271
_print ("Note: STRINGS are not supported at this time." , "LIT" )
272
+ return True
270
273
271
274
272
275
# list_operators prints all operators available within Scailable
@@ -279,6 +282,7 @@ def list_operators():
279
282
print ("Unable to locate the ONNX_VERSION INFO." )
280
283
return False
281
284
_print (json .dumps (glob .ONNX_VERSION_INFO ['operators' ], indent = 2 ), "MSG" )
285
+ return True
282
286
283
287
284
288
# No command line options for this script:
0 commit comments