Skip to content

Commit e74fdbe

Browse files
kundaMwizapytorchmergebot
authored andcommitted
[inductor] ignore block ptr advancements for removed buffers (pytorch#148087)
Follow up to pytorch#147193. Some buffers are removed only when the kernel context is exited so defer the lines instead. Added `use_block_ptr` as a parameter to test case that fails if run with block ptrs enabled. Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#148087 Approved by: https://github.com/jansel, https://github.com/eellison
1 parent d174562 commit e74fdbe

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

test/inductor/test_torchinductor.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -816,10 +816,10 @@ def wrapper(self):
816816

817817
def skip_if_not_triton(fn):
818818
@functools.wraps(fn)
819-
def wrapper(self):
819+
def wrapper(self, *args, **kwargs):
820820
if not is_triton_backend(self.device):
821821
raise unittest.SkipTest(f"triton backend is required for {self.device}")
822-
return fn(self)
822+
return fn(self, *args, **kwargs)
823823

824824
return wrapper
825825

@@ -9614,7 +9614,11 @@ def fn(x):
96149614
],
96159615
)
96169616

9617-
def test_tmp_not_defined_issue1(self):
9617+
@parametrize(
9618+
"use_block_ptr",
9619+
[subtest(False), subtest(True, decorators=[skip_if_not_triton])],
9620+
)
9621+
def test_tmp_not_defined_issue1(self, use_block_ptr):
96189622
def forward(
96199623
primals_3,
96209624
primals_4,
@@ -9651,7 +9655,8 @@ def forward(
96519655
(torch.Size([1, 512, 1]), torch.float32),
96529656
]
96539657
inps = [torch.randn(shape, dtype=dtype) for (shape, dtype) in inps]
9654-
self.common(forward, inps, atol=1e-05, rtol=2e-05)
9658+
with config.patch("triton.use_block_ptr", use_block_ptr):
9659+
self.common(forward, inps, atol=1e-05, rtol=2e-05)
96559660

96569661
@unittest.skipIf(
96579662
os.environ.get("BUILD_ENVIRONMENT", "").startswith("parallelnative"),

torch/_inductor/codegen/triton.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
DeferredLine,
7575
IndentedBuffer,
7676
InplacedBuffer,
77-
is_buffer_removed,
7877
OpOverrides,
7978
PythonPrinter,
8079
RemovedArg,
@@ -3142,8 +3141,6 @@ def codegen_body(self):
31423141
for block_ptr, advancement in self.pointer_advancements[
31433142
tree.symt
31443143
].items():
3145-
if is_buffer_removed(self.block_ptr_to_buffer[block_ptr]):
3146-
continue
31473144
# Subtract any advancements made in the previous loop level.
31483145
if level < len(loop_trees) - 1:
31493146
prev_tree = loop_trees[level + 1]
@@ -3158,7 +3155,10 @@ def codegen_body(self):
31583155
]
31593156

31603157
self.body.writeline(
3161-
f"{block_ptr} = tl.advance({block_ptr}, {V.kernel.index_to_str(advancement)})"
3158+
DeferredLine(
3159+
self.block_ptr_to_buffer[block_ptr],
3160+
f"{block_ptr} = tl.advance({block_ptr}, {V.kernel.index_to_str(advancement)})",
3161+
)
31623162
)
31633163

31643164
# Invalidate any cache entries that came from inside the loop.

0 commit comments

Comments
 (0)