Skip to content

Commit a5bdb40

Browse files
committed
Cleanup
1 parent b2ab995 commit a5bdb40

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

torch_ttnn/generate_op_accuracy_tests.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def node_to_python_code(node):
9191
wrapper_funcs.add(lines)
9292

9393
# find a better way to use registered custom builtins to replace TTNN constants
94-
# statement = f"{node} = {opname}({str(node.args)[1:-1]}, {format_dict(node.kwargs)})"
9594
statement = f"{node} = {opname}({node_args}, {format_dict(node.kwargs)})"
9695
replace_map = {
9796
"ttnn_Specified_Device": "device",
@@ -275,10 +274,10 @@ def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_i
275274

276275
# test_accuracy helper code
277276
test_accuracy_code = """
278-
def test_accuracy(tensor1, tensor2):
279-
if isinstance(tensor2, ttnn.Tensor):
280-
tensor2 = ttnn.to_torch(tensor2)
281-
assert_with_pcc(tensor1, tensor2, pcc = 0.90)
277+
def test_accuracy(expected, actual):
278+
if isinstance(actual, ttnn.Tensor):
279+
actual = ttnn.to_torch(actual)
280+
assert_with_pcc(expected, actual, pcc = 0.90)
282281
"""
283282

284283
# pcc functions

0 commit comments

Comments
 (0)