Skip to content

Commit b731ced

Browse files
eellisonpytorchmergebot
authored andcommitted
Prologue Fusion (pytorch#134532)
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion. Similar to the store_output api: `{{store_output(("idx_m", "idx_n"), "acc", "mask")}}` And the modification api: ``` {{ modification( subgraph_number=0, output_name="post_mod_scores", score="qk", out="qk" ) | indent_except_first(1) }} ``` We have: ```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}``` Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](https://github.com/pytorch/pytorch/blob/bb03ef7acadf7bbc3287b0ada58e9476eda6e0fe/torch/_inductor/kernel/mm.py#L110-L111) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference. There are a couple main use cases for prologue fusion: - Fusing dequants into a matmul. particularly for more bandwidth bound scenarios. - Fusing gather into a matmul. This is useful particularly in MOE. See pytorch#134535 for more details. Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066 Other notes: By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need pytorch#136778 and dtype-aware codegen to upcast fp16 ops into libdevice calls. With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so.. Pull Request resolved: pytorch#134532 Approved by: https://github.com/jansel
1 parent ceb664a commit b731ced

File tree

16 files changed

+883
-107
lines changed

16 files changed

+883
-107
lines changed

test/inductor/test_max_autotune.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Owner(s): ["module: inductor"]
2+
import contextlib
23
import os
34
import unittest
45
from typing import Callable, List, Optional
@@ -980,6 +981,214 @@ def test_tuning_pool_multiple_devices(self):
980981
tuning_pool.terminate()
981982

982983

984+
@instantiate_parametrized_tests
985+
class TestPrologueFusion(TestCase):
986+
@classmethod
987+
def setUpClass(cls):
988+
super().setUpClass()
989+
cls._stack = contextlib.ExitStack()
990+
cls._stack.enter_context(
991+
config.patch(
992+
{
993+
"max_autotune": True,
994+
"prologue_fusion": True,
995+
"benchmark_epilogue_fusion": False,
996+
"shape_padding": False,
997+
"max_autotune_gemm_backends": "TRITON",
998+
"test_configs.max_mm_configs": 4, # significantly speeds up tests
999+
}
1000+
)
1001+
)
1002+
1003+
def check_code(self, code_str, num_kernels, num_allocs, num_deallocs):
1004+
FileCheck().check("def call").check_count(
1005+
".run", num_kernels, exactly=True
1006+
).run(code_str)
1007+
1008+
if num_allocs is not None:
1009+
FileCheck().check("def call").check_count(
1010+
"empty_strided", num_allocs, exactly=True
1011+
).run(code_str)
1012+
1013+
if num_deallocs is not None:
1014+
FileCheck().check("def call").check_count(
1015+
"del", num_deallocs, exactly=True
1016+
).run(code_str)
1017+
1018+
@parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250)))
1019+
def test_upcast(self, sizes):
1020+
M, K, N = sizes
1021+
1022+
x = torch.rand([M, K], dtype=torch.float16, device="cuda")
1023+
y = torch.rand([K, N], dtype=torch.float, device="cuda")
1024+
1025+
def foo(x, y):
1026+
return x.to(y.dtype) @ y
1027+
1028+
out, code = run_and_get_code(torch.compile(foo), x, y)
1029+
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1030+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
1031+
1032+
def test_downcast(self):
1033+
# per heuristics, dont fuse a downcast into a mm because it would lead to more reads inside kernel
1034+
M, K, N = (64, 128, 256)
1035+
x = torch.rand([M, K], dtype=torch.float, device="cuda")
1036+
y = torch.rand([K, N], dtype=torch.float16, device="cuda")
1037+
1038+
def foo(x, y):
1039+
return x.to(y.dtype) @ y
1040+
1041+
out, code = run_and_get_code(torch.compile(foo), x, y)
1042+
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1043+
self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=3)
1044+
1045+
@parametrize("sizes", ((64, 128, 256), (64, 64, 64), (64, 120, 64)))
1046+
def test_multiple_fusions(self, sizes):
1047+
M, K, N = sizes
1048+
1049+
def foo(x, y):
1050+
return ((x - 1.1) @ (y + 1.1)) * 1.1
1051+
1052+
x = torch.rand([M, K], dtype=torch.float, device="cuda")
1053+
y = torch.rand([K, N], dtype=torch.float, device="cuda")
1054+
1055+
out, code = run_and_get_code(torch.compile(foo), x, y)
1056+
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1057+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
1058+
1059+
# check that we do not CSE any variables between prologues, epilogues
1060+
FileCheck().check("def triton").check_count("= 1.1", 3, exactly=True).check(
1061+
"tl.store"
1062+
).run(code[0])
1063+
1064+
@parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250)))
1065+
def test_multiple_inputs(self, sizes):
1066+
M, K, N = sizes
1067+
1068+
def foo(x, y, z):
1069+
return (x + y).to(torch.float) @ z
1070+
1071+
x = torch.rand([M, K], dtype=torch.float16, device="cuda")
1072+
y = torch.rand([M, K], dtype=torch.float16, device="cuda")
1073+
z = torch.rand([K, N], dtype=torch.float, device="cuda")
1074+
out_eager = foo(x, y, z)
1075+
out, code = run_and_get_code(torch.compile(foo), x, y, z)
1076+
self.assertEqual(out, out_eager, atol=0.05, rtol=0.05)
1077+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=3)
1078+
1079+
def test_storage_offset_prologue(self):
1080+
def foo(a):
1081+
q = a[:64, :]
1082+
k = a[64:, :]
1083+
return torch.mm(q + 2, k - 2)
1084+
1085+
inp = torch.randn(128, 64, device="cuda")
1086+
out, code = run_and_get_code(torch.compile(foo), inp)
1087+
self.assertEqual(out, foo(inp), atol=0.05, rtol=0.05)
1088+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=1)
1089+
1090+
@config.patch(realize_reads_threshold=1, realize_opcount_threshold=1)
1091+
@parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250)))
1092+
def test_prologue_multiple_nodes(self, sizes):
1093+
M, K, N = sizes
1094+
1095+
def foo(x, y):
1096+
return ((((x * 2) - 1) / 2) @ (y * 4)) * 3.0
1097+
1098+
x = torch.rand([M, K], dtype=torch.float, device="cuda")
1099+
y = torch.rand([K, N], dtype=torch.float, device="cuda")
1100+
1101+
out, code = run_and_get_code(torch.compile(foo), x, y)
1102+
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1103+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
1104+
1105+
@parametrize("K", (63, 64))
1106+
def test_broadcast_x(self, K):
1107+
def foo(x, y):
1108+
return (x.expand([1, y.shape[0]]) + 1) @ y
1109+
1110+
x = torch.rand([1, 1], dtype=torch.float, device="cuda")
1111+
y = torch.rand([K, 128], dtype=torch.float, device="cuda")
1112+
1113+
out, code = run_and_get_code(torch.compile(foo, dynamic=True), x, y)
1114+
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1115+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
1116+
1117+
def test_broadcast_y(self):
1118+
def foo(x, y):
1119+
return x @ y
1120+
1121+
M = 20
1122+
N = K = 1
1123+
x = torch.rand([M, K], dtype=torch.float, device="cuda")
1124+
y = torch.rand([K, N], dtype=torch.float, device="cuda")
1125+
torch._dynamo.mark_dynamic(x, 0)
1126+
1127+
out, code = run_and_get_code(torch.compile(foo, dynamic=True), x, y)
1128+
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1129+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
1130+
1131+
@config.patch(realize_reads_threshold=1, realize_opcount_threshold=1)
1132+
@parametrize("benchmark_fusion", (True, False))
1133+
def test_prologue_read_into_both_inputs(self, benchmark_fusion):
1134+
M = K = N = 256
1135+
1136+
# not supported today. it could be, but typically the pointwise nodes would get
1137+
# inlined into separate nodes.
1138+
1139+
def foo(x):
1140+
y = (x + 1) * 2
1141+
return y @ (y - 2)
1142+
1143+
with config.patch(benchmark_epilogue_fusion=benchmark_fusion):
1144+
x = torch.rand([M, K], dtype=torch.float, device="cuda")
1145+
1146+
out, code = run_and_get_code(torch.compile(foo), x)
1147+
self.assertEqual(out, foo(x), atol=0.05, rtol=0.05)
1148+
# not guaranteed to fuse, but still checking correctness
1149+
if not benchmark_fusion:
1150+
self.check_code(
1151+
code[0], num_kernels=2, num_allocs=None, num_deallocs=None
1152+
)
1153+
1154+
@config.patch(realize_reads_threshold=1, realize_opcount_threshold=1)
1155+
@config.patch(allow_buffer_reuse=False)
1156+
def test_mismatched_prologue_group(self):
1157+
def foo(x, y, z):
1158+
a = (x + 2) * 2
1159+
b = a * y
1160+
return b @ z
1161+
1162+
x = torch.rand([1, 256], device="cuda")
1163+
y = torch.rand([256, 256], device="cuda")
1164+
z = torch.rand([256, 128], device="cuda")
1165+
1166+
out, code = run_and_get_code(torch.compile(foo), x, y, z)
1167+
self.assertEqual(out, foo(x, y, z), atol=0.05, rtol=0.05)
1168+
# theres one more dealloc than there should be because of a buffer reuse. TODO:
1169+
# not sure why disabling buffer reuse doesnt stop
1170+
self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=4)
1171+
1172+
@config.patch(shape_padding=True)
1173+
@config.patch(force_shape_pad=True)
1174+
@parametrize("sizes", ((250, 245, 128), (250, 256, 128), (256, 128, 62)))
1175+
def test_prologue_masked_load(self, sizes):
1176+
M, K, N = sizes
1177+
1178+
def foo(x, y):
1179+
return x @ y
1180+
1181+
# cat will turn into masked load
1182+
# TODO - we should not attempt fusion if it turns an aligned load
1183+
# into an unaligned load
1184+
x = torch.rand([250, 245], device="cuda")
1185+
y = torch.rand([245, 128], device="cuda")
1186+
1187+
out, code = run_and_get_code(torch.compile(foo), x, y)
1188+
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1189+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
1190+
1191+
9831192
if __name__ == "__main__":
9841193
from torch._inductor.utils import is_big_gpu
9851194

