Skip to content

tl.inline_asm_elementwise returns wrong results on Intel GPU #5580

@hoshibara

Description

@hoshibara

Describe the bug

I found that torch compiled code returns wrong results in some PyTorch UTs.
After debugging, I confirmed that tl.inline_asm_elementwise behaves abnormally on XPU.

Related issues:
intel/torch-xpu-ops#2328
intel/torch-xpu-ops#2330

Reproducer:

import torch

import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

@triton.jit
def triton_poi_fused_tanh_approx_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.inline_asm_elementwise('tanh.approx.f32 $0, $1;', '=r, r', [tmp0], dtype=tl.float32, is_pure=True, pack=1)
    tl.store(out_ptr0 + (x0), tmp1, xmask)

    
def tanh_approx(x: torch.Tensor):
    # We need to preallocate the output.
    output = torch.empty_like(x)
    assert x.device == DEVICE and output.device == DEVICE
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), )
    triton_poi_fused_tanh_approx_0[grid](x, output, n_elements, XBLOCK=1024)
    return output

# Run reproducer
if __name__ == "__main__":
    torch.manual_seed(42)
    device = DEVICE
    inp = torch.randn(4, device=device)
    out = torch.zeros(4, device=device)

    print("Testing tanh_approx reproducer:")
    print(f"Input shape: {inp.shape}")
    print(f"Input device: {inp.device}")
    print(f"Input: {inp}")

    # Compiled mode
    output = tanh_approx(inp)
    print(f"Compiled result shape: {output.shape}")
    print(f"Compiled result: {output}")

    # Compare with torch.tanh
    expected = torch.tanh(inp)
    print(f"Expected result: {expected}")
    print(
        f"Match with torch.tanh: {torch.allclose(expected, output, atol=1e-3)}"
    )

Results on XPU:

Input shape: torch.Size([4])
Input device: xpu:0
Input: tensor([ 0.1940,  2.1614, -0.1721,  0.8491], device='xpu:0')
Compiled result shape: torch.Size([4])
Compiled result: tensor([2.6864e-38, 2.6866e-38, 2.6864e-38, 2.6866e-38], device='xpu:0')
Expected result: tensor([ 0.1916,  0.9738, -0.1704,  0.6906], device='xpu:0')
Match with torch.tanh: False

Results on CUDA:

Input shape: torch.Size([4])
Input device: cuda:0
Input: tensor([ 0.1940,  2.1614, -0.1721,  0.8491], device='cuda:0')
Compiled result shape: torch.Size([4])
Compiled result: tensor([ 0.1916,  0.9738, -0.1704,  0.6906], device='cuda:0')
Expected result: tensor([ 0.1916,  0.9738, -0.1704,  0.6906], device='cuda:0')
Match with torch.tanh: True

Environment details

torch: 2.10.0a0+gitb29465b
pytorch-triton-xpu: 3.5.0+git1b0418a9

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions