@@ -145,10 +145,7 @@ def compute_key(node):
145
145
def map_meta_to_aten_node (aten_graph ):
146
146
aten_name_to_node_map = defaultdict (list )
147
147
for node in aten_graph .nodes :
148
- if "val" in node .meta :
149
- print ("aten graph meta:" , node , node .meta ["val" ])
150
148
if node .op != "placeholder" and node .op != "output" :
151
- print ("aten meta:" , node , node .meta )
152
149
aten_name_to_node_map [compute_key (node )] = node
153
150
return aten_name_to_node_map
154
151
@@ -166,7 +163,6 @@ def map_aten_node_to_ttnn_node(ttnn_graph, output_nodes, aten_name_to_node_map):
166
163
continue
167
164
if "seq_nr" in node .meta :
168
165
aten_node_name = compute_key (node )
169
- print ("aten_node_name:" , node , aten_node_name )
170
166
aten_node = aten_name_to_node_map [aten_node_name ]
171
167
aten_to_ttnn_map [aten_node ].append (node )
172
168
# also append gettiem if exists
@@ -181,18 +177,13 @@ def process_ttnn_ops(ttnn_graph, aten_name_to_node_map, aten_to_ttnn_map):
181
177
if node .op == "output" :
182
178
continue
183
179
if node .op == "placeholder" :
184
- # arg_nodes.append(node)
185
180
continue
186
- # # val = node.meta["val"]
187
- # # print(f"{node.name} = torch.rand({tuple(val.size())}, dtype={val.dtype})")
188
181
ttnn_all_nodes .append (node )
189
182
# if ((from_node := node.meta.get("from_node", None)) is not None):
190
183
if "seq_nr" in node .meta :
191
- print ("ttnn meta:" , node , node .meta ["seq_nr" ], node .meta ["original_aten" ]._name , str (node .meta ["val" ]))
192
184
aten_node_name = compute_key (node )
193
185
aten_node = aten_name_to_node_map [aten_node_name ]
194
186
# this is the last ttnn node for this aten op, compare the output of this
195
- print ("aten_to_ttnn_map:" , aten_to_ttnn_map [aten_node ])
196
187
if node == aten_to_ttnn_map [aten_node ][- 1 ]:
197
188
# this will be converted to test_accuracy(node1, node2) later
198
189
# do not emit if users are getitem
@@ -206,8 +197,6 @@ def process_ttnn_ops(ttnn_graph, aten_name_to_node_map, aten_to_ttnn_map):
206
197
def generate_op_accuracy_tests (model_name , aten_fx_graphs , ttnn_fx_graphs , all_inputs , * , verbose = False ):
207
198
assert len (aten_fx_graphs ) == len (ttnn_fx_graphs )
208
199
209
- print ("len graphs:" , len (aten_fx_graphs ), len (ttnn_fx_graphs ))
210
-
211
200
test_accuracy_graph_codes = []
212
201
output_nodes = []
213
202
for aten_graph , ttnn_graph in zip (aten_fx_graphs , ttnn_fx_graphs ):
@@ -227,8 +216,6 @@ def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_i
227
216
continue
228
217
aten_all_nodes .append (node )
229
218
230
- print ("aten graph args:" , arg_nodes )
231
-
232
219
# preprocess: map aten to ttnn ops. this is to know what is the last ttnn op in group to compare output
233
220
aten_to_ttnn_map = map_aten_node_to_ttnn_node (ttnn_graph , output_nodes , aten_name_to_node_map )
234
221
0 commit comments