Skip to content

Commit 57c8a5e

Browse files
committed
Support get_attr nodes
1 parent d94b511 commit 57c8a5e

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

torch_ttnn/backend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ def aten_backend(
106106

107107
# Save aten graph if requested
108108
if options.gen_op_accuracy_tests:
109-
option._aten_fx_graphs.append(copy.deepcopy(gm.graph))
109+
# Will this hamper memory usage?
110+
graph_copy = copy.deepcopy(gm.graph)
111+
graph_copy.owning_module = gm
112+
option._aten_fx_graphs.append(graph_copy)
110113

111114
# Save the number of aten ops before compilation
112115
if option.metrics_path:
@@ -237,7 +240,7 @@ def ttnn_backend(
237240
options: TorchTtnnOption = None,
238241
) -> torch.fx.GraphModule:
239242
# Save all parameters and inputs if requested
240-
if options.gen_op_accuracy_tests:
243+
if options.gen_op_accuracy_tests and options._all_inputs is None:
241244
options._all_inputs = generate_op_accuracy_tests.generate_flat_args(gm, example_inputs)
242245

243246
tracer_option = options.tracer_option

torch_ttnn/generate_op_accuracy_tests.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_opname(node):
4343
# rename node names because some wrapper or built-in functions have the same name
4444
def rename_nodes(graph, prefix):
4545
for node in graph.nodes:
46-
if node.op != "placeholder" and node.op != "output":
46+
if node.op != "placeholder" and node.op != "output" and node.op != "get_attr":
4747
# simplify or put this in a new function
4848
opname = get_opname(node)
4949
if not opname.startswith("aten.") and not opname.startswith("ttnn."):
@@ -139,7 +139,14 @@ def rename_input_args_from_graph_break(output_nodes, node):
139139

140140

141141
def compute_key(node):
142-
return str(node.meta["seq_nr"]) + node.meta["original_aten"]._name + str(node.meta["val"])
142+
if "tensor_meta" in node.meta:
143+
tensor_meta = node.meta["tensor_meta"]
144+
else:
145+
tensor_meta = node.meta["val"]
146+
# Workaround for layer_norm and other ops that has a list for "val"
147+
if isinstance(tensor_meta, tuple):
148+
tensor_meta = node.meta["val"][0]
149+
return str(node.meta["seq_nr"]) + node.meta["original_aten"]._name + str(tensor_meta)
143150

144151

145152
def map_meta_to_aten_node(aten_graph):

torch_ttnn/passes/constant_folding_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def _evaluate_node(self, gm: torch.fx.GraphModule, node):
7777
def _replace_with_constant(self, gm: torch.fx.GraphModule, node, value):
7878
with gm.graph.inserting_before(node):
7979
new_node = gm.graph.create_node("get_attr", target=f"_folded_{node.name}", args=(), kwargs=None)
80-
new_node.meta = node.meta
80+
# Copying all of the meta causes some ops to be missing
81+
new_node.meta["seq_nr"] = node.meta["seq_nr"]
82+
new_node.meta["original_aten"] = node.meta["original_aten"]
83+
new_node.meta["tensor_meta"] = node.meta["tensor_meta"]
8184

8285
gm.register_parameter(f"_folded_{node.name}", torch.nn.Parameter(value, requires_grad=False))
8386

0 commit comments

Comments
 (0)