-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate standalone code that compares the accuracy between corresponding aten ops and lowered TTNN ops #611
Merged
Merged
Changes from 2 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
793132d
Add option to generate standalone python script for model runs
kevinwuTT efa1ebd
Refactor some variables and path names
kevinwuTT 672f741
Add lzma compression for input pickle files
kevinwuTT 61d4f7b
Add action to run accuracy tests in CI
kevinwuTT 9fb9353
Try different for loop in actions
kevinwuTT f92d244
Fix gen accuracy flag to the right action
kevinwuTT b1ce493
Fix file paths
kevinwuTT 39b81e7
Add ttnn.squeeze for lowering embedding when the rank of input is 1
kevinwuTT 2119337
Support metadata for get_attr nodes
kevinwuTT b2ab995
Support torch.device cases
kevinwuTT a5bdb40
Cleanup
kevinwuTT f0487d0
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT 7200411
Cleanup
kevinwuTT 4a27c75
Revert "Add ttnn.squeeze for lowering embedding when the rank of inpu…
kevinwuTT d94b511
Remove uneeded code
kevinwuTT 57c8a5e
Support get_attr nodes
kevinwuTT 52b47f3
Dynamically rename wrapper functions to avoid naming conflicts
kevinwuTT d45109a
Some get_attr nodes do not have tensor_meta or val metadata
kevinwuTT 2b960d1
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT cd5a3ca
Refactor
kevinwuTT 7020388
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT 6bbb613
Support cases where aten ops are decomposed and the decomposed ops ha…
kevinwuTT 50e00f8
Refactor
kevinwuTT 12065a7
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT d3cf5db
Fix oops
kevinwuTT 5f04797
Fix issue with 0-d tensor and add section on how to use this feature
kevinwuTT a0941b6
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT 4f815b7
Move autogen directory under tests/autogen_accuracy_tests and add pri…
kevinwuTT 17ea0ed
Use safetensors instead of pickle
kevinwuTT 125935d
Add small helper function for detecting regular operations
kevinwuTT f247e4b
Revert "Use safetensors instead of pickle"
kevinwuTT 4524d66
Fix accuracy test paths
kevinwuTT 278e8e2
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT acdc5e8
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT a1206eb
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT 46e489f
Move gen standalone code to tools
kevinwuTT File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
import inspect | ||
import pickle | ||
import torch.utils._pytree as pytree | ||
import ttnn | ||
import types | ||
|
||
from collections import defaultdict | ||
from pathlib import Path | ||
from tests.utils import assert_with_pcc, comp_pcc, construct_pcc_assert_message | ||
|
||
wrapper_funcs = set() | ||
|
||
|
||
# Returns a list of inputs to the first graph of model | ||
def generate_flat_args(gm, example_inputs): | ||
full_args = [] | ||
# torch/_functorch/aot_autograd.py::aot_module_simplified | ||
params = { | ||
**dict(gm.named_parameters(remove_duplicate=False)), | ||
**dict(gm.named_buffers(remove_duplicate=False)), | ||
} | ||
params_flat, params_spec = pytree.tree_flatten(params) | ||
params_flat = list(params_flat) | ||
full_args.extend(params_flat) | ||
full_args.extend(example_inputs) | ||
|
||
return full_args | ||
|
||
|
||
# rename node names because some wrapper or built-in functions have the same name | ||
def rename_nodes(graph, prefix): | ||
for node in graph.nodes: | ||
if node.op != "placeholder" and node.op != "output": | ||
# simplify or put this in a new function | ||
opname = str(node.target) if str(node.target).startswith("aten.") else node.target.__name__ | ||
if not opname.startswith("aten.") and not opname.startswith("ttnn."): | ||
node._rename(f"{prefix}_{node.name}") | ||
return graph | ||
|
||
|
||
def format_dict(obj): | ||
to_kwargs = [f"{k} = {v}" for k, v in obj.items()] | ||
return ", ".join(to_kwargs) | ||
|
||
|
||
# collect wrapper functions if required | ||
def node_to_python_code(node): | ||
# assume no placeholder and output? | ||
assert node.op not in ["placeholder", "output"] | ||
|
||
node_args = ", ".join([str(arg) for arg in node.args]) | ||
|
||
if node.target.__name__ == "getitem": | ||
return f"{node} = {node.args[0]}[{node.args[1]}]" | ||
|
||
# simplify or put this in a new function | ||
opname = str(node.target) if str(node.target).startswith("aten.") else node.target.__name__ | ||
if ( | ||
not opname.startswith("aten.") | ||
and not opname.startswith("ttnn.") | ||
and not isinstance(node.target, types.BuiltinFunctionType) | ||
): | ||
lines = inspect.getsource(node.target) | ||
wrapper_funcs.add(lines) | ||
|
||
# find a better way to use registered custom builtins to replace TTNN constants | ||
# statement = f"{node} = {opname}({str(node.args)[1:-1]}, {format_dict(node.kwargs)})" | ||
statement = f"{node} = {opname}({node_args}, {format_dict(node.kwargs)})" | ||
replace_map = { | ||
"ttnn_Specified_Device": "device", | ||
"ttnn_TILE_LAYOUT": "ttnn.TILE_LAYOUT", | ||
"ttnn_ROW_MAJOR_LAYOUT": "ttnn.ROW_MAJOR_LAYOUT", | ||
"ttnn_L1_MEMORY_CONFIG": "ttnn.L1_MEMORY_CONFIG", | ||
"ttnn_DRAM_MEMORY_CONFIG": "ttnn.DRAM_MEMORY_CONFIG", | ||
"ttnn_uint32": "ttnn.uint32", | ||
"ttnn_bfloat16": "ttnn.bfloat16", | ||
} | ||
|
||
for k, v in replace_map.items(): | ||
statement = statement.replace(k, v) | ||
return statement | ||
|
||
|
||
def users_have_getitem(node): | ||
for user in list(node.users.keys()): | ||
if user.op != "output" and user.op != "placeholder" and user.target.__name__ == "getitem": | ||
return user | ||
return None | ||
|
||
|
||
# does this modify node in-place? | ||
def rename_input_args_from_graph_break(output_nodes, node): | ||
# check for previous outputs | ||
if node.name == "clone": | ||
for out_arg in reversed(output_nodes[-1]): | ||
if out_arg.name.startswith("primals"): | ||
if out_arg.name in [a.name for a in output_nodes[-1]]: | ||
node.replace_all_uses_with(out_arg, delete_user_cb=lambda node: node != out_arg) | ||
else: | ||
node._rename(out_arg.name) | ||
break | ||
if node.name.startswith("tangent"): | ||
first_primal_idx = 0 | ||
for i, out_arg in enumerate(output_nodes[-1]): | ||
if out_arg.name.startswith("primals"): | ||
first_primal_idx = i | ||
break | ||
tangent_node = output_nodes[-1][first_primal_idx - 1] | ||
if tangent_node.name in [a.name for a in output_nodes[-1]]: | ||
node.replace_all_uses_with(tangent_node, delete_user_cb=lambda node: node != tangent_node) | ||
else: | ||
node._rename(tangent_node.name) | ||
|
||
|
||
def compute_key(node): | ||
return str(node.meta["seq_nr"]) + node.meta["original_aten"]._name + str(node.meta["val"]) | ||
|
||
|
||
def map_meta_to_aten_node(aten_graph): | ||
aten_name_to_node_map = defaultdict(list) | ||
for node in aten_graph.nodes: | ||
if "val" in node.meta: | ||
print("aten graph meta:", node, node.meta["val"]) | ||
if node.op != "placeholder" and node.op != "output": | ||
print("aten meta:", node, node.meta) | ||
aten_name_to_node_map[compute_key(node)] = node | ||
return aten_name_to_node_map | ||
|
||
|
||
def map_aten_node_to_ttnn_node(ttnn_graph, output_nodes, aten_name_to_node_map): | ||
aten_to_ttnn_map = defaultdict(list) | ||
for node in ttnn_graph.nodes: | ||
if node.op == "placeholder": | ||
rename_input_args_from_graph_break(output_nodes, node) | ||
continue | ||
|
||
if node.op != "placeholder" and node.op != "output": | ||
# ignore to_torch | ||
if node.target == ttnn.to_torch: | ||
continue | ||
if "seq_nr" in node.meta: | ||
aten_node_name = compute_key(node) | ||
print("aten_node_name:", node, aten_node_name) | ||
aten_node = aten_name_to_node_map[aten_node_name] | ||
aten_to_ttnn_map[aten_node].append(node) | ||
# also append gettiem if exists | ||
return aten_to_ttnn_map | ||
|
||
|
||
# gather ttnn ops into a list and insert tuple in between | ||
# this does not alter the graph, but prepare the list to translate to textual code | ||
def process_ttnn_ops(ttnn_graph, aten_name_to_node_map, aten_to_ttnn_map): | ||
ttnn_all_nodes = [] | ||
for node in ttnn_graph.nodes: | ||
if node.op == "output": | ||
continue | ||
if node.op == "placeholder": | ||
# arg_nodes.append(node) | ||
continue | ||
# # val = node.meta["val"] | ||
# # print(f"{node.name} = torch.rand({tuple(val.size())}, dtype={val.dtype})") | ||
ttnn_all_nodes.append(node) | ||
# if ((from_node := node.meta.get("from_node", None)) is not None): | ||
if "seq_nr" in node.meta: | ||
print("ttnn meta:", node, node.meta["seq_nr"], node.meta["original_aten"]._name, str(node.meta["val"])) | ||
aten_node_name = compute_key(node) | ||
aten_node = aten_name_to_node_map[aten_node_name] | ||
# this is the last ttnn node for this aten op, compare the output of this | ||
print("aten_to_ttnn_map:", aten_to_ttnn_map[aten_node]) | ||
if node == aten_to_ttnn_map[aten_node][-1]: | ||
# this will be converted to test_accuracy(node1, node2) later | ||
# do not emit if users are getitem | ||
if not users_have_getitem(node): | ||
if (getitem := users_have_getitem(aten_node)) is not None: | ||
aten_node = getitem | ||
ttnn_all_nodes.append((aten_node, node)) | ||
return ttnn_all_nodes | ||
|
||
|
||
def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_inputs, *, verbose=False): | ||
assert len(aten_fx_graphs) == len(ttnn_fx_graphs) | ||
|
||
print("len graphs:", len(aten_fx_graphs), len(ttnn_fx_graphs)) | ||
|
||
test_accuracy_graph_codes = [] | ||
output_nodes = [] | ||
for aten_graph, ttnn_graph in zip(aten_fx_graphs, ttnn_fx_graphs): | ||
ttnn_graph = rename_nodes(ttnn_graph, "ttnn_prefix") | ||
|
||
# map meta data to aten node | ||
# TODO: make sure key matches the correct group of aten ops | ||
aten_name_to_node_map = map_meta_to_aten_node(aten_graph) | ||
|
||
# gather all aten nodes and arg nodes | ||
aten_all_nodes = [] | ||
arg_nodes = [] | ||
for node in aten_graph.nodes: | ||
if node.op == "placeholder": | ||
rename_input_args_from_graph_break(output_nodes, node) | ||
arg_nodes.append(node) | ||
continue | ||
aten_all_nodes.append(node) | ||
|
||
print("aten graph args:", arg_nodes) | ||
|
||
# preprocess: map aten to ttnn ops. this is to know what is the last ttnn op in group to compare output | ||
aten_to_ttnn_map = map_aten_node_to_ttnn_node(ttnn_graph, output_nodes, aten_name_to_node_map) | ||
|
||
# interleave aten with ttnn code and insert test_accuracy code at the end of each section | ||
ttnn_all_nodes = process_ttnn_ops(ttnn_graph, aten_name_to_node_map, aten_to_ttnn_map) | ||
|
||
# finally convert interleaved nodes to python code for this graph | ||
arg_node_names = [node.name for node in arg_nodes] | ||
|
||
forward_signature = f"def forward({', '.join(arg_node_names)}):" | ||
# comment out signature if not the first graph | ||
graph_code = [forward_signature] if len(output_nodes) == 0 else [" # " + forward_signature] | ||
graph_code.append(" device = ttnn.open_device(device_id=0, l1_small_size=16384)") | ||
for node in aten_all_nodes: | ||
if node.op == "output": | ||
output_nodes.append(node.args[0]) | ||
graph_code.append(f" # return {node.args[0]}") | ||
continue | ||
else: | ||
graph_code.append(f" {node_to_python_code(node)}") | ||
for node in ttnn_all_nodes: | ||
if isinstance(node, tuple): | ||
graph_code.append(f" test_accuracy({node[0]}, {node[1]})") | ||
else: | ||
graph_code.append(f" {node_to_python_code(node)}") | ||
graph_code.append(" ttnn.close_device(device)") | ||
|
||
test_accuracy_graph_codes.append("\n".join(graph_code)) | ||
|
||
# arrange full code | ||
import_code = ["import ttnn", "import torch", "import numpy as np", "aten = torch.ops.aten"] | ||
|
||
# this needs to be done after the graph_code above because wrapper functions need to be resolved at that stage | ||
wrapper_code = list(wrapper_funcs) | ||
|
||
# test_accuracy helper code | ||
test_accuracy_code = """ | ||
def test_accuracy(tensor1, tensor2): | ||
if isinstance(tensor2, ttnn.Tensor): | ||
tensor2 = ttnn.to_torch(tensor2) | ||
assert_with_pcc(tensor1, tensor2, pcc = 0.90) | ||
""" | ||
|
||
# pcc functions | ||
pcc_funcs = [ | ||
inspect.getsource(comp_pcc), | ||
inspect.getsource(construct_pcc_assert_message), | ||
inspect.getsource(assert_with_pcc), | ||
] | ||
|
||
directory = Path("accuracy_tests") | ||
directory.mkdir(parents=True, exist_ok=True) | ||
|
||
# main | ||
input_pkl_file = Path(f"{model_name}_inputs.pickle") | ||
full_input_pkl_path = directory / input_pkl_file | ||
main_code = f""" | ||
import pickle | ||
if __name__ == "__main__": | ||
try: | ||
file = open("{full_input_pkl_path}", "rb") | ||
except: | ||
file = open("{input_pkl_file}", "rb") | ||
inputs = pickle.load(file) | ||
forward(*inputs) | ||
""" | ||
|
||
full_code = import_code + pcc_funcs + wrapper_code + [test_accuracy_code] + test_accuracy_graph_codes + [main_code] | ||
full_text = "\n".join(full_code) | ||
|
||
with open(directory / Path(f"{model_name}_code.py"), "w") as text_file: | ||
print(full_text, file=text_file) | ||
|
||
with open(directory / Path(f"{model_name}_inputs.pickle"), "wb") as f: | ||
pickle.dump(all_inputs, f) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to
tools/
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to find a better location, instead of a top level module dir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to
tools/
. In a future update, I will further separate them into