You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[kernel] Refactor FA kernel to be FA_transV when possible. (#568)
Flash Attention transpose_V variant is significantly faster than the non
transpose_V variant. This is due to many MM intrinsics being mmtb by
default. Hence, doing FA transpose_V will allow for better/more
contiguous reads from shared memory to register, improving the attention
performance vastly. This also makes FP8 faster than FP16. I have tested
that it indeed improves SDXL performance on FP8, making FP8 faster than
our FP16 model.
I have also tested/confirmed that, if we do not find any producers that
we can fuse with, it seem to re-fuse back into the attention. Hence, the
worst performance it will get is same as before we un-split the
transpose.
For some data on a microbenchmark with real size from SDXL:
```
(B0, B1, M, K1, K2, N): (2, 10, 4096, 64, 4096, 64)
Over 100 runs:
FP16 non transpose: 22.7 ms
FP8 non transpose: 23.8 ms
FP16 transpose: 20.1 ms
FP8 transpose: 17.5 ms
```
Additionally, this PR also moves the reduction dimension of attention to
the fastest dimension. This is preferable because many optimization
passes expects reduction dims to be fastest dims, and will match our
lowerings pass from IREE more.
Signed-off-by: Stanley Winata <[email protected]>
0 commit comments