Skip to content

Commit

Permalink
Update precision for optimized trigonometric functions (#276)
Browse files Browse the repository at this point in the history
* Update precision for optimized trigonometric functions (tenstorrent/tt-llk-wh-b0#39)

* Test trigs with [-5, 5], a wider range of arguments

Unfortunately, we can't do this to `tan` yet.  Reciprocal is still numerically unstable.

* Check improved precision by tenstorrent/tt-metal#13339

Tangent is still unstable around odd multiples of π/2, but we've still
make it precise for most arguments away from 0.

* Cite tenstorrent/tt-metal#14414 in the xfailed test case for `ttnn.tan`

* Improve precision of `ttnn.tan`
  • Loading branch information
jdh8 authored Nov 6, 2024
1 parent 38bfbf3 commit 196ae75
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
17 changes: 10 additions & 7 deletions tests/lowering/eltwise/unary/test_cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest
import ttnn

from tests.utils import assert_with_pcc


class CosModule(torch.nn.Module):
def __init__(self):
Expand All @@ -13,22 +15,23 @@ def forward(self, x):


@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
"input_shape",
((4, 4), (1, 1066)),
)
def test_cos(device, input_shapes):
def test_cos(device, input_shape):
m = CosModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
input = torch.rand(input_shape, dtype=torch.bfloat16) * 10 - 5
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.cos) == 1

# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.2)
assert_with_pcc(result_before, result_after)
12 changes: 7 additions & 5 deletions tests/lowering/eltwise/unary/test_sin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest
import ttnn

from tests.utils import assert_with_pcc


class SinModule(torch.nn.Module):
def __init__(self):
Expand All @@ -13,12 +15,12 @@ def forward(self, x):


@pytest.mark.parametrize(
("input_shape", "init_offset"),
[((4, 4), 0)],
"input_shape",
((4, 4), (1, 1066)),
)
def test_sin(device, input_shape, init_offset):
def test_sin(device, input_shape):
m = SinModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) + init_offset
input = torch.rand(input_shape, dtype=torch.bfloat16) * 10 - 5
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
Expand All @@ -32,4 +34,4 @@ def test_sin(device, input_shape, init_offset):
assert [node.target for node in nodes].count(ttnn.sin) == 1

# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.1, atol=0.1)
assert_with_pcc(result_before, result_after)
17 changes: 12 additions & 5 deletions tests/lowering/eltwise/unary/test_tan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest
import ttnn

from tests.utils import assert_with_pcc


class TanModule(torch.nn.Module):
def __init__(self):
Expand All @@ -13,12 +15,17 @@ def forward(self, x):


@pytest.mark.parametrize(
("input_shape", "init_offset"),
[((4, 4), 0)],
"input_shape, range",
(
((4, 4), 1),
((1, 1066), 1),
((1, 1066), 1.5),
pytest.param((1, 1066), 1.6, marks=pytest.mark.xfail(reason="tt-metal#14414: inaccurate reciprocal")),
),
)
def test_tan(device, input_shape, init_offset):
def test_tan(device, input_shape, range):
m = TanModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) + init_offset
input = (torch.rand(input_shape, dtype=torch.bfloat16) * 2 - 1) * range
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
Expand All @@ -32,4 +39,4 @@ def test_tan(device, input_shape, init_offset):
assert [node.target for node in nodes].count(ttnn.tan) == 1

# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.1, atol=0.1)
assert_with_pcc(result_before, result_after)

0 comments on commit 196ae75

Please sign in to comment.