-
Notifications
You must be signed in to change notification settings - Fork 77
Open
Description
Describe the bug
works with pytorch-triton-xpu 3.4 but failed with 3.5
torch_compile_dubug/torchinductor/output_code.py
# AOT ID: ['19_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _xpu_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_zhaoqion/hl/chldpaymqfqmwrak2vrlm6zzxyeozecmdsi7lwrapjakl2iek3pi.py
# Topologically Sorted Source Nodes: [getitem_1, position_ids_expanded], Original ATen: [aten.unsqueeze, aten._to_copy]
# Source node to ATen node mapping:
# getitem_1 => unsqueeze_3
# position_ids_expanded => convert_element_type
# Graph fragment:
# %arg0_1 : Tensor "i64[3, 1, 5390][5390, 5390, 1]xpu:0" = PlaceHolder[target=arg0_1]
# %unsqueeze_3 : Tensor "i64[3, 1, 1, 5390][5390, 5390, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg0_1, 2), kwargs = {})
# %convert_element_type : Tensor "f32[3, 1, 1, 5390][5390, 5390, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_3, torch.float32), kwargs = {})
# return %expand_2
triton_poi_fused__to_copy_unsqueeze_0 = async_compile.triton('triton_poi_fused__to_copy_unsqueeze_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 16384},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=64, cc={'architecture': 13136561920, 'device_id': 3029, 'driver_version': '1.6.33578+38', 'gpu_eu_count': 512, 'gpu_subslice_count': 64, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1550', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 68702699520, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '4DBA2B7984248D3650FA4B1AB2DC642F5864F56A87FFAEA50778AE4FC7494A6D', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 258720}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_unsqueeze_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 16170
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x0 = (xindex % 5390)
x1 = xindex // 5390
tmp0 = tl.load(in_ptr0 + (x2), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0 + 5408*x1), tmp1, xmask)
''', device_str='xpu')
# kernel path: /tmp/torchinductor_zhaoqion/td/ctdpohrnnpdxigh6uaxxfi7majmxwickypjg7xliklb4l3ar76qh.py
# Topologically Sorted Source Nodes: [matmul, freqs, setitem, getitem_3, , _generalized_scatter_1, setitem_1], Original ATen: [aten.view, aten.transpose, aten.select, aten.slice]
# Source node to ATen node mapping:
# => select_int, select_int_1, slice_scatter_default, slice_scatter_default_1
# _generalized_scatter_1 => select_scatter_default
# freqs => permute
# getitem_3 => select_1, slice_1
# matmul => view_2
# setitem => permute_1, permute_2, view_3, view_4
# setitem_1 => permute_8, permute_9, select_8, slice_7, view_10, view_11
# Graph fragment:
# %bmm : Tensor "f32[3, 64, 5390][344960, 5390, 1]xpu:0" = PlaceHolder[target=bmm]
# %view_2 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [3, 1, 64, 5390]), kwargs = {})
# %permute : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%view_2, [0, 1, 3, 2]), kwargs = {})
# %view_3 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [3, 1, 64, 5390]), kwargs = {})
# %permute_1 : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=3] = call_function[target=torch.ops.aten.permute.default](args = (%view_3, [0, 1, 3, 2]), kwargs = {})
# %select_1 : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%permute, 0, 1), kwargs = {})
# %slice_1 : Tensor "f32[1, 5390, 20][344960, 1, 16170]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select_1, 2, 1, 60, 3), kwargs = {})
# %select_int : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%permute_1, 0, 0), kwargs = {})
# %slice_scatter_default : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_int, %slice_1, 2, 1, 60, 3), kwargs = {})
# %select_scatter_default : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%permute_1, %slice_scatter_default, 0, 0), kwargs = {})
# %permute_2 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%select_scatter_default, [0, 1, 3, 2]), kwargs = {})
# %view_4 : Tensor "f32[3, 64, 5390][344960, 5390, 1]xpu:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_2, [3, 64, 5390]), kwargs = {})
# %view_11 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [3, 1, 64, 5390]), kwargs = {})
# %permute_9 : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=3] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 1, 3, 2]), kwargs = {})
# %view_10 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [3, 1, 64, 5390]), kwargs = {})
# %permute_8 : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 3, 2]), kwargs = {})
# %select_8 : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%permute_8, 0, 2), kwargs = {})
# %slice_7 : Tensor "f32[1, 5390, 20][344960, 1, 16170]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select_8, 2, 2, 60, 3), kwargs = {})
# %select_int_1 : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%permute_9, 0, 0), kwargs = {})
# %slice_scatter_default_1 : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_int_1, %slice_7, 2, 2, 60, 3), kwargs = {})
# return %slice_scatter_default_1
triton_poi_fused_select_slice_transpose_view_1 = async_compile.triton('triton_poi_fused_select_slice_transpose_view_1', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'y': 64, 'x': 8192}, tile_hint=TileHint.SQUARE,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=64, cc={'architecture': 13136561920, 'device_id': 3029, 'driver_version': '1.6.33578+38', 'gpu_eu_count': 512, 'gpu_subslice_count': 64, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1550', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 68702699520, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_select_slice_transpose_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '4DBA2B7984248D3650FA4B1AB2DC642F5864F56A87FFAEA50778AE4FC7494A6D', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 2759680, 'x': 6899200}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_select_slice_transpose_view_1(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 64
xnumel = 5390
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
xmask = xindex < xnumel
y0 = yindex
x1 = xindex
tmp39 = tl.load(in_ptr0 + (x1 + 5390*y0), xmask & ymask, eviction_policy='evict_last')
tmp0 = y0
tmp1 = tl.full([1, 1], 2, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1, 1], 60, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = (((-2) + y0) % 3)
tmp6 = tl.full([1, 1], 0, tl.int64)
tmp7 = tmp5 == tmp6
tmp8 = tmp2 & tmp4
tmp9 = tmp8 & tmp7
tmp10 = tl.full([1, 1], 2, tl.int32)
tmp11 = tl.full([1, 1], 0, tl.int32)
tmp12 = tmp10 == tmp11
tmp13 = tl.broadcast_to(2 + 3*(triton_helpers.div_floor_integer((-2) + y0, 3)), [YBLOCK, XBLOCK])
tmp14 = tl.full([1, 1], 1, tl.int64)
tmp15 = tmp13 >= tmp14
tmp16 = tl.full([1, 1], 60, tl.int64)
tmp17 = tmp13 < tmp16
tmp18 = tl.full([1, 1], 0, tl.int64)
tmp19 = tmp14 == tmp18
tmp20 = tmp15 & tmp17
tmp21 = tmp20 & tmp19
tmp22 = tmp21 & tmp9
tmp23 = tl.load(in_ptr0 + (350350 + x1 + 16170*(triton_helpers.div_floor_integer((-2) + y0, 3))), tmp22 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp24 = tl.load(in_ptr0 + (10780 + x1 + 16170*(triton_helpers.div_floor_integer((-2) + y0, 3))), tmp9 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp25 = tl.where(tmp21, tmp23, tmp24)
tmp26 = tl.load(in_ptr0 + (700700 + x1 + 16170*(triton_helpers.div_floor_integer((-2) + y0, 3))), tmp9 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.where(tmp12, tmp25, tmp26)
tmp28 = tl.full(tmp27.shape, 0.0, tmp27.dtype)
tmp29 = tl.where(tmp9, tmp27, tmp28)
tmp30 = tl.full([1, 1], 0, tl.int32)
tmp31 = tmp30 == tmp30
tmp32 = tl.full([1, 1], 1, tl.int64)
tmp33 = tmp0 >= tmp32
tmp34 = (((-1) + y0) % 3)
tmp35 = tmp34 == tmp6
tmp36 = tmp33 & tmp4
tmp37 = tmp36 & tmp35
tmp38 = tl.load(in_ptr0 + (350350 + x1 + 16170*(triton_helpers.div_floor_integer((-1) + y0, 3))), tmp37 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp40 = tl.where(tmp37, tmp38, tmp39)
tmp41 = tl.where(tmp31, tmp40, tmp39)
tmp42 = tl.where(tmp9, tmp29, tmp41)
tl.store(out_ptr0 + (y0 + 64*x1), tmp42, xmask & ymask)
''', device_str='xpu')
# kernel path: /tmp/torchinductor_zhaoqion/ea/ceafclnzlnyjxa5rwhsk3z32gzruq7ipgzzt6swxjjudpt4sckve.py
# Topologically Sorted Source Nodes: [matmul, freqs, setitem, getitem_3, , _generalized_scatter_1, setitem_1, _generalized_scatter_3, emb, cos, cos_1, to, sin, sin_1, to_1], Original ATen: [aten.view, aten.transpose, aten.select, aten.slice, aten.cat, aten.cos, aten.mul, aten._to_copy, aten.sin]
# Source node to ATen node mapping:
# => select_int, slice_scatter_default
# _generalized_scatter_1 => select_scatter_default
# _generalized_scatter_3 => select_scatter_default_1
# cos => cos
# cos_1 => mul
# emb => clone, expand_3, permute_12, select_11, unsqueeze_4, view_15
# freqs => permute
# getitem_3 => select_1, slice_1
# matmul => view_2
# setitem => permute_1, permute_2, view_3, view_4
# setitem_1 => permute_10, permute_9, view_11
# sin => sin
# sin_1 => mul_1
# to => convert_element_type_1
# to_1 => convert_element_type_2
# Graph fragment:
# %slice_scatter_default_1 : Tensor "f32[1, 5390, 64][344960, 64, 1]xpu:0" = PlaceHolder[target=slice_scatter_default_1]
# %bmm : Tensor "f32[3, 64, 5390][344960, 5390, 1]xpu:0" = PlaceHolder[target=bmm]
# %view_2 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [3, 1, 64, 5390]), kwargs = {})
# %permute : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%view_2, [0, 1, 3, 2]), kwargs = {})
# %view_3 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [3, 1, 64, 5390]), kwargs = {})
# %permute_1 : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=3] = call_function[target=torch.ops.aten.permute.default](args = (%view_3, [0, 1, 3, 2]), kwargs = {})
# %select_1 : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%permute, 0, 1), kwargs = {})
# %slice_1 : Tensor "f32[1, 5390, 20][344960, 1, 16170]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select_1, 2, 1, 60, 3), kwargs = {})
# %select_int : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%permute_1, 0, 0), kwargs = {})
# %slice_scatter_default : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_int, %slice_1, 2, 1, 60, 3), kwargs = {})
# %select_scatter_default : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%permute_1, %slice_scatter_default, 0, 0), kwargs = {})
# %permute_2 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%select_scatter_default, [0, 1, 3, 2]), kwargs = {})
# %view_4 : Tensor "f32[3, 64, 5390][344960, 5390, 1]xpu:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_2, [3, 64, 5390]), kwargs = {})
# %view_11 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [3, 1, 64, 5390]), kwargs = {})
# %permute_9 : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=3] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 1, 3, 2]), kwargs = {})
# %select_scatter_default_1 : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%permute_9, %slice_scatter_default_1, 0, 0), kwargs = {})
# %permute_10 : Tensor "f32[3, 1, 64, 5390][344960, 344960, 5390, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%select_scatter_default_1, [0, 1, 3, 2]), kwargs = {})
# %permute_12 : Tensor "f32[3, 1, 5390, 64][344960, 344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_10, [0, 1, 3, 2]), kwargs = {})
# %select_11 : Tensor "f32[1, 5390, 64][344960, 1, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%permute_12, 0, 0), kwargs = {})
# %unsqueeze_4 : Tensor "f32[1, 5390, 1, 64][344960, 1, 344960, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%select_11, 2), kwargs = {})
# %expand_3 : Tensor "f32[1, 5390, 2, 64][344960, 1, 0, 5390]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_4, [1, 5390, 2, 64]), kwargs = {})
# %clone : Tensor "f32[1, 5390, 2, 64][689920, 128, 64, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_3,), kwargs = {memory_format: torch.contiguous_format})
# %view_15 : Tensor "f32[1, 5390, 128][689920, 128, 1]xpu:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%clone, [1, 5390, 128]), kwargs = {})
# %cos : Tensor "f32[1, 5390, 128][689920, 128, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%view_15,), kwargs = {})
# %mul : Tensor "f32[1, 5390, 128][689920, 128, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cos, 1.0), kwargs = {})
# %convert_element_type_1 : Tensor "f16[1, 5390, 128][689920, 128, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.float16), kwargs = {})
# %sin : Tensor "f32[1, 5390, 128][689920, 128, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%view_15,), kwargs = {})
# %mul_1 : Tensor "f32[1, 5390, 128][689920, 128, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sin, 1.0), kwargs = {})
# %convert_element_type_2 : Tensor "f16[1, 5390, 128][689920, 128, 1]xpu:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_1, torch.float16), kwargs = {})
# return %convert_element_type_1,%convert_element_type_2
triton_poi_fused__to_copy_cat_cos_mul_select_sin_slice_transpose_view_2 = async_compile.triton('triton_poi_fused__to_copy_cat_cos_mul_select_sin_slice_transpose_view_2', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'y': 128, 'x': 8192}, tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp16', 'out_ptr1': '*fp16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=64, cc={'architecture': 13136561920, 'device_id': 3029, 'driver_version': '1.6.33578+38', 'gpu_eu_count': 512, 'gpu_subslice_count': 64, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1550', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 68702699520, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_cat_cos_mul_select_sin_slice_transpose_view_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 2, 'num_reduction': 0, 'backend_hash': '4DBA2B7984248D3650FA4B1AB2DC642F5864F56A87FFAEA50778AE4FC7494A6D', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 6899200, 'x': 5519360}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_cat_cos_mul_select_sin_slice_transpose_view_2(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 128
xnumel = 5390
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp2 = tl.load(in_ptr0 + (64*x1 + ((y0 % 64))), xmask & ymask, eviction_policy='evict_last')
tmp14 = tl.load(in_ptr1 + (x1 + 5390*((y0 % 64))), xmask & ymask, eviction_policy='evict_last')
tmp0 = tl.full([1, 1], 0, tl.int32)
tmp1 = tmp0 == tmp0
tmp3 = (y0 % 64)
tmp4 = tl.full([1, 1], 1, tl.int64)
tmp5 = tmp3 >= tmp4
tmp6 = tl.full([1, 1], 60, tl.int64)
tmp7 = tmp3 < tmp6
tmp8 = (((-1) + ((y0 % 64))) % 3)
tmp9 = tl.full([1, 1], 0, tl.int64)
tmp10 = tmp8 == tmp9
tmp11 = tmp5 & tmp7
tmp12 = tmp11 & tmp10
tmp13 = tl.load(in_ptr1 + (350350 + x1 + 16170*(triton_helpers.div_floor_integer((-1) + ((y0 % 64)), 3))), tmp12 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp15 = tl.where(tmp12, tmp13, tmp14)
tmp16 = tl.where(tmp1, tmp15, tmp14)
tmp17 = tl.where(tmp1, tmp2, tmp16)
tmp18 = tl_math.cos(tmp17)
tmp19 = 1.0
tmp20 = tmp18 * tmp19
tmp21 = tmp20.to(tl.float32)
tmp22 = tl_math.sin(tmp17)
tmp23 = tmp22 * tmp19
tmp24 = tmp23.to(tl.float32)
tl.store(out_ptr0 + (y0 + 128*x1), tmp21, xmask & ymask)
tl.store(out_ptr1 + (y0 + 128*x1), tmp24, xmask & ymask)
''', device_str='xpu')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (3, 1, 5390), (5390, 5390, 1))
assert_size_stride(arg1_1, (64, ), (1, ))
with torch.xpu._DeviceGuard(0):
torch.xpu.set_device(0)
buf0 = empty_strided_xpu((3, 1, 1, 5390), (5408, 5408, 16224, 1), torch.float32)
# Topologically Sorted Source Nodes: [getitem_1, position_ids_expanded], Original ATen: [aten.unsqueeze, aten._to_copy]
# [Provenance debug handles] triton_poi_fused__to_copy_unsqueeze_0:1
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_unsqueeze_0.run(arg0_1, buf0, 16170, stream=stream0)
del arg0_1
buf1 = empty_strided_xpu((3, 64, 5390), (344960, 5390, 1), torch.float32)
# Topologically Sorted Source Nodes: [getitem, inv_freq_expanded, matmul, getitem_1, position_ids_expanded], Original ATen: [aten.unsqueeze, aten.expand, aten.view, aten._to_copy, aten.bmm]
# [Provenance debug handles] extern_kernels.bmm:2
extern_kernels.bmm(reinterpret_tensor(arg1_1, (3, 64, 1), (0, 1, 0), 0), reinterpret_tensor(buf0, (3, 1, 5390), (5408, 0, 1), 0), out=buf1)
del arg1_1
del buf0
buf2 = empty_strided_xpu((1, 5390, 64), (344960, 64, 1), torch.float32)
# Topologically Sorted Source Nodes: [matmul, freqs, setitem, getitem_3, , _generalized_scatter_1, setitem_1], Original ATen: [aten.view, aten.transpose, aten.select, aten.slice]
# [Provenance debug handles] triton_poi_fused_select_slice_transpose_view_1:3
stream0 = get_raw_stream(0)
triton_poi_fused_select_slice_transpose_view_1.run(buf1, buf2, 64, 5390, stream=stream0)
buf3 = empty_strided_xpu((1, 5390, 128), (689920, 128, 1), torch.float16)
buf4 = empty_strided_xpu((1, 5390, 128), (689920, 128, 1), torch.float16)
# Topologically Sorted Source Nodes: [matmul, freqs, setitem, getitem_3, , _generalized_scatter_1, setitem_1, _generalized_scatter_3, emb, cos, cos_1, to, sin, sin_1, to_1], Original ATen: [aten.view, aten.transpose, aten.select, aten.slice, aten.cat, aten.cos, aten.mul, aten._to_copy, aten.sin]
# [Provenance debug handles] triton_poi_fused__to_copy_cat_cos_mul_select_sin_slice_transpose_view_2:4
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_cat_cos_mul_select_sin_slice_transpose_view_2.run(buf2, buf1, buf3, buf4, 128, 5390, stream=stream0)
del buf1
del buf2
return (buf3, buf4, )
runner = Runner(partitions=[])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((3, 1, 5390), (5390, 5390, 1), device='xpu:0', dtype=torch.int64)
arg1_1 = rand_strided((64, ), (1, ), device='xpu:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
output
L0 build module failed. Log:
Error during Intel loadBinary: ZE_RESULT_ERROR_MODULE_BUILD_FAILURE
Traceback (most recent call last):
File "/home1/zhaoqion/llm/torch_compile_debug/run_2025_11_26_22_40_13_041309-pid_39960/torchinductor/model__19_inference_19.19/output_code.py", line 234, in <module>
triton_poi_fused__to_copy_cat_cos_mul_select_sin_slice_transpose_view_2 = async_compile.triton('triton_poi_fused__to_copy_cat_cos_mul_select_sin_slice_transpose_view_2', '''
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/zhaoqion/miniforge3/envs/qwen3/lib/python3.12/site-packages/torch/_inductor/async_compile.py", line 477, in triton
kernel.precompile(
File "/home1/zhaoqion/miniforge3/envs/qwen3/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 439, in precompile
self._make_launchers()
File "/home1/zhaoqion/miniforge3/envs/qwen3/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 603, in _make_launchers
launchers.append(result.make_launcher())
^^^^^^^^^^^^^^^^^^^^^^
File "/home1/zhaoqion/miniforge3/envs/qwen3/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1838, in make_launcher
binary._init_handles()
File "/home1/zhaoqion/miniforge3/envs/qwen3/lib/python3.12/site-packages/triton/compiler/compiler.py", line 462, in _init_handles
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/zhaoqion/miniforge3/envs/qwen3/lib/python3.12/site-packages/triton/backends/intel/driver.py", line 213, in load_binary
return self.shared_library.load_binary(args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: ZE_RESULT_ERROR_MODULE_BUILD_FAILURE
Environment details
pytorch-triton-xpu 3.5.0+git1b0418a9
PVC
Metadata
Metadata
Assignees
Labels
No labels