Skip to content

Commit d57f617

Browse files
jianan-gupytorchmergebot
authored andcommitted
[Inductor][CPP] Avoid transpose with cpp micro-gemm for FlexAttention (pytorch#147069)
Pull Request resolved: pytorch#147069 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/drisspg ghstack dependencies: pytorch#147068
1 parent 6c089f5 commit d57f617

File tree

4 files changed

+97
-85
lines changed

4 files changed

+97
-85
lines changed

test/inductor/test_flex_attention.py

+1-18
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def create_block_mask_test(score_mod, query, key):
135135
)
136136

137137
test_dtypes = (
138-
[torch.float32, torch.bfloat16]
138+
[torch.float32, torch.bfloat16, torch.float16]
139139
if torch.backends.mkldnn.is_available()
140140
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
141141
else [torch.float32]
@@ -3677,23 +3677,6 @@ def test_cpu_error_message_return_lse(self):
36773677
):
36783678
attention(query, key, value, return_lse=True)
36793679

3680-
@unittest.skipIf(TEST_ON_CUDA, "Testing CPU error message")
3681-
def test_validate_cpu_dtype_error_message(self):
3682-
make_tensor = functools.partial(
3683-
torch.randn,
3684-
(2, 2, 128, 16),
3685-
device="cpu",
3686-
dtype=torch.half,
3687-
requires_grad=False,
3688-
)
3689-
query, key, value = make_tensor(), make_tensor(), make_tensor()
3690-
attention = torch.compile(flex_attention)
3691-
with self.assertRaisesRegex(
3692-
torch._inductor.exc.InductorError,
3693-
r"`torch.float` and `torch.bfloat16` are supported in FlexAttention for CPU device. Found input tensors are `torch.float16`.",
3694-
):
3695-
attention(query, key, value)
3696-
36973680
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
36983681
def test_device_cuda_1(self):
36993682
class TestModule(torch.nn.Module):

torch/_inductor/codegen/cpp_flex_attention_template.py

+93-64
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..utils import parallel_num_threads
1818
from ..virtualized import V
1919
from .cpp_template import CppTemplate
20+
from .cpp_utils import GemmBlocking
2021

2122

