17
17
from ..utils import parallel_num_threads
18
18
from ..virtualized import V
19
19
from .cpp_template import CppTemplate
20
+ from .cpp_utils import GemmBlocking
20
21
21
22
22
23
log = logging .getLogger (__name__ )
195
196
}
196
197
"""
197
198
199
+ MICRO_GEMM_TEMPLATE = r"""
200
+ GEMM_DEFINE
201
+ """
202
+
198
203
ALLOCATE_BUFFER = r"""
199
204
int64_t {{buffer_name}}_dtype_itemsize = std::is_same_v<{{buffer_dtype}}, at::BFloat16> ? 2 : 4;
200
205
auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
208
213
#include <ATen/native/cpu/utils.h>
209
214
#include <ATen/native/CPUBlas.h>
210
215
#include <ATen/Context.h>
216
+ {{template.codegen_micro_gemm(kernel.kernel_name)}}
211
217
{{template.codegen_softmax_fusion(kernel.kernel_name)}}
212
218
{{template.codegen_brgemm_pack_function(kernel.kernel_name)}}
213
219
{%- set kernel_args = {"query": query, "key": key, "value": value,
329
335
need_pack = gemm_size_per_thread / pack_size >= 4;
330
336
}
331
337
}
332
-
333
338
// Pad is needed for packing when K is not even
334
339
bool headSize_even = headSize % 2 == 0;
335
340
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
358
363
{{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}}
359
364
{{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}}
360
365
361
- // Reorder K, V and transpose K
362
- at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
363
- int ompIdx = at::get_thread_num();
364
- int64_t i = 0, j = 0, l = 0, n = 0 ;
365
- scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr ;
366
- at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice) ;
367
- for ([[maybe_unused]] auto z : c10::irange (begin, end)) {
368
- n = l * kvSplitSize;
369
- int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n) ;
370
- auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i ;
371
- auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j ;
372
- auto kv_block_num = n / cur_kvSplitSize ;
373
- auto kv_block_offset = n - kv_block_num * cur_kvSplitSize;
374
- // getting kv indices by [BS, Head, 1, kv_block_num]
375
- auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
376
- auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j ;
377
- auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
378
- j_kvi * kviStrideH + kv_block_num;
379
- auto k_addr =
380
- k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
381
- auto v_addr =
382
- v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
383
- if (use_kv_indice) {
384
- k_addr =
385
- k_data + i_kv * kStrideB + j_kv * kStrideH +
386
- (*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN;
387
- v_addr =
388
- v_data + i_kv * vStrideB + j_kv * vStrideH +
389
- (*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN;
390
- }
391
- if (need_pack) {
366
+ if (need_pack) {
367
+ // Pack K, V
368
+ at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
369
+ int ompIdx = at::get_thread_num() ;
370
+ int64_t i = 0, j = 0, l = 0, n = 0 ;
371
+ scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize ;
372
+ at::native::data_index_init (begin, i, batchSize, j, num_head, l, kvSlice);
373
+ for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
374
+ n = l * kvSplitSize ;
375
+ int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n) ;
376
+ auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i ;
377
+ auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j ;
378
+ auto kv_block_num = n / cur_kvSplitSize;
379
+ auto kv_block_offset = n - kv_block_num * cur_kvSplitSize;
380
+ // getting kv indices by [BS, Head, 1, kv_block_num]
381
+ auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i ;
382
+ auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
383
+ auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
384
+ j_kvi * kviStrideH + kv_block_num;
385
+ auto k_addr =
386
+ k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
387
+ auto v_addr =
388
+ v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
389
+ if (use_kv_indice) {
390
+ k_addr =
391
+ k_data + i_kv * kStrideB + j_kv * kStrideH +
392
+ (*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN;
393
+ v_addr =
394
+ v_data + i_kv * vStrideB + j_kv * vStrideH +
395
+ (*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN;
396
+ }
392
397
// transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize]
393
398
at::native::utils::transpose<uint16_t>(
394
399
cur_kvSplitSize,
417
422
/* ld_src */ vStrideN,
418
423
/* K */ cur_kvSplitSize,
419
424
/* N */ headSize_v);
420
- } else {
421
- using trans_t = std::conditional_t<std::is_same_v<scalar_t, at::BFloat16>, uint16_t, float>;
422
- at::native::utils::transpose<trans_t>(
423
- cur_kvSplitSize,
424
- headSize,
425
- /* src_ptr */
426
- reinterpret_cast<const trans_t*>(k_addr),
427
- /* ld_src */ kStrideN,
428
- /* dst */ reinterpret_cast<trans_t*>(key_reorder_ptr + i * num_head * eheadSize * kvSize +
429
- j * eheadSize * kvSize + n * eheadSize),
430
- /* ld_dst */ cur_kvSplitSize);
425
+ // Move to the next query
426
+ at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
431
427
}
432
- // Move to the next query
433
- at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
434
- }
435
- });
436
-
428
+ });
429
+ }
437
430
// Attention loop below
438
431
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
439
432
int64_t i = 0, j = 0, k = 0;
488
481
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
489
482
j_kvi * kviStrideH + kv_block_num;
490
483
if (!need_pack) {
491
- auto k_addr_t = key_reorder_ptr + i * num_head * eheadSize * kvSize +
492
- j * eheadSize * kvSize + n * eheadSize;
493
- // TODO: use the micro-gemm template instead of brgemm API
494
- at::native::cpublas::brgemm(
495
- cur_qSplitSize,
496
- cur_kvSplitSize,
497
- eheadSize,
498
- qStrideM,
499
- cur_kvSplitSize,
500
- cur_kvSplitSize,
501
- false,
484
+ auto k_addr =
485
+ k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
486
+ if (use_kv_indice) {
487
+ k_addr =
488
+ k_data + i_kv * kStrideB + j_kv * kStrideH +
489
+ (*kv_logical_data * kvBlockSize + kv_block_offset) * kStrideN;
490
+ }
491
+
492
+ {{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
502
493
q_data + i * qStrideB + j * qStrideH +
503
494
m * qStrideM,
504
- k_addr_t ,
495
+ k_addr ,
505
496
qk_data,
506
- need_pack);
497
+ cur_qSplitSize,
498
+ cur_kvSplitSize,
499
+ headSize,
500
+ qStrideM,
501
+ kStrideN,
502
+ cur_kvSplitSize);
503
+
507
504
} else {
508
505
at::native::cpublas::brgemm(
509
506
cur_qSplitSize,
@@ -690,7 +687,7 @@ def __init__(
690
687
kernel_input_name_to_buffer ,
691
688
block_vars ,
692
689
) -> None :
693
- assert layout .dtype in [torch .float , torch .bfloat16 ]
690
+ assert layout .dtype in [torch .float , torch .bfloat16 , torch . float16 ]
694
691
super ().__init__ ("flex_attention" , input_nodes , layout , parallel_num_threads ())
695
692
self .scale = scale
696
693
self .score_mod = score_mod
@@ -958,6 +955,8 @@ def render( # type: ignore[override,return]
958
955
query = kernel .permute (self .input_nodes [0 ], [0 , 2 , 1 , 3 ])
959
956
key = kernel .permute (self .input_nodes [1 ], [0 , 2 , 1 , 3 ])
960
957
value = kernel .permute (self .input_nodes [2 ], [0 , 2 , 1 , 3 ])
958
+ self .accumulate_dtype = torch .float
959
+ self .input_dtype = query .layout .dtype
961
960
962
961
num_threads = parallel_num_threads ()
963
962
buf_out = TensorBox .create (self .output_node )
@@ -975,8 +974,8 @@ def render( # type: ignore[override,return]
975
974
score_mod_other_buffers = self .score_mod_other_buffers ,
976
975
mask_mod_other_buffers = self .mask_mod_other_buffers ,
977
976
scale = self .scale ,
978
- accumulate_dtype = torch . float ,
979
- query_dtype = query . layout . dtype ,
977
+ accumulate_dtype = self . accumulate_dtype ,
978
+ query_dtype = self . input_dtype ,
980
979
kvBlockSize = self .kv_block_size ,
981
980
template = self ,
982
981
output = buf_out ,
@@ -1016,3 +1015,33 @@ def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size):
1016
1015
buffer_size = buffer_size ,
1017
1016
)
1018
1017
)
1018
+
1019
+ def micro_gemm_define (self , kernel_name : str ):
1020
+ from torch ._inductor .codegen .cpp_gemm_template import (
1021
+ CppTemplateKernel ,
1022
+ parallel_num_threads ,
1023
+ )
1024
+ from torch ._inductor .codegen .cpp_micro_gemm import CppMicroGemmFP32Vec
1025
+ from torch ._inductor .virtualized import V
1026
+
1027
+ micro_gemm = CppMicroGemmFP32Vec (
1028
+ kernel_name + "_kernel_micro_gemm" ,
1029
+ self .input_dtype ,
1030
+ self .input_dtype ,
1031
+ self .accumulate_dtype ,
1032
+ self .accumulate_dtype ,
1033
+ GemmBlocking (1 , 16 , 1 ),
1034
+ 1 ,
1035
+ True ,
1036
+ True ,
1037
+ )
1038
+
1039
+ with V .set_graph_handler (V .graph ):
1040
+ kernel = CppTemplateKernel ("cpp_micro_gemm" , parallel_num_threads ())
1041
+ code = micro_gemm .codegen_define (kernel )
1042
+ return code
1043
+
1044
+ def codegen_micro_gemm (self , kernel_name : str ):
1045
+ micro_gemm = self .micro_gemm_define (kernel_name )
1046
+ GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE .replace ("GEMM_DEFINE" , micro_gemm )
1047
+ return self ._template_from_string (GEMM_SOURCE_CODE ).render ()
0 commit comments