torch/_inductor/choices.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,18 @@ def score_fusion(
306306
abs(node1.min_order - node2.max_order),
307307
abs(node2.min_order - node1.max_order),
308308
)
309+
310+
# prologue fusion always last
311+
if node2.is_template():
312+
template_score = 0
313+
else:
314+
template_score = 1 + (
315+
(node1.is_template() == config.epilogue_fusion_first)
316+
and memory_score > 0
317+
)
318+
309319
return (
310-
node1.is_template() == config.epilogue_fusion_first and memory_score > 0,
320+
template_score,
311321
node1.is_reduction() == node2.is_reduction() and memory_score > 0,
312322
memory_score,
313323
proximity_score,

torch/_inductor/codecache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,8 +1112,8 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]:
11121112
metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas)
11131113
counters["inductor"] += graph.counter_deltas
11141114

1115-
output_code_log.debug("Output code written to: %s", artifact_path)
11161115
output_code_log.debug("Output code: \n%s", code)
1116+
output_code_log.debug("Output code written to: %s", artifact_path)
11171117
# On cache hit, use artifact path as filename
11181118
trace_structured(
11191119
"inductor_output_code",

torch/_inductor/codegen/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from __future__ import annotations
3+
24
import contextlib
35
import dataclasses
46
import enum
@@ -18,10 +20,16 @@
1820
List,
1921
NamedTuple,
2022
Optional,
23+
Set,
2124
Tuple,
25+
TYPE_CHECKING,
2226
Union,
2327
)
2428

