Skip to content

Commit 165e335

Browse files
leslie-fang-intelpytorchmergebot
authored andcommitted
[Inductor][CPP] Fix the vec codegen for tanh (pytorch#148254)
**Summary** Fix pytorch#148241, The previous vectorized code generation for `tanh` used a decomposed implementation, leading to numerical differences that were further amplified by `atan2`. For example, in the given test case after `tanh`, the eager output at `[0,0,11,47]` was `-5.820766091346741e-10`, while the compiled output was `1.4319084584712982e-08`, resulting in different `atan2` outputs of `-2.3561` and `0.7853`. This issue is fixed by switching to the Sleef implementation. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_tanh_atan2 ``` Pull Request resolved: pytorch#148254 Approved by: https://github.com/malfet, https://github.com/jgong5
1 parent 118a165 commit 165e335

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

test/inductor/test_cpu_repro.py

+15
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,21 @@ def fn(a):
10281028
a = torch.randn(1, 3)
10291029
self.common(fn, (a,))
10301030

1031+
def test_tanh_atan2(self):
1032+
# https://github.com/pytorch/pytorch/issues/148241
1033+
class Model(torch.nn.Module):
1034+
def __init__(self):
1035+
super().__init__()
1036+
self.shrink = nn.Tanhshrink()
1037+
1038+
def forward(self, x):
1039+
x = self.shrink(x)
1040+
x = torch.atan2(x, x)
1041+
return x
1042+
1043+
x = torch.randn(1, 3, 64, 64)
1044+
self.common(Model(), (x,))
1045+
10311046
def test_index_propagation_issue_102065(self):
10321047
def fn(x):
10331048
x = torch.arange(x.numel())

torch/_inductor/codegen/cpp.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1406,10 +1406,7 @@ def tan(a):
14061406

14071407
@staticmethod
14081408
def tanh(a):
1409-
vec_one = f"decltype({a})(1)"
1410-
vec_two = f"decltype({a})(2)"
1411-
vec_minus_two = f"decltype({a})(-2)"
1412-
return f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}"
1409+
return f"{a}.tanh()"
14131410

14141411
@staticmethod
14151412
def reciprocal(a):

0 commit comments

Comments
 (0)