2223
log = logging.getLogger(__name__)
@@ -195,6 +196,10 @@
195196
}
196197
"""
197198

199+
MICRO_GEMM_TEMPLATE = r"""
200+
GEMM_DEFINE
201+
"""
202+
198203
ALLOCATE_BUFFER = r"""
199204
int64_t {{buffer_name}}_dtype_itemsize = std::is_same_v<{{buffer_dtype}}, at::BFloat16> ? 2 : 4;
200205
auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
@@ -208,6 +213,7 @@
208213
#include <ATen/native/cpu/utils.h>
209214
#include <ATen/native/CPUBlas.h>
210215
#include <ATen/Context.h>
216+
{{template.codegen_micro_gemm(kernel.kernel_name)}}
211217
{{template.codegen_softmax_fusion(kernel.kernel_name)}}
212218
{{template.codegen_brgemm_pack_function(kernel.kernel_name)}}
213219
{%- set kernel_args = {"query": query, "key": key, "value": value,
@@ -329,7 +335,6 @@
329335
need_pack = gemm_size_per_thread / pack_size >= 4;
330336
}
331337
}
332-
333338
// Pad is needed for packing when K is not even
334339
bool headSize_even = headSize % 2 == 0;
335340
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
@@ -358,37 +363,37 @@
358363
{{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}}
359364
{{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}}
360365
361-
// Reorder K, V and transpose K
362-
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
363-
int ompIdx = at::get_thread_num();
364-
int64_t i = 0, j = 0, l = 0, n = 0;
365-
scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr;
366-
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
367-
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
368-
n = l * kvSplitSize;
369-
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
370-
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
371-
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
372-
auto kv_block_num = n / cur_kvSplitSize;
373-
auto kv_block_offset = n - kv_block_num * cur_kvSplitSize;
374-
// getting kv indices by [BS, Head, 1, kv_block_num]
375-
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
376-
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
377-
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
378-
j_kvi * kviStrideH + kv_block_num;
379-
auto k_addr =
380-
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
381-
auto v_addr =
382-
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
383-
if (use_kv_indice) {
384-
k_addr =
385-
k_data + i_kv * kStrideB + j_kv * kStrideH +
386-
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN;
387-
v_addr =
388-
v_data + i_kv * vStrideB + j_kv * vStrideH +
389-
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN;
390-
}
391-
if (need_pack) {
366+
if (need_pack) {
367+
// Pack K, V
368+
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
369+
int ompIdx = at::get_thread_num();
370+
int64_t i = 0, j = 0, l = 0, n = 0;
371+
scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize;
372+
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
373+
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
374+
n = l * kvSplitSize;
375+
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
376+
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
377+
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
378+
auto kv_block_num = n / cur_kvSplitSize;
379+
auto kv_block_offset = n - kv_block_num * cur_kvSplitSize;
380+
// getting kv indices by [BS, Head, 1, kv_block_num]
381+
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
382+
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
383+
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
384+
j_kvi * kviStrideH + kv_block_num;
385+
auto k_addr =
386+
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
387+
auto v_addr =
388+
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
389+
if (use_kv_indice) {
390+
k_addr =
391+
k_data + i_kv * kStrideB + j_kv * kStrideH +
392+
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN;
393+
v_addr =
394+
v_data + i_kv * vStrideB + j_kv * vStrideH +
395+
(*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN;
396+
}
392397
// transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize]
393398
at::native::utils::transpose<uint16_t>(
394399
cur_kvSplitSize,
@@ -417,23 +422,11 @@
417422
/* ld_src */ vStrideN,
418423
/* K */ cur_kvSplitSize,
419424
/* N */ headSize_v);
420-
} else {
421-
using trans_t = std::conditional_t<std::is_same_v<scalar_t, at::BFloat16>, uint16_t, float>;
422-
at::native::utils::transpose<trans_t>(
423-
cur_kvSplitSize,
424-
headSize,
425-
/* src_ptr */
426-
reinterpret_cast<const trans_t*>(k_addr),
427-
/* ld_src */ kStrideN,
428-
/* dst */ reinterpret_cast<trans_t*>(key_reorder_ptr + i * num_head * eheadSize * kvSize +
429-
j * eheadSize * kvSize + n * eheadSize),
430-
/* ld_dst */ cur_kvSplitSize);
425+
// Move to the next query
426+
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
431427
}
432-
// Move to the next query
433-
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
434-
}
435-
});
436-
428+
});
429+
}
437430
// Attention loop below
438431
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
439432
int64_t i = 0, j = 0, k = 0;
@@ -488,22 +481,26 @@
488481
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
489482
j_kvi * kviStrideH + kv_block_num;
490483
if (!need_pack) {
491-
auto k_addr_t = key_reorder_ptr + i * num_head * eheadSize * kvSize +
492-
j * eheadSize * kvSize + n * eheadSize;
493-
// TODO: use the micro-gemm template instead of brgemm API
494-
at::native::cpublas::brgemm(
495-
cur_qSplitSize,
496-
cur_kvSplitSize,
497-
eheadSize,
498-
qStrideM,
499-
cur_kvSplitSize,
500-
cur_kvSplitSize,
501-
false,
484+
auto k_addr =
485+
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
486+
if (use_kv_indice) {
487+
k_addr =
488+
k_data + i_kv * kStrideB + j_kv * kStrideH +
489+
(*kv_logical_data * kvBlockSize + kv_block_offset) * kStrideN;
490+
}
491+
492+
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
502493
q_data + i * qStrideB + j * qStrideH +
503494
m * qStrideM,
504-
k_addr_t,
495+
k_addr,
505496
qk_data,
506-
need_pack);
497+
cur_qSplitSize,
498+
cur_kvSplitSize,
499+
headSize,
500+
qStrideM,
501+
kStrideN,
502+
cur_kvSplitSize);
503+
507504
} else {
508505
at::native::cpublas::brgemm(
509506
cur_qSplitSize,
@@ -690,7 +687,7 @@ def __init__(
690687
kernel_input_name_to_buffer,
691688
block_vars,
692689
) -> None:
693-
assert layout.dtype in [torch.float, torch.bfloat16]
690+
assert layout.dtype in [torch.float, torch.bfloat16, torch.float16]
694691
super().__init__("flex_attention", input_nodes, layout, parallel_num_threads())
695692
self.scale = scale
696693
self.score_mod = score_mod
@@ -958,6 +955,8 @@ def render( # type: ignore[override,return]
958955
query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3])
959956
key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3])
960957
value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3])
958+
self.accumulate_dtype = torch.float
959+
self.input_dtype = query.layout.dtype
961960

