From 196ae752825cab6c532fa8ed741f3acd030385a3 Mon Sep 17 00:00:00 2001 From: Chen-Pang He Date: Wed, 6 Nov 2024 19:36:22 +0800 Subject: [PATCH] Update precision for optimized trigonometric functions (#276) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update precision for optimized trigonometric functions (tenstorrent/tt-llk-wh-b0#39) * Test trigs with [-5, 5], a wider range of arguments Unfortunately, we can't do this to `tan` yet. Reciprocal is still numerically unstable. * Check improved precision by tenstorrent/tt-metal#13339 Tangent is still unstable around odd multiples of π/2, but we've still make it precise for most arguments away from 0. * Cite tenstorrent/tt-metal#14414 in the xfailed test case for `ttnn.tan` * Improve precision of `ttnn.tan` --- tests/lowering/eltwise/unary/test_cos.py | 17 ++++++++++------- tests/lowering/eltwise/unary/test_sin.py | 12 +++++++----- tests/lowering/eltwise/unary/test_tan.py | 17 ++++++++++++----- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/lowering/eltwise/unary/test_cos.py b/tests/lowering/eltwise/unary/test_cos.py index 6f7db3087..b1c4d02ba 100644 --- a/tests/lowering/eltwise/unary/test_cos.py +++ b/tests/lowering/eltwise/unary/test_cos.py @@ -3,6 +3,8 @@ import pytest import ttnn +from tests.utils import assert_with_pcc + class CosModule(torch.nn.Module): def __init__(self): @@ -13,22 +15,23 @@ def forward(self, x): @pytest.mark.parametrize( - "input_shapes", - [[(4, 4)]], + "input_shape", + ((4, 4), (1, 1066)), ) -def test_cos(device, input_shapes): +def test_cos(device, input_shape): m = CosModule() - inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] - result_before = m.forward(*inputs) + input = torch.rand(input_shape, dtype=torch.bfloat16) * 10 - 5 + result_before = m.forward(input) option = torch_ttnn.TorchTtnnOption(device=device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation m = torch.compile(m, backend=torch_ttnn.backend, options=option) - result_after = m.forward(*inputs) + result_after = m.forward(input) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) assert [node.target for node in nodes].count(ttnn.cos) == 1 + # Check inference result - assert torch.allclose(result_before, result_after, rtol=0.2) + assert_with_pcc(result_before, result_after) diff --git a/tests/lowering/eltwise/unary/test_sin.py b/tests/lowering/eltwise/unary/test_sin.py index 198c5fc64..bc441dc64 100644 --- a/tests/lowering/eltwise/unary/test_sin.py +++ b/tests/lowering/eltwise/unary/test_sin.py @@ -3,6 +3,8 @@ import pytest import ttnn +from tests.utils import assert_with_pcc + class SinModule(torch.nn.Module): def __init__(self): @@ -13,12 +15,12 @@ def forward(self, x): @pytest.mark.parametrize( - ("input_shape", "init_offset"), - [((4, 4), 0)], + "input_shape", + ((4, 4), (1, 1066)), ) -def test_sin(device, input_shape, init_offset): +def test_sin(device, input_shape): m = SinModule() - input = torch.rand(input_shape, dtype=torch.bfloat16) + init_offset + input = torch.rand(input_shape, dtype=torch.bfloat16) * 10 - 5 result_before = m.forward(input) option = torch_ttnn.TorchTtnnOption(device=device) option.gen_graphviz = True @@ -32,4 +34,4 @@ def test_sin(device, input_shape, init_offset): assert [node.target for node in nodes].count(ttnn.sin) == 1 # Check inference result - assert torch.allclose(result_before, result_after, rtol=0.1, atol=0.1) + assert_with_pcc(result_before, result_after) diff --git a/tests/lowering/eltwise/unary/test_tan.py b/tests/lowering/eltwise/unary/test_tan.py index 60bed5a2a..48d525d21 100644 --- a/tests/lowering/eltwise/unary/test_tan.py +++ b/tests/lowering/eltwise/unary/test_tan.py @@ -3,6 +3,8 @@ import pytest import ttnn +from tests.utils import assert_with_pcc + class TanModule(torch.nn.Module): def __init__(self): @@ -13,12 +15,17 @@ def forward(self, x): @pytest.mark.parametrize( - ("input_shape", "init_offset"), - [((4, 4), 0)], + "input_shape, range", + ( + ((4, 4), 1), + ((1, 1066), 1), + ((1, 1066), 1.5), + pytest.param((1, 1066), 1.6, marks=pytest.mark.xfail(reason="tt-metal#14414: inaccurate reciprocal")), + ), ) -def test_tan(device, input_shape, init_offset): +def test_tan(device, input_shape, range): m = TanModule() - input = torch.rand(input_shape, dtype=torch.bfloat16) + init_offset + input = (torch.rand(input_shape, dtype=torch.bfloat16) * 2 - 1) * range result_before = m.forward(input) option = torch_ttnn.TorchTtnnOption(device=device) option.gen_graphviz = True @@ -32,4 +39,4 @@ def test_tan(device, input_shape, init_offset): assert [node.target for node in nodes].count(ttnn.tan) == 1 # Check inference result - assert torch.allclose(result_before, result_after, rtol=0.1, atol=0.1) + assert_with_pcc(result_before, result_after)