29+
30+
if TYPE_CHECKING:
31+
from typing import Never
32+
2533
import sympy
2634

2735
import torch
@@ -1460,7 +1468,7 @@ def __init__(
14601468
self.invalidated_stores = OrderedSet() # type: ignore[var-annotated]
14611469
self.varname_map = varname_map or {}
14621470

1463-
def invalidate(self, keep_vars: OrderedSet[str]):
1471+
def invalidate(self, keep_vars: Union[OrderedSet[str], Set[Never]]):
14641472
for name, tmp in list(self.store_cache.items()):
14651473
if tmp not in keep_vars:
14661474
del self.store_cache[name]
@@ -2326,7 +2334,7 @@ def maybe_append_choice(self, choices, **kwargs):
23262334
except NotImplementedError as e:
23272335
return e
23282336

2329-
def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
2337+
def generate(self, **kwargs) -> torch._inductor.ir.ChoiceCaller:
23302338
"""
23312339
Generates a ChoiceCaller instance from the given arguments.
23322340
"""

torch/_inductor/codegen/cpp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4802,10 +4802,12 @@ def codegen_template(
48024802
self,
48034803
template_node: BaseSchedulerNode,
48044804
epilogue_nodes: Sequence[BaseSchedulerNode],
4805+
prologue_nodes: Sequence[BaseSchedulerNode],
48054806
):
48064807
"""
48074808
Codegen a CPP template, possibly with fused epilogues
48084809
"""
4810+
assert not prologue_nodes
48094811
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
48104812
assert self.is_cpp_template(
48114813
template_node

torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def codegen_template(
8282
self,
8383
template_node: BaseSchedulerNode,
8484
epilogue_nodes: Sequence[BaseSchedulerNode],
85+
prologue_nodes: Sequence[BaseSchedulerNode],
8586
):
8687
"""
8788
Codegen a CUDA template, possibly with fused epilogues

torch/_inductor/codegen/cuda_combined_scheduling.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,23 @@ def codegen_template(
6060
self,
6161
template_node: BaseSchedulerNode,
6262
epilogue_nodes: Sequence[BaseSchedulerNode],
63+
prologue_nodes: Sequence[BaseSchedulerNode],
6364
):
6465
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
65-
assert epilogue_nodes is None or len(epilogue_nodes) == 0
66+
assert not epilogue_nodes
67+
assert not prologue_nodes
6668
return self._cuda_cpp_scheduling.codegen_template(
67-
template_node, epilogue_nodes
69+
template_node, epilogue_nodes, prologue_nodes
6870
)
6971
elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node):
70-
assert epilogue_nodes is None or len(epilogue_nodes) == 0
72+
assert not epilogue_nodes
73+
assert not prologue_nodes
7174
return self._rocm_cpp_scheduling.codegen_template(
72-
template_node, epilogue_nodes
75+
template_node, epilogue_nodes, prologue_nodes
7376
)
7477
else:
7578
return self._triton_scheduling.codegen_template(
76-
template_node, epilogue_nodes
79+
template_node, epilogue_nodes, prologue_nodes
7780
)
7881

7982
def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]):

torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def codegen_template(
7777
self,
7878
template_node: BaseSchedulerNode,
7979
epilogue_nodes: Sequence[BaseSchedulerNode],
80+
prologue_nodes: Sequence[BaseSchedulerNode],
8081
):
8182
"""
8283
Codegen a ROCm template, possibly with fused epilogues

0 commit comments

Comments
 (0)