962961
num_threads = parallel_num_threads()
963962
buf_out = TensorBox.create(self.output_node)
@@ -975,8 +974,8 @@ def render( # type: ignore[override,return]
975974
score_mod_other_buffers=self.score_mod_other_buffers,
976975
mask_mod_other_buffers=self.mask_mod_other_buffers,
977976
scale=self.scale,
978-
accumulate_dtype=torch.float,
979-
query_dtype=query.layout.dtype,
977+
accumulate_dtype=self.accumulate_dtype,
978+
query_dtype=self.input_dtype,
980979
kvBlockSize=self.kv_block_size,
981980
template=self,
982981
output=buf_out,
@@ -1016,3 +1015,33 @@ def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size):
10161015
buffer_size=buffer_size,
10171016
)
10181017
)
1018+
1019+
def micro_gemm_define(self, kernel_name: str):
1020+
from torch._inductor.codegen.cpp_gemm_template import (
1021+
CppTemplateKernel,
1022+
parallel_num_threads,
1023+
)
1024+
from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec
1025+
from torch._inductor.virtualized import V
1026+
1027+
micro_gemm = CppMicroGemmFP32Vec(
1028+
kernel_name + "_kernel_micro_gemm",
1029+
self.input_dtype,
1030+
self.input_dtype,
1031+
self.accumulate_dtype,
1032+
self.accumulate_dtype,
1033+
GemmBlocking(1, 16, 1),
1034+
1,
1035+
True,
1036+
True,
1037+
)
1038+
1039+
with V.set_graph_handler(V.graph):
1040+
kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads())
1041+
code = micro_gemm.codegen_define(kernel)
1042+
return code
1043+
1044+
def codegen_micro_gemm(self, kernel_name: str):
1045+
micro_gemm = self.micro_gemm_define(kernel_name)
1046+
GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm)
1047+
return self._template_from_string(GEMM_SOURCE_CODE).render()

torch/_inductor/codegen/cpp_micro_gemm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ def codegen_define(self, kernel: CppTemplateKernel) -> str:
877877
if self.trans_b:
878878
# TODO supports tuning of sub_block_m/sub_block_n
879879
# to get better performance for specific shapes
880-
sub_block_m = min(4, self.register_blocking.block_m)
880+
sub_block_m = min(1, self.register_blocking.block_m)
881881
sub_block_n = min(4, self.register_blocking.block_n)
882882
# update options to generate kernel with trans_b and sub-block size
883883
options.update(

torch/_inductor/kernel/flex_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1099,9 +1099,9 @@ def convert_mask_graph_module(mask_graph):
10991099
raise NotImplementedError(
11001100
"Unsupported for now if query, key, value are the same buffer."
11011101
)
1102-
if query.get_dtype() not in [torch.float, torch.bfloat16]:
1102+
if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]:
11031103
raise NotImplementedError(
1104-
"`torch.float` and `torch.bfloat16` are supported in FlexAttention for CPU device. "
1104+
"`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. "
11051105
f"Found input tensors are `{query.get_dtype()}`."
11061106
)
11071107
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)

0 commit comments

Comments
 (0)