Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate standalone code that compares the accuracy between corresponding aten ops and lowered TTNN ops #611

Merged
merged 36 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
793132d
Add option to generate standalone python script for model runs
kevinwuTT Dec 16, 2024
efa1ebd
Refactor some variables and path names
kevinwuTT Dec 16, 2024
672f741
Add lzma compression for input pickle files
kevinwuTT Dec 17, 2024
61d4f7b
Add action to run accuracy tests in CI
kevinwuTT Dec 17, 2024
9fb9353
Try different for loop in actions
kevinwuTT Dec 18, 2024
f92d244
Fix gen accuracy flag to the right action
kevinwuTT Dec 18, 2024
b1ce493
Fix file paths
kevinwuTT Dec 18, 2024
39b81e7
Add ttnn.squeeze for lowering embedding when the rank of input is 1
kevinwuTT Dec 21, 2024
2119337
Support metadata for get_attr nodes
kevinwuTT Dec 21, 2024
b2ab995
Support torch.device cases
kevinwuTT Dec 21, 2024
a5bdb40
Cleanup
kevinwuTT Dec 21, 2024
f0487d0
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT Dec 21, 2024
7200411
Cleanup
kevinwuTT Dec 21, 2024
4a27c75
Revert "Add ttnn.squeeze for lowering embedding when the rank of inpu…
kevinwuTT Dec 23, 2024
d94b511
Remove uneeded code
kevinwuTT Dec 23, 2024
57c8a5e
Support get_attr nodes
kevinwuTT Dec 23, 2024
52b47f3
Dynamically rename wrapper functions to avoid naming conflicts
kevinwuTT Dec 23, 2024
d45109a
Some get_attr nodes do not have tensor_meta or val metadata
kevinwuTT Dec 26, 2024
2b960d1
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT Dec 27, 2024
cd5a3ca
Refactor
kevinwuTT Jan 1, 2025
7020388
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT Jan 1, 2025
6bbb613
Support cases where aten ops are decomposed and the decomposed ops ha…
kevinwuTT Jan 2, 2025
50e00f8
Refactor
kevinwuTT Jan 4, 2025
12065a7
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT Jan 4, 2025
d3cf5db
Fix oops
kevinwuTT Jan 4, 2025
5f04797
Fix issue with 0-d tensor and add section on how to use this feature
kevinwuTT Jan 6, 2025
a0941b6
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT Jan 6, 2025
4f815b7
Move autogen directory under tests/autogen_accuracy_tests and add pri…
kevinwuTT Jan 6, 2025
17ea0ed
Use safetensors instead of pickle
kevinwuTT Jan 6, 2025
125935d
Add small helper function for detecting regular operations
kevinwuTT Jan 6, 2025
f247e4b
Revert "Use safetensors instead of pickle"
kevinwuTT Jan 6, 2025
4524d66
Fix accuracy test paths
kevinwuTT Jan 6, 2025
278e8e2
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT Jan 17, 2025
acdc5e8
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT Jan 23, 2025
a1206eb
Merge branch 'main' into kw/gen_acc_tests
kevinwuTT Jan 27, 2025
46e489f
Move gen standalone code to tools
kevinwuTT Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
import subprocess
import sys

import torch_ttnn.generate_op_accuracy_tests as generate_op_accuracy_tests

mb_in_bytes = 1048576


def pytest_addoption(parser):
parser.addoption("--input_var_only_native", action="store_true")
parser.addoption("--input_var_check_ttnn", action="store_true")
parser.addoption("--input_var_check_accu", action="store_true")
parser.addoption("--gen_op_accuracy_tests", action="store_true")


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -121,13 +124,22 @@ def compile_and_run(device, reset_torch_dynamo, request):
run_mem_analysis=False,
metrics_path=model_name,
verbose=True,
gen_op_accuracy_tests=request.config.getoption("--gen_op_accuracy_tests"),
)

start = time.perf_counter() * 1000

outputs_after = model_tester.test_model(as_ttnn=True, option=option)

end = time.perf_counter() * 1000
comp_runtime_metrics = {"success": True, "run_time": round(end - start, 2)}

# set to one variable?
if request.config.getoption("--gen_op_accuracy_tests"):
generate_op_accuracy_tests.generate_op_accuracy_tests(
model_name, option._aten_fx_graphs, option._out_fx_graphs, option._all_inputs
)

