Skip to content

Commit

Permalink
Dynamically rename wrapper functions to avoid naming conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwuTT committed Dec 24, 2024
1 parent 57c8a5e commit 52b47f3
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion torch_ttnn/generate_op_accuracy_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tests.utils import assert_with_pcc, comp_pcc, construct_pcc_assert_message

wrapper_funcs = set()
rename_wrappers = set()


# Returns a list of inputs to the first graph of model
Expand Down Expand Up @@ -89,6 +90,15 @@ def node_to_python_code(node):
):
lines = inspect.getsource(node.target)
wrapper_funcs.add(lines)
# rename functions to avoid naming conflict
func_name = node.target.__name__
rename_func = f"""
ref = globals()["{func_name}"]
globals()["{func_name}_wrapper"] = ref
del globals()["{func_name}"]
"""
rename_wrappers.add(rename_func)
opname += "_wrapper"

# find a better way to use registered custom builtins to replace TTNN constants
statement = f"{node} = {opname}({node_args}, {format_dict(node.kwargs)})"
Expand Down Expand Up @@ -266,6 +276,9 @@ def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_i
# this needs to be done after the graph_code above because wrapper functions need to be resolved at that stage
wrapper_code = list(wrapper_funcs)

# this needs to be done after defining wrapper functions
rename_wrapper_code = list(rename_wrappers)

# test_accuracy helper code
test_accuracy_code = """
def test_accuracy(expected, actual):
Expand Down Expand Up @@ -294,7 +307,15 @@ def test_accuracy(expected, actual):
forward(*inputs)
"""

full_code = import_code + pcc_funcs + wrapper_code + [test_accuracy_code] + test_accuracy_graph_codes + [main_code]
full_code = (
import_code
+ pcc_funcs
+ wrapper_code
+ rename_wrapper_code
+ [test_accuracy_code]
+ test_accuracy_graph_codes
+ [main_code]
)
full_text = "\n".join(full_code)

with open(directory / Path(f"{model_name}_code.py"), "w") as text_file:
Expand Down

0 comments on commit 52b47f3

Please sign in to comment.