Skip to content

Commit 44c5078

Browse files
committed
Minor fix of check function (validating dynamic input size), and quite an overhaul of the merge functionality. Now including merge, split, join, and concat. These operations are somewhat challenging, so the example_merge.py provides a number of possible use cases for simple graphs. The test_merge.py currently contains very limited tests, and should still be expanded, but for now this update should cover mucht of the discussion in scailable#21 and scailable#11.
1 parent 120b1d4 commit 44c5078

11 files changed

+615
-402
lines changed

examples/.temp.onnx

-440 KB
Binary file not shown.

examples/example_merge.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import copy
2+
import sclblonnx as so
3+
import numpy as np
4+
"""
5+
EXAMPLE MERGE: a number of examples usages of the merge, join, split, and concat functions.
6+
7+
Note that merge(), join(), and split() are high level wrappers around concat(), each effectively assuming that the
8+
resulting graph is "complete" (i.e., it is a valid onnx graph including input and output). Concat itself is more
9+
flexible and can be used for intermediate merging/concatenation of partial graphs (i.e., graphs that are not yet
10+
finished).
11+
12+
Below we provide a number of examples of each of the functions. We recommend using so.display() throughout to visualize
13+
the resulting graphs and truly understand how the graphs are joined together. Examples are all very simple (small graphs,
14+
scalar operations, etc.), but don't underestimate the complexities involved; with larger graphs the behavior of
15+
the concat function can be challenging.
16+
"""
17+
18+
# # Lets start by creating a few simple (and complete) graphs which we will use throughout:
19+
# # Simple absolute value graph:
20+
g1 = so.empty_graph("g_1")
21+
n1 = so.node('Abs', inputs=['in_1_1'], outputs=['out_1_1'], name="node_1_1")
22+
g1 = so.add_input(g1, 'in_1_1', "FLOAT", [1])
23+
g1 = so.add_output(g1, 'out_1_1', "FLOAT", [1])
24+
g1 = so.add_node(g1, n1)
25+
# so.display(g1)
26+
# data = {"in_1_1": np.array([2]).astype(np.float32)}
27+
# print(so.run(g1, inputs=data, outputs=["out_1_1"]))
28+
29+
# # Simple max value graph:
30+
g2= so.empty_graph("g_2")
31+
n2 = so.node('Max', inputs=['in_2_1', 'in_2_2'], outputs=['out_2_1'], name="node_2_1")
32+
g2 = so.add_input(g2, 'in_2_1', "FLOAT", [1])
33+
g2 = so.add_constant(g2, "in_2_2", np.array([10]), "FLOAT")
34+
g2 = so.add_output(g2, 'out_2_1', "FLOAT", [1])
35+
g2 = so.add_node(g2, n2)
36+
# so.display(g2)
37+
# data = {"in_2_1": np.array([2]).astype(np.float32)}
38+
# print(so.run(g2, inputs=data, outputs=["out_2_1"]))
39+
40+
# # Simple add two values graph:
41+
g3 = so.empty_graph("g_3")
42+
n3 = so.node('Add', inputs=['in_3_1', 'in_3_2'], outputs=['out_3_1'], name="node_3_1")
43+
g3 = so.add_input(g3, 'in_3_1', "FLOAT", [1])
44+
g3 = so.add_input(g3, 'in_3_2', "FLOAT", [1])
45+
g3 = so.add_output(g3, 'out_3_1', "FLOAT", [1])
46+
g3 = so.add_node(g3, n3)
47+
# so.display(g3)
48+
# data = {
49+
# "in_3_1": np.array([2]).astype(np.float32),
50+
# "in_3_2": np.array([5]).astype(np.float32)}
51+
# print(so.run(g3, inputs=data, outputs=["out_3_1"]))
52+
53+
54+
# # MERGE:
55+
# # Merge takes two complete graphs and links the output of the parent to the inputs of the child.
56+
# # Merge assumes the result is complete.
57+
g_merge = so.merge(sg1=g1, sg2=g2, io_match=[("out_1_1", "in_2_1")])
58+
# so.display(g_merge)
59+
# data = {"in_1_1": np.array([2]).astype(np.float32)}
60+
# print(so.run(g_merge, inputs=data, outputs=["out_2_1"]))
61+
62+
63+
# # JOIN:
64+
# # Join takes two parents and links their outputs to one child
65+
# # Join assumes the result is complete.
66+
g_join = so.join(pg1=g1, pg2=g2, cg=g3, pg1_match=[("out_1_1", "in_3_1")], pg2_match=[("out_2_1", "in_3_2")])
67+
# so.display(g_join)
68+
# data = {
69+
# "in_1_1": np.array([2]).astype(np.float32),
70+
# "in_2_1": np.array([2]).astype(np.float32)}
71+
# print(so.run(g_join, inputs=data, outputs=["out_3_1"]))
72+
73+
74+
# # SPLIT:
75+
# # Split takes a single parent and links its output to the inputs of two children.
76+
# # Split assumes the result is complete.
77+
g_split = so.split(pg=g3, cg1=g1, cg2=g2, cg1_match=[("out_3_1", "in_1_1")], cg2_match=[("out_3_1", "in_2_1")])
78+
# so.display(g_split)
79+
# data = {
80+
# "in_3_1": np.array([2]).astype(np.float32),
81+
# "in_3_2": np.array([5]).astype(np.float32)}
82+
# print(so.run(g_split, inputs=data, outputs=["out_1_1", "out_2_1"]))
83+
84+
85+
# # CONCAT
86+
# # Here we provide a number of uses of concat, please inspect the resulting graphs
87+
# # Note, these result are by default not checked for completeness. Hence, the returned graph need not contain
88+
# # valid inputs and outputs.
89+
g_c1 = so.concat(g1, g2) # Note, these are just the two graphs "side-by-side"
90+
g_c2 = so.concat(g1, g2, io_match=[("out_1_1", "in_2_1")]) # Merge
91+
g_c3 = so.concat(g1, g2, io_match=[("out_2_1", "in_1_1")]) # No merge
92+
g_c4 = so.concat(g2, g1, io_match=[("out_2_1", "in_1_1")]) # Merge flipped, the order matters
93+
g_c5 = so.concat(g1, g2, io_match=[("out_1_1", "in_2_1")], rename_nodes=False) # Akin g_c2, but without the node names changed
94+
95+
g4 = copy.deepcopy(g1) # an exact copy of g1
96+
g_c6 = so.concat(g1, g4) # Ugly...
97+
g_c7 = so.concat(g1, g4, rename_edges=True, rename_io=True) # Side by side
98+
99+
g5 = copy.deepcopy(g4) # Another exact copy,
100+
g5 = so.delete_input(g5, "in_1_1") # Removing input and output
101+
g5 = so.delete_output(g5, "out_1_1")
102+
g_c8 = so.concat(g1, g5) # Edge created, but unable to link a single output to two named edges
103+
104+
g6 = so.empty_graph("g_6")
105+
n4 = so.node('Add', inputs=['in_1_1', 'in_6_2'], outputs=['out_6_1'], name="node_6_1")
106+
g6 = so.add_node(g6, n4)
107+
g_c9 = so.concat(g1, g6) # Similarly named edges are also linked
108+
g_c10 = so.concat(g1, g6, rename_edges=True) # All edges renamed, but not i/o broken
109+
g_c11 = so.concat(g1, g6, rename_edges=True, rename_io=True) # g6 did not have inputs and outputs
110+
g_c12 = so.concat(g1, g6, edge_match=[("out_1_1", "in_6_2")]) # Explicit edge matching (akin io_match but for internal edges)
111+
112+
# # Again, please use so.display(g..) to see the results of the above uses of concat.
113+
19 Bytes
Binary file not shown.

sclblonnx/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@
4949
delete_output
5050

5151
from .merge import \
52-
merge
52+
merge, \
53+
join, \
54+
split, \
55+
concat, \
56+
postfix_names
5357

5458

5559

0 commit comments

Comments
 (0)