if len(option._out_fx_graphs) > 0:
option._out_fx_graphs[0].print_tabular()
if model_name not in ["speecht5-tts"]:
Expand Down
19 changes: 19 additions & 0 deletions torch_ttnn/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
import os
from torch_ttnn.handle_input_aliasing import insert_clones_for_input_aliasing
import torch_ttnn.generate_op_accuracy_tests as generate_op_accuracy_tests
import torch_ttnn.metrics as metrics
from torch_ttnn import mem_utils

Expand All @@ -28,6 +29,7 @@ def __init__(
tracer_option=None,
bypass_compile=False,
use_less_ttnn_op_types=True,
gen_op_accuracy_tests=False,
):
self.device = device
self.gen_graphviz = gen_graphviz
Expand All @@ -44,11 +46,19 @@ def __init__(
self.original_schema_list = list()
self.compiled_schema_list = list()

# Used for generate standalone python script
self.gen_op_accuracy_tests = gen_op_accuracy_tests
self._aten_fx_graphs = list()
self._all_inputs = None

def reset_containers(self):
self._out_fx_graphs = list()
self.original_schema_list = list()


from pdb import set_trace as bp


def register_ttnn_objects(option: TorchTtnnOption):
"""
torch.fx builds a source object as a string, calls builtin compile(), and finally
Expand Down Expand Up @@ -95,6 +105,11 @@ def aten_backend(
from .handle_input_aliasing import remove_clones_for_input_aliasing

gm = remove_clones_for_input_aliasing(gm)

# Save aten graph if requested
if options.gen_op_accuracy_tests:
option._aten_fx_graphs.append(gm.graph)

# Save the number of aten ops before compilation
if option.metrics_path:
option.original_schema_list.extend(metrics.collect_input_variations_from_list_nodes(gm.graph.nodes))
Expand Down Expand Up @@ -217,6 +232,10 @@ def ttnn_backend(
example_inputs: List[torch.Tensor],
options: TorchTtnnOption = None,
) -> torch.fx.GraphModule:
# Save all parameters and inputs if requested
if options.gen_op_accuracy_tests:
options._all_inputs = generate_op_accuracy_tests.generate_flat_args(gm, example_inputs)

tracer_option = options.tracer_option
if tracer_option is not None:
from ..tracer import Tracer
Expand Down
280 changes: 280 additions & 0 deletions torch_ttnn/generate_op_accuracy_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import inspect
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to tools/?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to find a better location, instead of a top level module dir

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to tools/. In a future update, I will further separate them into

tools\
|-- export code base module
    |-- generate accuracy code
    |-- generate profiling code

import pickle
import torch.utils._pytree as pytree
import ttnn
import types

from collections import defaultdict
from pathlib import Path
from tests.utils import assert_with_pcc, comp_pcc, construct_pcc_assert_message

wrapper_funcs = set()


# Returns a list of inputs to the first graph of model
def generate_flat_args(gm, example_inputs):
full_args = []
# torch/_functorch/aot_autograd.py::aot_module_simplified
params = {
**dict(gm.named_parameters(remove_duplicate=False)),
**dict(gm.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = list(params_flat)
full_args.extend(params_flat)
full_args.extend(example_inputs)

return full_args


# rename node names because some wrapper or built-in functions have the same name
def rename_nodes(graph, prefix):
for node in graph.nodes:
if node.op != "placeholder" and node.op != "output":
# simplify or put this in a new function
opname = str(node.target) if str(node.target).startswith("aten.") else node.target.__name__
if not opname.startswith("aten.") and not opname.startswith("ttnn."):
node._rename(f"{prefix}_{node.name}")
return graph


def format_dict(obj):
to_kwargs = [f"{k} = {v}" for k, v in obj.items()]
return ", ".join(to_kwargs)


# collect wrapper functions if required
def node_to_python_code(node):
# assume no placeholder and output?
assert node.op not in ["placeholder", "output"]

node_args = ", ".join([str(arg) for arg in node.args])

if node.target.__name__ == "getitem":
return f"{node} = {node.args[0]}[{node.args[1]}]"

# simplify or put this in a new function
opname = str(node.target) if str(node.target).startswith("aten.") else node.target.__name__
if (
not opname.startswith("aten.")
and not opname.startswith("ttnn.")
and not isinstance(node.target, types.BuiltinFunctionType)
):
lines = inspect.getsource(node.target)
wrapper_funcs.add(lines)

# find a better way to use registered custom builtins to replace TTNN constants
# statement = f"{node} = {opname}({str(node.args)[1:-1]}, {format_dict(node.kwargs)})"
statement = f"{node} = {opname}({node_args}, {format_dict(node.kwargs)})"
replace_map = {
"ttnn_Specified_Device": "device",
"ttnn_TILE_LAYOUT": "ttnn.TILE_LAYOUT",
"ttnn_ROW_MAJOR_LAYOUT": "ttnn.ROW_MAJOR_LAYOUT",
"ttnn_L1_MEMORY_CONFIG": "ttnn.L1_MEMORY_CONFIG",
"ttnn_DRAM_MEMORY_CONFIG": "ttnn.DRAM_MEMORY_CONFIG",
"ttnn_uint32": "ttnn.uint32",
"ttnn_bfloat16": "ttnn.bfloat16",
}

for k, v in replace_map.items():
statement = statement.replace(k, v)
return statement


def users_have_getitem(node):
for user in list(node.users.keys()):
if user.op != "output" and user.op != "placeholder" and user.target.__name__ == "getitem":
return user
return None


# does this modify node in-place?
def rename_input_args_from_graph_break(output_nodes, node):
# check for previous outputs
if node.name == "clone":
for out_arg in reversed(output_nodes[-1]):
if out_arg.name.startswith("primals"):
if out_arg.name in [a.name for a in output_nodes[-1]]:
node.replace_all_uses_with(out_arg, delete_user_cb=lambda node: node != out_arg)
else:
node._rename(out_arg.name)
break
if node.name.startswith("tangent"):
first_primal_idx = 0
for i, out_arg in enumerate(output_nodes[-1]):
if out_arg.name.startswith("primals"):
first_primal_idx = i
break
tangent_node = output_nodes[-1][first_primal_idx - 1]
if tangent_node.name in [a.name for a in output_nodes[-1]]:
node.replace_all_uses_with(tangent_node, delete_user_cb=lambda node: node != tangent_node)
else:
node._rename(tangent_node.name)


def compute_key(node):
return str(node.meta["seq_nr"]) + node.meta["original_aten"]._name + str(node.meta["val"])


def map_meta_to_aten_node(aten_graph):
aten_name_to_node_map = defaultdict(list)
for node in aten_graph.nodes:
if "val" in node.meta:
print("aten graph meta:", node, node.meta["val"])
if node.op != "placeholder" and node.op != "output":
print("aten meta:", node, node.meta)
aten_name_to_node_map[compute_key(node)] = node
return aten_name_to_node_map


def map_aten_node_to_ttnn_node(ttnn_graph, output_nodes, aten_name_to_node_map):
aten_to_ttnn_map = defaultdict(list)
for node in ttnn_graph.nodes:
if node.op == "placeholder":
rename_input_args_from_graph_break(output_nodes, node)
continue

if node.op != "placeholder" and node.op != "output":
# ignore to_torch
if node.target == ttnn.to_torch:
continue
if "seq_nr" in node.meta:
aten_node_name = compute_key(node)
print("aten_node_name:", node, aten_node_name)
aten_node = aten_name_to_node_map[aten_node_name]
aten_to_ttnn_map[aten_node].append(node)
# also append gettiem if exists
return aten_to_ttnn_map


# gather ttnn ops into a list and insert tuple in between
# this does not alter the graph, but prepare the list to translate to textual code
def process_ttnn_ops(ttnn_graph, aten_name_to_node_map, aten_to_ttnn_map):
ttnn_all_nodes = []
for node in ttnn_graph.nodes:
if node.op == "output":
continue
if node.op == "placeholder":
# arg_nodes.append(node)
continue
# # val = node.meta["val"]
# # print(f"{node.name} = torch.rand({tuple(val.size())}, dtype={val.dtype})")
ttnn_all_nodes.append(node)
# if ((from_node := node.meta.get("from_node", None)) is not None):
if "seq_nr" in node.meta:
print("ttnn meta:", node, node.meta["seq_nr"], node.meta["original_aten"]._name, str(node.meta["val"]))
aten_node_name = compute_key(node)
aten_node = aten_name_to_node_map[aten_node_name]
# this is the last ttnn node for this aten op, compare the output of this
print("aten_to_ttnn_map:", aten_to_ttnn_map[aten_node])
if node == aten_to_ttnn_map[aten_node][-1]:
# this will be converted to test_accuracy(node1, node2) later
# do not emit if users are getitem
if not users_have_getitem(node):
if (getitem := users_have_getitem(aten_node)) is not None:
aten_node = getitem
ttnn_all_nodes.append((aten_node, node))
return ttnn_all_nodes


def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_inputs, *, verbose=False):
assert len(aten_fx_graphs) == len(ttnn_fx_graphs)

print("len graphs:", len(aten_fx_graphs), len(ttnn_fx_graphs))

test_accuracy_graph_codes = []
output_nodes = []
for aten_graph, ttnn_graph in zip(aten_fx_graphs, ttnn_fx_graphs):
ttnn_graph = rename_nodes(ttnn_graph, "ttnn_prefix")

# map meta data to aten node
# TODO: make sure key matches the correct group of aten ops
aten_name_to_node_map = map_meta_to_aten_node(aten_graph)

# gather all aten nodes and arg nodes
aten_all_nodes = []
arg_nodes = []
for node in aten_graph.nodes:
if node.op == "placeholder":
rename_input_args_from_graph_break(output_nodes, node)
arg_nodes.append(node)
continue
aten_all_nodes.append(node)

print("aten graph args:", arg_nodes)

# preprocess: map aten to ttnn ops. this is to know what is the last ttnn op in group to compare output
aten_to_ttnn_map = map_aten_node_to_ttnn_node(ttnn_graph, output_nodes, aten_name_to_node_map)

# interleave aten with ttnn code and insert test_accuracy code at the end of each section
ttnn_all_nodes = process_ttnn_ops(ttnn_graph, aten_name_to_node_map, aten_to_ttnn_map)

# finally convert interleaved nodes to python code for this graph
arg_node_names = [node.name for node in arg_nodes]

forward_signature = f"def forward({', '.join(arg_node_names)}):"
# comment out signature if not the first graph
graph_code = [forward_signature] if len(output_nodes) == 0 else [" # " + forward_signature]
graph_code.append(" device = ttnn.open_device(device_id=0, l1_small_size=16384)")
for node in aten_all_nodes:
if node.op == "output":
output_nodes.append(node.args[0])
graph_code.append(f" # return {node.args[0]}")
continue
else:
graph_code.append(f" {node_to_python_code(node)}")
for node in ttnn_all_nodes:
if isinstance(node, tuple):
graph_code.append(f" test_accuracy({node[0]}, {node[1]})")
else:
graph_code.append(f" {node_to_python_code(node)}")
graph_code.append(" ttnn.close_device(device)")

test_accuracy_graph_codes.append("\n".join(graph_code))

# arrange full code
import_code = ["import ttnn", "import torch", "import numpy as np", "aten = torch.ops.aten"]

# this needs to be done after the graph_code above because wrapper functions need to be resolved at that stage
wrapper_code = list(wrapper_funcs)

# test_accuracy helper code
test_accuracy_code = """
def test_accuracy(tensor1, tensor2):
if isinstance(tensor2, ttnn.Tensor):
tensor2 = ttnn.to_torch(tensor2)
assert_with_pcc(tensor1, tensor2, pcc = 0.90)
"""

# pcc functions
pcc_funcs = [
inspect.getsource(comp_pcc),
inspect.getsource(construct_pcc_assert_message),
inspect.getsource(assert_with_pcc),
]

directory = Path("accuracy_tests")
directory.mkdir(parents=True, exist_ok=True)

# main
input_pkl_file = Path(f"{model_name}_inputs.pickle")
full_input_pkl_path = directory / input_pkl_file
main_code = f"""
import pickle
if __name__ == "__main__":
try:
file = open("{full_input_pkl_path}", "rb")
except:
file = open("{input_pkl_file}", "rb")
inputs = pickle.load(file)
forward(*inputs)
"""

full_code = import_code + pcc_funcs + wrapper_code + [test_accuracy_code] + test_accuracy_graph_codes + [main_code]
full_text = "\n".join(full_code)

with open(directory / Path(f"{model_name}_code.py"), "w") as text_file:
print(full_text, file=text_file)

with open(directory / Path(f"{model_name}_inputs.pickle"), "wb") as f:
pickle.dump(all_inputs, f)
Loading