|
15 | 15 | from torch._inductor.codegen.rocm.compile_command import rocm_compile_command
|
16 | 16 | from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel
|
17 | 17 | from torch._inductor.ir import Buffer, Layout
|
| 18 | +from torch._inductor.runtime.runtime_utils import next_power_of_2 |
18 | 19 |
|
19 | 20 | from ...utils import IndentedBuffer, try_import_ck_lib
|
20 | 21 |
|
@@ -876,6 +877,27 @@ def _is_rcr_f16(self):
|
876 | 877 | and Y_layout == "Row"
|
877 | 878 | )
|
878 | 879 |
|
| 880 | + # helper to calculate a potentially optimal kBatch(es) for a problem |
| 881 | + def _get_kBatch(self, op): |
| 882 | + # we only set a higher kBatch if K > 16 * the larger of M and N |
| 883 | + # this is a hand-tuned heuristic to start |
| 884 | + metas = [T.get_layout() for T in [*self.input_nodes]] |
| 885 | + X_meta = metas[0] |
| 886 | + W_meta = metas[1] |
| 887 | + M = X_meta.size[-2] |
| 888 | + K = X_meta.size[-1] |
| 889 | + N = W_meta.size[-1] |
| 890 | + if K < 16 * max(M, N): |
| 891 | + return [1] |
| 892 | + # Calculate the number of blocks needed for each dimension |
| 893 | + total_k_blocks = math.ceil(K / op.k_per_block) |
| 894 | + # we want to calculate how many blocks we need to fit per CU |
| 895 | + cus = torch.cuda.get_device_properties(X_meta.device).multi_processor_count |
| 896 | + # again, manual heuristics as much larger kBatch are significantly worse in |
| 897 | + # initial testing |
| 898 | + kBatch = min(max(next_power_of_2(total_k_blocks // cus), 1), 128) |
| 899 | + return [kBatch] |
| 900 | + |
879 | 901 | def gen_ops(self) -> list[InductorROCmOp]:
|
880 | 902 | """
|
881 | 903 | Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents.
|
@@ -905,14 +927,13 @@ def gen_ops(self) -> list[InductorROCmOp]:
|
905 | 927 |
|
906 | 928 | assert generator is not None
|
907 | 929 |
|
908 |
| - # NOTE(coconutruben): for now, we only support kBatch 1 |
909 |
| - # TODO(coconturuben): infer a better kBatch depending on the input shape |
910 | 930 | # TODO(coconutruben): allow users to provide a list of kBatches to sweep over
|
911 |
| - kBatches = [1] |
912 | 931 | rops = generator()
|
913 |
| - ops = [ |
914 |
| - InductorROCmOp(op=op, kBatch=kBatch) for op in rops for kBatch in kBatches |
915 |
| - ] |
| 932 | + ops = [] |
| 933 | + for o in rops: |
| 934 | + kBatches = self._get_kBatch(o) |
| 935 | + for kBatch in kBatches: |
| 936 | + ops.append(InductorROCmOp(op=o, kBatch=kBatch)) |
916 | 937 |
|
917 | 938 | filtered_instances = list(filter(lambda op: self.filter_op(op), ops))
|
918 | 939 |
|
|
0 commit comments