Attention (including MHA, GQA, MQA) is one of the key parts of transformers and also the performance critical in many scenarios. To implement various optimizations, a fused attention layer and corresponding utilities for the customized KV-cache it uses are introduced. As an example, fused attention can reduce the cost of MHA from 8521.276 ms (17.044 ms) to 248.586 ms (7.944 ms) of a 1975-token llama-7b first-token inference[1].
Note that this doc assumes you have the basic knowledge of this cpp graph implementation including ne_tensor::ne
, ne_tensor::nb
etc. Model builder can enable fused attention with operators ne_flash_attn*
defined in core/ne_layers.h
while the implementation of fused attentions is mostly located in core/layers/mha_dense.cpp
.
The memory for kv-cache is allocated in neural_speed/models/model_utils/model_utils.cpp
. As the fused attention implementation requires certain instruction extensions and potentially some other limitations, fused attention is enabled only if bestla_reordered_attn_fp32_support()
aggress, denoting with memory_type = NE_TYPE_BTLA
. Next, get_batch_kv_elements_from_gpt_params()
will give the sizes (in terms of bytes if fused attention enabled, or in terms elements if fused attention disabled) of k-cache and v-cache respectively for each batch and each layer. The KV-cache is finally prepared with these 2 sizes by creating ne_new_tensor
inside model.kv_self.k/.v
or model.layer[il].k_cache/.v_cache
.
KV-cache is appended every time a new pair of K and V are generated by evaluating inner product for QKV (ne_mul_qkv). This operation append an additional K/V-tensor on the dimension of sequence length (i.e. resulting n_past = n_past + N
, where n_past
is number of previously cached tokens and N
is the length of current tokens).
The tensor of K and V should be both permuted (with ne_permute
) to batch x N x head_num x head_size
. The permuted K/V tensors can then be passed to ne_flash_attn_update_k()/_v()
, together with n_past
as an offset in the dimension of sequence length, to concatenate to a "view" of the cached K/V of the current layer and current batch.
Note: Currently the K and V tensor to append must be contiguous in the dimension of head size (i.e.
head_dim
in some implementations). The strides (ne_tensor::nb
) of the other 3 dimensions are configurable.
With the KV-cache of customized type and layout, the attention can be computed at its best performance. ne_flash_attn
accepts a Q-tensor of batch x head_num x N x head_size
and outputs a result tensor of batch x N x head_num x head_size
(totally contiguous).
Node: similar to the append operation, Q-tensor must be contiguous in the dimension of head and its strides (
ne_tensor::nb
) of the other 3 dimensions are configurable.
Fused attention is designed to be able to easily support various models:
Models | Attention Type & Support | Runtime ISA |
---|---|---|
MHA ✅ | AMX_BF16 | |
GQA ✅ | ||
GPT-J-6B | MHA ✅ | |
GPT-NeoX-20B, Dolly-v2-3B | MHA ✅ | |
Qwen-7B, Qwen-14B | MHA ✅ | |
MPT-7B, MPT-30B | MHA with ALiBi ✅ | |
Falcon-7B, Falcon-40B | MQA, GQA ✅ | |
BLOOM-7B | MHA with ALiBi 🚧 | |
OPT-125m, OPT-350m, OPT-1.3B, OPT-13B | MHA 🚧 | |
ChatGLM-6B, ChatGLM2-6B | MHA with the autoregressive-blank-infilling pattern 🚧, GQA ✅ |
|
StarCoder-1B, StarCoder-3B, StarCoder-15.5B | MHA ✅ |
✅: Supported; 🚧: WIP
Currently the fused attention is only enabled when compiling the Neural Speed with GCC11+.
Thanks to the mathematical nature of attention, one can simply parallel the whole kv-cache operations and fused attention on commonly-parallelizable dimensions. Just pass each part to every KV-cache operations (and merge them together if needed).
[1] The data was tested on a single socket of Intel(R) Xeon(R) Platinum 8480L on commit 01a809d. Details below.
1st-token fused attn disabled | 1st-token fused attn enabled | 4th-token fused attn disabled | 4th-token fused attn enabled | |
---|---|---|---|---|
total latency | 9748.26ms | 1475.57ms | 50.37ms | 41.27ms |
fused-attn lat | / | 179.883ms + 68.703ms | / | 6.271ms + 1.673ms |
est non-attn lat | 1226.984ms | 1226.984ms | 33.326ms | 33.326ms |
MHA cost compare | 8521.276ms | 248.586ms | 17.044ms | 7.944ms |
(4th token is taking as an example of next-token performance)
Row logs:
# fused attn enabled
rm -rf bin && cmake .. -GNinja -DNS_BUILD_TESTS=ON -DNE_PROFILING=ON -DCMAKE_BUILD_TYPE=Release && ninja run_llama && env ENGINE_PROFILING=1 numactl -m 1 -C 56-111 bin/run_llama -m llama-7b-hf-pr447-q4j-sym-int8-fp32-g128.bin --seed 1234 -t 56 -b 2048 -c 2048 -n 4 --memory-auto -p "$(echo "$LUOYU_PROMPT" | cut -d' ' -f 1-1500)"
Welcome to use the llama on the ITREX!
...
=== GRAPH Profiling ===
perf_total_per_op_us[ ADD] = 51.409 ms
perf_total_per_op_us[ MUL] = 26.328 ms
perf_total_per_op_us[ RMS_NORM] = 42.445 ms
perf_total_per_op_us[ MUL_MAT] = 127.810 ms
perf_total_per_op_us[ RESHAPE] = 0.446 ms
perf_total_per_op_us[ VIEW] = 0.997 ms
perf_total_per_op_us[ PERMUTE] = 0.101 ms
perf_total_per_op_us[ TRANSPOSE] = 0.105 ms
perf_total_per_op_us[ GET_ROWS] = 8.342 ms
perf_total_per_op_us[ ROPE] = 44.115 ms
perf_total_per_op_us[ MUL_QKV] = 252.611 ms
perf_total_per_op_us[ FFN_SILU] = 668.217 ms
perf_total_per_op_us[ FLASH_ATTN] = 179.883 ms
perf_total_per_op_us[ FLASH_ATTN_KV_UPDATE] = 68.703 ms
perf_total_per_op_us[ INNER PRODUCT] = 0.000 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[ ADD] = 0.420 ms
perf_total_per_op_us[ MUL] = 0.447 ms
perf_total_per_op_us[ RMS_NORM] = 1.377 ms
perf_total_per_op_us[ RESHAPE] = 0.432 ms
perf_total_per_op_us[ VIEW] = 0.956 ms
perf_total_per_op_us[ PERMUTE] = 0.126 ms
perf_total_per_op_us[ TRANSPOSE] = 0.105 ms
perf_total_per_op_us[ GET_ROWS] = 0.024 ms
perf_total_per_op_us[ ROPE] = 1.992 ms
perf_total_per_op_us[ MUL_QKV] = 6.311 ms
perf_total_per_op_us[ FFN_SILU] = 14.597 ms
perf_total_per_op_us[ FLASH_ATTN] = 6.425 ms
perf_total_per_op_us[ FLASH_ATTN_KV_UPDATE] = 1.717 ms
perf_total_per_op_us[ INNER PRODUCT] = 3.535 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[ ADD] = 0.402 ms
perf_total_per_op_us[ MUL] = 0.358 ms
perf_total_per_op_us[ RMS_NORM] = 1.281 ms
perf_total_per_op_us[ RESHAPE] = 0.427 ms
perf_total_per_op_us[ VIEW] = 1.058 ms
perf_total_per_op_us[ PERMUTE] = 0.106 ms
perf_total_per_op_us[ TRANSPOSE] = 0.102 ms
perf_total_per_op_us[ GET_ROWS] = 0.024 ms
perf_total_per_op_us[ ROPE] = 1.919 ms
perf_total_per_op_us[ MUL_QKV] = 5.881 ms
perf_total_per_op_us[ FFN_SILU] = 14.522 ms
perf_total_per_op_us[ FLASH_ATTN] = 6.389 ms
perf_total_per_op_us[ FLASH_ATTN_KV_UPDATE] = 1.621 ms
perf_total_per_op_us[ INNER PRODUCT] = 3.339 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[ ADD] = 0.327 ms
perf_total_per_op_us[ MUL] = 0.361 ms
perf_total_per_op_us[ RMS_NORM] = 1.272 ms
perf_total_per_op_us[ RESHAPE] = 0.422 ms
perf_total_per_op_us[ VIEW] = 1.032 ms
perf_total_per_op_us[ PERMUTE] = 0.110 ms
perf_total_per_op_us[ TRANSPOSE] = 0.101 ms
perf_total_per_op_us[ GET_ROWS] = 0.023 ms
perf_total_per_op_us[ ROPE] = 1.967 ms
perf_total_per_op_us[ MUL_QKV] = 6.034 ms
perf_total_per_op_us[ FFN_SILU] = 14.527 ms
perf_total_per_op_us[ FLASH_ATTN] = 6.271 ms
perf_total_per_op_us[ FLASH_ATTN_KV_UPDATE] = 1.673 ms
perf_total_per_op_us[ INNER PRODUCT] = 3.444 ms
========================================
model_print_timings: load time = 2691.89 ms
model_print_timings: sample time = 2.36 ms / 4 runs ( 0.59 ms per token)
model_print_timings: prompt eval time = 1475.57 ms / 1975 tokens ( 0.75 ms per token)
model_print_timings: eval time = 124.68 ms / 3 runs ( 41.56 ms per token)
model_print_timings: total time = 2853.38 ms
========== eval time log of each prediction ==========
prediction 0, time: 1475.57ms
prediction 1, time: 42.19ms
prediction 2, time: 41.22ms
prediction 3, time: 41.27ms
# fused attn disabled
rm -rf bin && cmake .. -GNinja -DNS_BUILD_TESTS=ON -DNE_PROFILING=ON -DCMAKE_BUILD_TYPE=Release && ninja run_llama && env ENGINE_PROFILING=1 numactl -m 1 -C 56-111 bin/run_llama -m llama-7b-hf-pr447-q4j-sym-int8-fp32-g128.bin --seed 1234 -t 56 -b 2048 -c 2048 -n 4 --memory-f16 -p "$(echo "$LUOYU_PROMPT" | cut -d' ' -f 1-1500)"
Welcome to use the llama on the ITREX!
...
=== GRAPH Profiling ===
perf_total_per_op_us[ ADD] = 55.300 ms
perf_total_per_op_us[ MUL] = 40.209 ms
perf_total_per_op_us[ RMS_NORM] = 63.544 ms
perf_total_per_op_us[ MUL_MAT] = 6698.093 ms
perf_total_per_op_us[ SCALE] = 1325.542 ms
perf_total_per_op_us[ CPY] = 273.083 ms
perf_total_per_op_us[ RESHAPE] = 0.460 ms
perf_total_per_op_us[ VIEW] = 0.734 ms
perf_total_per_op_us[ PERMUTE] = 0.323 ms
perf_total_per_op_us[ TRANSPOSE] = 0.105 ms
perf_total_per_op_us[ GET_ROWS] = 8.467 ms
perf_total_per_op_us[ DIAG_MASK_INF] = 69.310 ms
perf_total_per_op_us[ SOFT_MAX] = 226.629 ms
perf_total_per_op_us[ ROPE] = 44.610 ms
perf_total_per_op_us[ MUL_QKV] = 264.430 ms
perf_total_per_op_us[ FFN_SILU] = 672.668 ms
perf_total_per_op_us[ INNER PRODUCT] = 0.000 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[ ADD] = 0.445 ms
perf_total_per_op_us[ MUL] = 0.405 ms
perf_total_per_op_us[ RMS_NORM] = 1.232 ms
perf_total_per_op_us[ MUL_MAT] = 10.702 ms
perf_total_per_op_us[ SCALE] = 0.952 ms
perf_total_per_op_us[ CPY] = 3.040 ms
perf_total_per_op_us[ RESHAPE] = 0.416 ms
perf_total_per_op_us[ VIEW] = 0.792 ms
perf_total_per_op_us[ PERMUTE] = 0.323 ms
perf_total_per_op_us[ TRANSPOSE] = 0.103 ms
perf_total_per_op_us[ GET_ROWS] = 0.023 ms
perf_total_per_op_us[ DIAG_MASK_INF] = 0.118 ms
perf_total_per_op_us[ SOFT_MAX] = 1.359 ms
perf_total_per_op_us[ ROPE] = 1.888 ms
perf_total_per_op_us[ MUL_QKV] = 6.133 ms
perf_total_per_op_us[ FFN_SILU] = 14.607 ms
perf_total_per_op_us[ INNER PRODUCT] = 3.504 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[ ADD] = 0.324 ms
perf_total_per_op_us[ MUL] = 0.402 ms
perf_total_per_op_us[ RMS_NORM] = 1.321 ms
perf_total_per_op_us[ MUL_MAT] = 10.624 ms
perf_total_per_op_us[ SCALE] = 0.954 ms
perf_total_per_op_us[ CPY] = 3.104 ms
perf_total_per_op_us[ RESHAPE] = 0.425 ms
perf_total_per_op_us[ VIEW] = 0.748 ms
perf_total_per_op_us[ PERMUTE] = 0.316 ms
perf_total_per_op_us[ TRANSPOSE] = 0.102 ms
perf_total_per_op_us[ GET_ROWS] = 0.021 ms
perf_total_per_op_us[ DIAG_MASK_INF] = 0.111 ms
perf_total_per_op_us[ SOFT_MAX] = 1.362 ms
perf_total_per_op_us[ ROPE] = 1.874 ms
perf_total_per_op_us[ MUL_QKV] = 6.001 ms
perf_total_per_op_us[ FFN_SILU] = 14.542 ms
perf_total_per_op_us[ INNER PRODUCT] = 3.314 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[ ADD] = 0.354 ms
perf_total_per_op_us[ MUL] = 0.391 ms
perf_total_per_op_us[ RMS_NORM] = 1.379 ms
perf_total_per_op_us[ MUL_MAT] = 10.610 ms
perf_total_per_op_us[ SCALE] = 0.964 ms
perf_total_per_op_us[ CPY] = 3.115 ms
perf_total_per_op_us[ RESHAPE] = 0.430 ms
perf_total_per_op_us[ VIEW] = 0.866 ms
perf_total_per_op_us[ PERMUTE] = 0.336 ms
perf_total_per_op_us[ TRANSPOSE] = 0.109 ms
perf_total_per_op_us[ GET_ROWS] = 0.022 ms
perf_total_per_op_us[ DIAG_MASK_INF] = 0.108 ms
perf_total_per_op_us[ SOFT_MAX] = 1.410 ms
perf_total_per_op_us[ ROPE] = 1.959 ms
perf_total_per_op_us[ MUL_QKV] = 5.826 ms
perf_total_per_op_us[ FFN_SILU] = 14.737 ms
perf_total_per_op_us[ INNER PRODUCT] = 3.378 ms
========================================
model_print_timings: load time = 10987.95 ms
model_print_timings: sample time = 2.38 ms / 4 runs ( 0.60 ms per token)
model_print_timings: prompt eval time = 9748.26 ms / 1975 tokens ( 4.94 ms per token)
model_print_timings: eval time = 150.93 ms / 3 runs ( 50.31 ms per token)
model_print_timings: total time = 11175.74 ms
========== eval time log of each prediction ==========
prediction 0, time: 9748.26ms
prediction 1, time: 50.56ms
prediction 2, time: 50.00ms
prediction 3, time: 50.37ms