Skip to content

Commit 52b47f3

Browse files
committed
Dynamically rename wrapper functions to avoid naming conflicts
1 parent 57c8a5e commit 52b47f3

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

torch_ttnn/generate_op_accuracy_tests.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tests.utils import assert_with_pcc, comp_pcc, construct_pcc_assert_message
1212

1313
wrapper_funcs = set()
14+
rename_wrappers = set()
1415

1516

1617
# Returns a list of inputs to the first graph of model
@@ -89,6 +90,15 @@ def node_to_python_code(node):
8990
):
9091
lines = inspect.getsource(node.target)
9192
wrapper_funcs.add(lines)
93+
# rename functions to avoid naming conflict
94+
func_name = node.target.__name__
95+
rename_func = f"""
96+
ref = globals()["{func_name}"]
97+
globals()["{func_name}_wrapper"] = ref
98+
del globals()["{func_name}"]
99+
"""
100+
rename_wrappers.add(rename_func)
101+
opname += "_wrapper"
92102

93103
# find a better way to use registered custom builtins to replace TTNN constants
94104
statement = f"{node} = {opname}({node_args}, {format_dict(node.kwargs)})"
@@ -266,6 +276,9 @@ def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_i
266276
# this needs to be done after the graph_code above because wrapper functions need to be resolved at that stage
267277
wrapper_code = list(wrapper_funcs)
268278

279+
# this needs to be done after defining wrapper functions
280+
rename_wrapper_code = list(rename_wrappers)
281+
269282
# test_accuracy helper code
270283
test_accuracy_code = """
271284
def test_accuracy(expected, actual):
@@ -294,7 +307,15 @@ def test_accuracy(expected, actual):
294307
forward(*inputs)
295308
"""
296309

297-
full_code = import_code + pcc_funcs + wrapper_code + [test_accuracy_code] + test_accuracy_graph_codes + [main_code]
310+
full_code = (
311+
import_code
312+
+ pcc_funcs
313+
+ wrapper_code
314+
+ rename_wrapper_code
315+
+ [test_accuracy_code]
316+
+ test_accuracy_graph_codes
317+
+ [main_code]
318+
)
298319
full_text = "\n".join(full_code)
299320

300321
with open(directory / Path(f"{model_name}_code.py"), "w") as text_file:

0 commit comments

Comments
 (0)