|
11 | 11 | from tests.utils import assert_with_pcc, comp_pcc, construct_pcc_assert_message
|
12 | 12 |
|
13 | 13 | wrapper_funcs = set()
|
| 14 | +rename_wrappers = set() |
14 | 15 |
|
15 | 16 |
|
16 | 17 | # Returns a list of inputs to the first graph of model
|
@@ -89,6 +90,15 @@ def node_to_python_code(node):
|
89 | 90 | ):
|
90 | 91 | lines = inspect.getsource(node.target)
|
91 | 92 | wrapper_funcs.add(lines)
|
| 93 | + # rename functions to avoid naming conflict |
| 94 | + func_name = node.target.__name__ |
| 95 | + rename_func = f""" |
| 96 | +ref = globals()["{func_name}"] |
| 97 | +globals()["{func_name}_wrapper"] = ref |
| 98 | +del globals()["{func_name}"] |
| 99 | +""" |
| 100 | + rename_wrappers.add(rename_func) |
| 101 | + opname += "_wrapper" |
92 | 102 |
|
93 | 103 | # find a better way to use registered custom builtins to replace TTNN constants
|
94 | 104 | statement = f"{node} = {opname}({node_args}, {format_dict(node.kwargs)})"
|
@@ -266,6 +276,9 @@ def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_i
|
266 | 276 | # this needs to be done after the graph_code above because wrapper functions need to be resolved at that stage
|
267 | 277 | wrapper_code = list(wrapper_funcs)
|
268 | 278 |
|
| 279 | + # this needs to be done after defining wrapper functions |
| 280 | + rename_wrapper_code = list(rename_wrappers) |
| 281 | + |
269 | 282 | # test_accuracy helper code
|
270 | 283 | test_accuracy_code = """
|
271 | 284 | def test_accuracy(expected, actual):
|
@@ -294,7 +307,15 @@ def test_accuracy(expected, actual):
|
294 | 307 | forward(*inputs)
|
295 | 308 | """
|
296 | 309 |
|
297 |
| - full_code = import_code + pcc_funcs + wrapper_code + [test_accuracy_code] + test_accuracy_graph_codes + [main_code] |
| 310 | + full_code = ( |
| 311 | + import_code |
| 312 | + + pcc_funcs |
| 313 | + + wrapper_code |
| 314 | + + rename_wrapper_code |
| 315 | + + [test_accuracy_code] |
| 316 | + + test_accuracy_graph_codes |
| 317 | + + [main_code] |
| 318 | + ) |
298 | 319 | full_text = "\n".join(full_code)
|
299 | 320 |
|
300 | 321 | with open(directory / Path(f"{model_name}_code.py"), "w") as text_file:
|
|
0 commit comments