Skip to content

Commit 4995e05

Browse files
sijiacpytorchmergebot
authored andcommitted
[user-triton] handle inline_asm_case (pytorch#148043)
Summary: We currently failed the mutation analysis for all inline_asm ops. In this diff, we handle the case when "is_pure" is set to True since it indicates the operation doesn't mutate the input value Test Plan: ../buck-out/v2/gen/fbcode/854b9ed00d28c5c5/caffe2/test/inductor/__triton_kernels__/triton_kernels.par --r test_mutations_inline_asm_kernel ``` test_mutations_inline_asm_kernel_is_pure_true (caffe2.test.inductor.test_triton_kernels.MutationTests) ... W0226 18:10:34.261000 1906801 /data/users/sijiac/fbsource/fbcode/caffe2/torch/_higher_order_ops/triton_kernel_wrap.py:656] TTIR mutation analysis: Skipping pure tt.elementwise_inline_asm op (is_pure=True) ok ---------------------------------------------------------------------- Ran 2 tests in 0.706s OK ``` Differential Revision: D69878591 Pull Request resolved: pytorch#148043 Approved by: https://github.com/zou3519
1 parent 6f91720 commit 4995e05

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

test/inductor/test_triton_kernels.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3127,9 +3127,13 @@ def branch_with_multiple_yield_args(
31273127
{"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4},
31283128
["ptr"],
31293129
],
3130-
# Cant optimize since the kernel contains a tl.inline_asm_elementwise
31313130
[
3132-
inline_asm_kernel,
3131+
inline_asm_kernel_is_pure_true,
3132+
{"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4},
3133+
["Z"],
3134+
],
3135+
[
3136+
inline_asm_kernel_is_pure_false,
31333137
{"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4},
31343138
["X", "Y", "Z"],
31353139
],

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ class Op:
158158
ret: Intermediate = dataclasses.field(repr=False)
159159
# used for scf.yield: see [Note: scf.yield fix-up]
160160
sub_idx: Optional[int] = None
161+
# used for tt.elementwise_inline_asm
162+
# `is_pure = True` assumes the asm block has no side-effects
163+
is_pure: bool = False
161164

162165
def __post_init__(self) -> None:
163166
if self.name == "tt.call":
@@ -572,14 +575,22 @@ def mlir_to_functions(op: "TritonIROperation") -> None:
572575
Intermediate(operand) for operand in operand_ids
573576
]
574577
block_ops = op_stack[parent_block_id]
578+
579+
is_pure = False
580+
# Handle the case for tt.elementwise_inline_asm to set `is_pure` for mutation analysis
581+
if name == "tt.elementwise_inline_asm":
582+
is_pure = op.get_bool_attr("pure")
583+
575584
if result_ids:
576585
for result_id in result_ids:
577586
res = Intermediate(result_id)
578-
block_ops[res].append(Op(name, callee, args, res))
587+
block_ops[res].append(Op(name, callee, args, res, is_pure=is_pure))
579588
else:
580589
next_fake_intermediate -= 1
581590
fake_res = Intermediate(next_fake_intermediate)
582-
block_ops[fake_res].append(Op(name, callee, args, fake_res))
591+
block_ops[fake_res].append(
592+
Op(name, callee, args, fake_res, is_pure=is_pure)
593+
)
583594

584595
ttir_module.walk(mlir_to_functions)
585596

@@ -640,7 +651,14 @@ def analyze_kernel_mutations(
640651
ops = functions[fn_name]
641652
for op_list in ops.values():
642653
for op in op_list:
654+
# If we encounter an operation with effects that cannot be reliably analyzed
655+
# (e.g. `tt.elementwise_inline_asm`), we assume it does not mutate any input parameters.
643656
if op.name in UNKNOWN_OPS:
657+
if op.name == "tt.elementwise_inline_asm" and op.is_pure:
658+
log.warning(
659+
"TTIR mutation analysis: Skipping pure tt.elementwise_inline_asm op (is_pure=True)"
660+
)
661+
continue
644662
raise RuntimeError(
645663
f"ttir analysis hit an op we do not know how to analyze: {op.name}"
646664
)

torch/testing/_internal/triton_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,9 @@ def double_strided_kernel(
346346
tl.store(out_ptr + dst_offsets, src * 2.0)
347347

348348
@triton.jit
349-
def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"):
349+
def inline_asm_kernel_is_pure_true(
350+
X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
351+
):
350352
x = tl.load(X + tl.arange(0, BLOCK))
351353
y = tl.load(Y + tl.arange(0, BLOCK))
352354
s = tl.full([BLOCK], n, tl.int32)
@@ -360,6 +362,23 @@ def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"):
360362
)
361363
tl.store(Z + tl.arange(0, BLOCK), z)
362364

365+
@triton.jit
366+
def inline_asm_kernel_is_pure_false(
367+
X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
368+
):
369+
x = tl.load(X + tl.arange(0, BLOCK))
370+
y = tl.load(Y + tl.arange(0, BLOCK))
371+
s = tl.full([BLOCK], n, tl.int32)
372+
z = tl.inline_asm_elementwise(
373+
"shf.l.wrap.b32 $0, $1, $2, $3;",
374+
"=r,r, r, r",
375+
[x, y, s],
376+
dtype=tl.int32,
377+
is_pure=False,
378+
pack=1,
379+
)
380+
tl.store(Z + tl.arange(0, BLOCK), z)
381+
363382
@triton.jit
364383
def add_kernel_with_block_ptr(
365384
x_ptr,

0 commit comments

Comments
 (0)