Skip to content

Commit

Permalink
Workaround for metadata for outputs with tuples like layer_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwuTT committed Dec 23, 2024
1 parent 2d2ffc9 commit 326af78
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion torch_ttnn/generate_op_accuracy_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,13 @@ def rename_input_args_from_graph_break(output_nodes, node):
def compute_key(node):
if node.op == "get_attr":
return str()
tensor_meta = node.meta["tensor_meta"] if "tensor_meta" in node.meta else node.meta["val"]
if "tensor_meta" in node.meta:
tensor_meta = node.meta["tensor_meta"]
else:
tensor_meta = node.meta["val"]
# Workaround for layer_norm and other ops that has a list for "val"
if isinstance(tensor_meta, tuple):
tensor_meta = node.meta["val"][0]
return str(node.meta["seq_nr"]) + node.meta["original_aten"]._name + str(tensor_meta)


Expand Down

0 comments on commit 326af78

Please sign in to comment.