-
Notifications
You must be signed in to change notification settings - Fork 77
Open
Description
Describe the bug
I found that torch compiled code returns wrong results in some PyTorch UTs.
After debugging, I confirmed that tl.inline_asm_elementwise behaves abnormally on XPU.
Related issues:
intel/torch-xpu-ops#2328
intel/torch-xpu-ops#2330
Reproducer:
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@triton.jit
def triton_poi_fused_tanh_approx_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.inline_asm_elementwise('tanh.approx.f32 $0, $1;', '=r, r', [tmp0], dtype=tl.float32, is_pure=True, pack=1)
tl.store(out_ptr0 + (x0), tmp1, xmask)
def tanh_approx(x: torch.Tensor):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.device == DEVICE and output.device == DEVICE
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), )
triton_poi_fused_tanh_approx_0[grid](x, output, n_elements, XBLOCK=1024)
return output
# Run reproducer
if __name__ == "__main__":
torch.manual_seed(42)
device = DEVICE
inp = torch.randn(4, device=device)
out = torch.zeros(4, device=device)
print("Testing tanh_approx reproducer:")
print(f"Input shape: {inp.shape}")
print(f"Input device: {inp.device}")
print(f"Input: {inp}")
# Compiled mode
output = tanh_approx(inp)
print(f"Compiled result shape: {output.shape}")
print(f"Compiled result: {output}")
# Compare with torch.tanh
expected = torch.tanh(inp)
print(f"Expected result: {expected}")
print(
f"Match with torch.tanh: {torch.allclose(expected, output, atol=1e-3)}"
)Results on XPU:
Input shape: torch.Size([4])
Input device: xpu:0
Input: tensor([ 0.1940, 2.1614, -0.1721, 0.8491], device='xpu:0')
Compiled result shape: torch.Size([4])
Compiled result: tensor([2.6864e-38, 2.6866e-38, 2.6864e-38, 2.6866e-38], device='xpu:0')
Expected result: tensor([ 0.1916, 0.9738, -0.1704, 0.6906], device='xpu:0')
Match with torch.tanh: False
Results on CUDA:
Input shape: torch.Size([4])
Input device: cuda:0
Input: tensor([ 0.1940, 2.1614, -0.1721, 0.8491], device='cuda:0')
Compiled result shape: torch.Size([4])
Compiled result: tensor([ 0.1916, 0.9738, -0.1704, 0.6906], device='cuda:0')
Expected result: tensor([ 0.1916, 0.9738, -0.1704, 0.6906], device='cuda:0')
Match with torch.tanh: True
Environment details
torch: 2.10.0a0+gitb29465b
pytorch-triton-xpu: 3.5.0+git1b0418a9
Metadata
Metadata
Assignees
Labels
No labels