Skip to content

Commit 7200411

Browse files
committed
Cleanup
1 parent f0487d0 commit 7200411

File tree

1 file changed

+0
-13
lines changed

1 file changed

+0
-13
lines changed

torch_ttnn/generate_op_accuracy_tests.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,7 @@ def compute_key(node):
145145
def map_meta_to_aten_node(aten_graph):
146146
aten_name_to_node_map = defaultdict(list)
147147
for node in aten_graph.nodes:
148-
if "val" in node.meta:
149-
print("aten graph meta:", node, node.meta["val"])
150148
if node.op != "placeholder" and node.op != "output":
151-
print("aten meta:", node, node.meta)
152149
aten_name_to_node_map[compute_key(node)] = node
153150
return aten_name_to_node_map
154151

@@ -166,7 +163,6 @@ def map_aten_node_to_ttnn_node(ttnn_graph, output_nodes, aten_name_to_node_map):
166163
continue
167164
if "seq_nr" in node.meta:
168165
aten_node_name = compute_key(node)
169-
print("aten_node_name:", node, aten_node_name)
170166
aten_node = aten_name_to_node_map[aten_node_name]
171167
aten_to_ttnn_map[aten_node].append(node)
172168
# also append gettiem if exists
@@ -181,18 +177,13 @@ def process_ttnn_ops(ttnn_graph, aten_name_to_node_map, aten_to_ttnn_map):
181177
if node.op == "output":
182178
continue
183179
if node.op == "placeholder":
184-
# arg_nodes.append(node)
185180
continue
186-
# # val = node.meta["val"]
187-
# # print(f"{node.name} = torch.rand({tuple(val.size())}, dtype={val.dtype})")
188181
ttnn_all_nodes.append(node)
189182
# if ((from_node := node.meta.get("from_node", None)) is not None):
190183
if "seq_nr" in node.meta:
191-
print("ttnn meta:", node, node.meta["seq_nr"], node.meta["original_aten"]._name, str(node.meta["val"]))
192184
aten_node_name = compute_key(node)
193185
aten_node = aten_name_to_node_map[aten_node_name]
194186
# this is the last ttnn node for this aten op, compare the output of this
195-
print("aten_to_ttnn_map:", aten_to_ttnn_map[aten_node])
196187
if node == aten_to_ttnn_map[aten_node][-1]:
197188
# this will be converted to test_accuracy(node1, node2) later
198189
# do not emit if users are getitem
@@ -206,8 +197,6 @@ def process_ttnn_ops(ttnn_graph, aten_name_to_node_map, aten_to_ttnn_map):
206197
def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_inputs, *, verbose=False):
207198
assert len(aten_fx_graphs) == len(ttnn_fx_graphs)
208199

209-
print("len graphs:", len(aten_fx_graphs), len(ttnn_fx_graphs))
210-
211200
test_accuracy_graph_codes = []
212201
output_nodes = []
213202
for aten_graph, ttnn_graph in zip(aten_fx_graphs, ttnn_fx_graphs):
@@ -227,8 +216,6 @@ def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_i
227216
continue
228217
aten_all_nodes.append(node)
229218

230-
print("aten graph args:", arg_nodes)
231-
232219
# preprocess: map aten to ttnn ops. this is to know what is the last ttnn op in group to compare output
233220
aten_to_ttnn_map = map_aten_node_to_ttnn_node(ttnn_graph, output_nodes, aten_name_to_node_map)
234221

0 commit comments

Comments
 (0)