Skip to content

Commit a7feae8

Browse files
authored
[kernel] Refactor FA kernel to be FA_transV when possible. (#568)
Flash Attention transpose_V variant is significantly faster than the non transpose_V variant. This is due to many MM intrinsics being mmtb by default. Hence, doing FA transpose_V will allow for better/more contiguous reads from shared memory to register, improving the attention performance vastly. This also makes FP8 faster than FP16. I have tested that it indeed improves SDXL performance on FP8, making FP8 faster than our FP16 model. I have also tested/confirmed that, if we do not find any producers that we can fuse with, it seem to re-fuse back into the attention. Hence, the worst performance it will get is same as before we un-split the transpose. For some data on a microbenchmark with real size from SDXL: ``` (B0, B1, M, K1, K2, N): (2, 10, 4096, 64, 4096, 64) Over 100 runs: FP16 non transpose: 22.7 ms FP8 non transpose: 23.8 ms FP16 transpose: 20.1 ms FP8 transpose: 17.5 ms ``` Additionally, this PR also moves the reduction dimension of attention to the fastest dimension. This is preferable because many optimization passes expects reduction dims to be fastest dims, and will match our lowerings pass from IREE more. Signed-off-by: Stanley Winata <[email protected]>
1 parent 4dd2fc8 commit a7feae8

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

sharktank/sharktank/kernels/templates/flash_attention.mlir

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
!q_type = tensor<?x?x{{l}}x{{d}}x{{i_type}}>
88
!k_type = tensor<?x?x{{s}}x{{d}}x{{i_type}}>
99
!v_type = tensor<?x?x{{s}}x{{e}}x{{i_type}}>
10+
!trans_v_type = tensor<?x?x{{e}}x{{s}}x{{i_type}}>
1011
!o_type = tensor<?x?x{{l}}x{{e}}x{{o_type}}>
1112
!o_dyn_type = tensor<?x?x?x?x{{o_type}}>
1213
!s_type = tensor<{{scale_type}}>
@@ -32,16 +33,19 @@ util.func private @sharktank_flash_attention_{{l}}_{{s}}_{{d}}_{{e}}_{{i_type}}_
3233

3334
%scale = tensor.extract %s[] : !s_type
3435

36+
%init_trans_v = tensor.empty(%b0, %b1) : !trans_v_type
37+
%transpose_v = linalg.transpose ins(%v: !v_type) outs(%init_trans_v: !trans_v_type) permutation = [0, 1, 3, 2]
38+
3539
%empty_dyn = tensor.empty(%b0, %b1, %l, %e) : !o_dyn_type
3640
%empty = tensor.cast %empty_dyn : !o_dyn_type to !o_type
3741

3842
%atten = iree_linalg_ext.attention {indexing_maps = [
3943
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
40-
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
44+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>,
4145
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
4246
affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
43-
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
44-
ins(%q, %k, %v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) {
47+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>]}
48+
ins(%q, %k, %transpose_v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) {
4549
^bb0(%score: f32):
4650
iree_linalg_ext.yield %score : f32
4751
} -> !o_type

0 commit comments

Comments
 (0)