Skip to content

Commit 6f91720

Browse files
coconutrubenpytorchmergebot
authored andcommitted
[inductor][ck] manual kBatch heuristic (pytorch#148118)
Summary: # Why Leverage kBatch parameter for large splitK examples for CK for better than ATEN performance # What replace default kBatch = 1 with a manual heuristic - if K > 16 * max (M,N) - leverage k_per_block, and K and number of SMs on the chip - upper bound to 128, lower bound to 1 This is better than defaulting to 1, cheap to calculate, and shows performance beyond ATEN This is of course subject to change and improvement Test Plan: with minor modifications to to run torch.mm on the shape `M, N, K = 2048, 2048, 524288` ``` buck2 run -c fbcode.re_gpu_tests=False mode/opt-amd-gpu fbcode//deeplearning/aot_inductor/benchmark/sampling:test_gemm_autotune_benchmark_AMD_block_0 ``` ``` AUTOTUNE mm(2048x524288, 524288x2048) rocm_ck_gemm_template_49 10.4972 ms 100.0% rocm_ck_gemm_template_8 10.6132 ms 98.9% rocm_ck_gemm_template_9 10.6907 ms 98.2% [...] mm 18.9880 ms 55.3% ``` Reviewed By: ColinPeppler Differential Revision: D70224591 Pull Request resolved: pytorch#148118 Approved by: https://github.com/ColinPeppler
1 parent 48c55a6 commit 6f91720

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

torch/_inductor/codegen/rocm/ck_universal_gemm_template.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch._inductor.codegen.rocm.compile_command import rocm_compile_command
1616
from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel
1717
from torch._inductor.ir import Buffer, Layout
18+
from torch._inductor.runtime.runtime_utils import next_power_of_2
1819

1920
from ...utils import IndentedBuffer, try_import_ck_lib
2021

@@ -876,6 +877,27 @@ def _is_rcr_f16(self):
876877
and Y_layout == "Row"
877878
)
878879

880+
# helper to calculate a potentially optimal kBatch(es) for a problem
881+
def _get_kBatch(self, op):
882+
# we only set a higher kBatch if K > 16 * the larger of M and N
883+
# this is a hand-tuned heuristic to start
884+
metas = [T.get_layout() for T in [*self.input_nodes]]
885+
X_meta = metas[0]
886+
W_meta = metas[1]
887+
M = X_meta.size[-2]
888+
K = X_meta.size[-1]
889+
N = W_meta.size[-1]
890+
if K < 16 * max(M, N):
891+
return [1]
892+
# Calculate the number of blocks needed for each dimension
893+
total_k_blocks = math.ceil(K / op.k_per_block)
894+
# we want to calculate how many blocks we need to fit per CU
895+
cus = torch.cuda.get_device_properties(X_meta.device).multi_processor_count
896+
# again, manual heuristics as much larger kBatch are significantly worse in
897+
# initial testing
898+
kBatch = min(max(next_power_of_2(total_k_blocks // cus), 1), 128)
899+
return [kBatch]
900+
879901
def gen_ops(self) -> list[InductorROCmOp]:
880902
"""
881903
Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents.
@@ -905,14 +927,13 @@ def gen_ops(self) -> list[InductorROCmOp]:
905927

906928
assert generator is not None
907929

908-
# NOTE(coconutruben): for now, we only support kBatch 1
909-
# TODO(coconturuben): infer a better kBatch depending on the input shape
910930
# TODO(coconutruben): allow users to provide a list of kBatches to sweep over
911-
kBatches = [1]
912931
rops = generator()
913-
ops = [
914-
InductorROCmOp(op=op, kBatch=kBatch) for op in rops for kBatch in kBatches
915-
]
932+
ops = []
933+
for o in rops:
934+
kBatches = self._get_kBatch(o)
935+
for kBatch in kBatches:
936+
ops.append(InductorROCmOp(op=o, kBatch=kBatch))
916937

917938
filtered_instances = list(filter(lambda op: self.filter_op(op), ops))
918939

0 commit comments

Comments
 (0)