|
1 | 1 | # Owner(s): ["module: inductor"]
|
| 2 | +import contextlib |
2 | 3 | import os
|
3 | 4 | import unittest
|
4 | 5 | from typing import Callable, List, Optional
|
@@ -980,6 +981,214 @@ def test_tuning_pool_multiple_devices(self):
|
980 | 981 | tuning_pool.terminate()
|
981 | 982 |
|
982 | 983 |
|
| 984 | +@instantiate_parametrized_tests |
| 985 | +class TestPrologueFusion(TestCase): |
| 986 | + @classmethod |
| 987 | + def setUpClass(cls): |
| 988 | + super().setUpClass() |
| 989 | + cls._stack = contextlib.ExitStack() |
| 990 | + cls._stack.enter_context( |
| 991 | + config.patch( |
| 992 | + { |
| 993 | + "max_autotune": True, |
| 994 | + "prologue_fusion": True, |
| 995 | + "benchmark_epilogue_fusion": False, |
| 996 | + "shape_padding": False, |
| 997 | + "max_autotune_gemm_backends": "TRITON", |
| 998 | + "test_configs.max_mm_configs": 4, # significantly speeds up tests |
| 999 | + } |
| 1000 | + ) |
| 1001 | + ) |
| 1002 | + |
| 1003 | + def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): |
| 1004 | + FileCheck().check("def call").check_count( |
| 1005 | + ".run", num_kernels, exactly=True |
| 1006 | + ).run(code_str) |
| 1007 | + |
| 1008 | + if num_allocs is not None: |
| 1009 | + FileCheck().check("def call").check_count( |
| 1010 | + "empty_strided", num_allocs, exactly=True |
| 1011 | + ).run(code_str) |
| 1012 | + |
| 1013 | + if num_deallocs is not None: |
| 1014 | + FileCheck().check("def call").check_count( |
| 1015 | + "del", num_deallocs, exactly=True |
| 1016 | + ).run(code_str) |
| 1017 | + |
| 1018 | + @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) |
| 1019 | + def test_upcast(self, sizes): |
| 1020 | + M, K, N = sizes |
| 1021 | + |
| 1022 | + x = torch.rand([M, K], dtype=torch.float16, device="cuda") |
| 1023 | + y = torch.rand([K, N], dtype=torch.float, device="cuda") |
| 1024 | + |
| 1025 | + def foo(x, y): |
| 1026 | + return x.to(y.dtype) @ y |
| 1027 | + |
| 1028 | + out, code = run_and_get_code(torch.compile(foo), x, y) |
| 1029 | + self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05) |
| 1030 | + self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2) |
| 1031 | + |
| 1032 | + def test_downcast(self): |
| 1033 | + # per heuristics, dont fuse a downcast into a mm because it would lead to more reads inside kernel |
| 1034 | + M, K, N = (64, 128, 256) |
| 1035 | + x = torch.rand([M, K], dtype=torch.float, device="cuda") |
| 1036 | + y = torch.rand([K, N], dtype=torch.float16, device="cuda") |
| 1037 | + |
| 1038 | + def foo(x, y): |
| 1039 | + return x.to(y.dtype) @ y |
| 1040 | + |
| 1041 | + out, code = run_and_get_code(torch.compile(foo), x, y) |
| 1042 | + self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05) |
| 1043 | + self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=3) |
| 1044 | + |
| 1045 | + @parametrize("sizes", ((64, 128, 256), (64, 64, 64), (64, 120, 64))) |
| 1046 | + def test_multiple_fusions(self, sizes): |
| 1047 | + M, K, N = sizes |
| 1048 | + |
| 1049 | + def foo(x, y): |
| 1050 | + return ((x - 1.1) @ (y + 1.1)) * 1.1 |
| 1051 | + |
| 1052 | + x = torch.rand([M, K], dtype=torch.float, device="cuda") |
| 1053 | + y = torch.rand([K, N], dtype=torch.float, device="cuda") |
| 1054 | + |
| 1055 | + out, code = run_and_get_code(torch.compile(foo), x, y) |
| 1056 | + self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05) |
| 1057 | + self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2) |
| 1058 | + |
| 1059 | + # check that we do not CSE any variables between prologues, epilogues |
| 1060 | + FileCheck().check("def triton").check_count("= 1.1", 3, exactly=True).check( |
| 1061 | + "tl.store" |
| 1062 | + ).run(code[0]) |
| 1063 | + |
| 1064 | + @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) |
| 1065 | + def test_multiple_inputs(self, sizes): |
| 1066 | + M, K, N = sizes |
| 1067 | + |
| 1068 | + def foo(x, y, z): |
| 1069 | + return (x + y).to(torch.float) @ z |
| 1070 | + |
| 1071 | + x = torch.rand([M, K], dtype=torch.float16, device="cuda") |
| 1072 | + y = torch.rand([M, K], dtype=torch.float16, device="cuda") |
| 1073 | + z = torch.rand([K, N], dtype=torch.float, device="cuda") |
| 1074 | + out_eager = foo(x, y, z) |
| 1075 | + out, code = run_and_get_code(torch.compile(foo), x, y, z) |
| 1076 | + self.assertEqual(out, out_eager, atol=0.05, rtol=0.05) |
| 1077 | + self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=3) |
| 1078 | + |
| 1079 | + def test_storage_offset_prologue(self): |
| 1080 | + def foo(a): |
| 1081 | + q = a[:64, :] |
| 1082 | + k = a[64:, :] |
| 1083 | + return torch.mm(q + 2, k - 2) |
| 1084 | + |
| 1085 | + inp = torch.randn(128, 64, device="cuda") |
| 1086 | + out, code = run_and_get_code(torch.compile(foo), inp) |
| 1087 | + self.assertEqual(out, foo(inp), atol=0.05, rtol=0.05) |
| 1088 | + self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=1) |
| 1089 | + |
| 1090 | + @config.patch(realize_reads_threshold=1, realize_opcount_threshold=1) |
| 1091 | + @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) |
| 1092 | + def test_prologue_multiple_nodes(self, sizes): |
| 1093 | + M, K, N = sizes |
| 1094 | + |
| 1095 | + def foo(x, y): |
| 1096 | + return ((((x * 2) - 1) / 2) @ (y * 4)) * 3.0 |
| 1097 | + |
| 1098 | + x = torch.rand([M, K], dtype=torch.float, device="cuda") |
| 1099 | + y = torch.rand([K, N], dtype=torch.float, device="cuda") |
| 1100 | + |
| 1101 | + out, code = run_and_get_code(torch.compile(foo), x, y) |
| 1102 | + self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05) |
| 1103 | + self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2) |
| 1104 | + |
| 1105 | + @parametrize("K", (63, 64)) |
| 1106 | + def test_broadcast_x(self, K): |
| 1107 | + def foo(x, y): |
| 1108 | + return (x.expand([1, y.shape[0]]) + 1) @ y |
| 1109 | + |
| 1110 | + x = torch.rand([1, 1], dtype=torch.float, device="cuda") |
| 1111 | + y = torch.rand([K, 128], dtype=torch.float, device="cuda") |
| 1112 | + |
| 1113 | + out, code = run_and_get_code(torch.compile(foo, dynamic=True), x, y) |
| 1114 | + self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05) |
| 1115 | + self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2) |
| 1116 | + |
| 1117 | + def test_broadcast_y(self): |
| 1118 | + def foo(x, y): |
| 1119 | + return x @ y |
| 1120 | + |
| 1121 | + M = 20 |
| 1122 | + N = K = 1 |
| 1123 | + x = torch.rand([M, K], dtype=torch.float, device="cuda") |
| 1124 | + y = torch.rand([K, N], dtype=torch.float, device="cuda") |
| 1125 | + torch._dynamo.mark_dynamic(x, 0) |
| 1126 | + |
| 1127 | + out, code = run_and_get_code(torch.compile(foo, dynamic=True), x, y) |
| 1128 | + self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05) |
| 1129 | + self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2) |
| 1130 | + |
| 1131 | + @config.patch(realize_reads_threshold=1, realize_opcount_threshold=1) |
| 1132 | + @parametrize("benchmark_fusion", (True, False)) |
| 1133 | + def test_prologue_read_into_both_inputs(self, benchmark_fusion): |
| 1134 | + M = K = N = 256 |
| 1135 | + |
| 1136 | + # not supported today. it could be, but typically the pointwise nodes would get |
| 1137 | + # inlined into separate nodes. |
| 1138 | + |
| 1139 | + def foo(x): |
| 1140 | + y = (x + 1) * 2 |
| 1141 | + return y @ (y - 2) |
| 1142 | + |
| 1143 | + with config.patch(benchmark_epilogue_fusion=benchmark_fusion): |
| 1144 | + x = torch.rand([M, K], dtype=torch.float, device="cuda") |
| 1145 | + |
| 1146 | + out, code = run_and_get_code(torch.compile(foo), x) |
| 1147 | + self.assertEqual(out, foo(x), atol=0.05, rtol=0.05) |
| 1148 | + # not guaranteed to fuse, but still checking correctness |
| 1149 | + if not benchmark_fusion: |
| 1150 | + self.check_code( |
| 1151 | + code[0], num_kernels=2, num_allocs=None, num_deallocs=None |
| 1152 | + ) |
| 1153 | + |
| 1154 | + @config.patch(realize_reads_threshold=1, realize_opcount_threshold=1) |
| 1155 | + @config.patch(allow_buffer_reuse=False) |
| 1156 | + def test_mismatched_prologue_group(self): |
| 1157 | + def foo(x, y, z): |
| 1158 | + a = (x + 2) * 2 |
| 1159 | + b = a * y |
| 1160 | + return b @ z |
| 1161 | + |
| 1162 | + x = torch.rand([1, 256], device="cuda") |
| 1163 | + y = torch.rand([256, 256], device="cuda") |
| 1164 | + z = torch.rand([256, 128], device="cuda") |
| 1165 | + |
| 1166 | + out, code = run_and_get_code(torch.compile(foo), x, y, z) |
| 1167 | + self.assertEqual(out, foo(x, y, z), atol=0.05, rtol=0.05) |
| 1168 | + # theres one more dealloc than there should be because of a buffer reuse. TODO: |
| 1169 | + # not sure why disabling buffer reuse doesnt stop |
| 1170 | + self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=4) |
| 1171 | + |
| 1172 | + @config.patch(shape_padding=True) |
| 1173 | + @config.patch(force_shape_pad=True) |
| 1174 | + @parametrize("sizes", ((250, 245, 128), (250, 256, 128), (256, 128, 62))) |
| 1175 | + def test_prologue_masked_load(self, sizes): |
| 1176 | + M, K, N = sizes |
| 1177 | + |
| 1178 | + def foo(x, y): |
| 1179 | + return x @ y |
| 1180 | + |
| 1181 | + # cat will turn into masked load |
| 1182 | + # TODO - we should not attempt fusion if it turns an aligned load |
| 1183 | + # into an unaligned load |
| 1184 | + x = torch.rand([250, 245], device="cuda") |
| 1185 | + y = torch.rand([245, 128], device="cuda") |
| 1186 | + |
| 1187 | + out, code = run_and_get_code(torch.compile(foo), x, y) |
| 1188 | + self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05) |
| 1189 | + self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2) |
| 1190 | + |
| 1191 | + |
983 | 1192 | if __name__ == "__main__":
|
984 | 1193 | from torch._inductor.utils import is_big_gpu
|
985 | 1194 |
|
|
0 commit comments