Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Latest commit

 

History

History
293 lines (271 loc) · 16.8 KB

fused_attention.md

File metadata and controls

293 lines (271 loc) · 16.8 KB

Fused Attention

Attention (including MHA, GQA, MQA) is one of the key parts of transformers and also the performance critical in many scenarios. To implement various optimizations, a fused attention layer and corresponding utilities for the customized KV-cache it uses are introduced. As an example, fused attention can reduce the cost of MHA from 8521.276 ms (17.044 ms) to 248.586 ms (7.944 ms) of a 1975-token llama-7b first-token inference[1].

Note that this doc assumes you have the basic knowledge of this cpp graph implementation including ne_tensor::ne, ne_tensor::nb etc. Model builder can enable fused attention with operators ne_flash_attn* defined in core/ne_layers.h while the implementation of fused attentions is mostly located in core/layers/mha_dense.cpp.

KV-cache initialization

The memory for kv-cache is allocated in neural_speed/models/model_utils/model_utils.cpp. As the fused attention implementation requires certain instruction extensions and potentially some other limitations, fused attention is enabled only if bestla_reordered_attn_fp32_support() aggress, denoting with memory_type = NE_TYPE_BTLA. Next, get_batch_kv_elements_from_gpt_params() will give the sizes (in terms of bytes if fused attention enabled, or in terms elements if fused attention disabled) of k-cache and v-cache respectively for each batch and each layer. The KV-cache is finally prepared with these 2 sizes by creating ne_new_tensor inside model.kv_self.k/.v or model.layer[il].k_cache/.v_cache.

KV-cache Append

KV-cache is appended every time a new pair of K and V are generated by evaluating inner product for QKV (ne_mul_qkv). This operation append an additional K/V-tensor on the dimension of sequence length (i.e. resulting n_past = n_past + N, where n_past is number of previously cached tokens and N is the length of current tokens).

The tensor of K and V should be both permuted (with ne_permute) to batch x N x head_num x head_size. The permuted K/V tensors can then be passed to ne_flash_attn_update_k()/_v(), together with n_past as an offset in the dimension of sequence length, to concatenate to a "view" of the cached K/V of the current layer and current batch.

Note: Currently the K and V tensor to append must be contiguous in the dimension of head size (i.e. head_dim in some implementations). The strides (ne_tensor::nb) of the other 3 dimensions are configurable.

Fused Attention Computation

With the KV-cache of customized type and layout, the attention can be computed at its best performance. ne_flash_attn accepts a Q-tensor of batch x head_num x N x head_size and outputs a result tensor of batch x N x head_num x head_size (totally contiguous).

Node: similar to the append operation, Q-tensor must be contiguous in the dimension of head and its strides (ne_tensor::nb) of the other 3 dimensions are configurable.

Supported Models

Fused attention is designed to be able to easily support various models:

Models Attention Type & Support Runtime ISA

LLaMA-7B, LLaMA-13B, LLaMA2-7B, LLaMA2-13B

MHA ✅ AMX_BF16

LLaMA2-70B

GQA ✅
GPT-J-6B MHA ✅
GPT-NeoX-20B, Dolly-v2-3B MHA ✅
Qwen-7B, Qwen-14B MHA ✅
MPT-7B, MPT-30B MHA with ALiBi
Falcon-7B, Falcon-40B MQA, GQA ✅
BLOOM-7B MHA with ALiBi 🚧
OPT-125m, OPT-350m, OPT-1.3B, OPT-13B MHA 🚧
ChatGLM-6B, ChatGLM2-6B MHA with the autoregressive-blank-infilling pattern 🚧,
GQA ✅
StarCoder-1B, StarCoder-3B, StarCoder-15.5B MHA ✅

✅: Supported; 🚧: WIP

Limitations

Currently the fused attention is only enabled when compiling the Neural Speed with GCC11+.

Tips for parallelism

Thanks to the mathematical nature of attention, one can simply parallel the whole kv-cache operations and fused attention on commonly-parallelizable dimensions. Just pass each part to every KV-cache operations (and merge them together if needed).

References

[1] The data was tested on a single socket of Intel(R) Xeon(R) Platinum 8480L on commit 01a809d. Details below.
1st-token fused attn disabled 1st-token fused attn enabled 4th-token fused attn disabled 4th-token fused attn enabled
total latency 9748.26ms 1475.57ms 50.37ms 41.27ms
fused-attn lat / 179.883ms + 68.703ms / 6.271ms + 1.673ms
est non-attn lat 1226.984ms 1226.984ms 33.326ms 33.326ms
MHA cost compare 8521.276ms 248.586ms 17.044ms 7.944ms

(4th token is taking as an example of next-token performance)

Row logs:

# fused attn enabled
rm -rf bin && cmake .. -GNinja -DNS_BUILD_TESTS=ON -DNE_PROFILING=ON -DCMAKE_BUILD_TYPE=Release && ninja run_llama && env ENGINE_PROFILING=1 numactl -m 1 -C 56-111 bin/run_llama -m llama-7b-hf-pr447-q4j-sym-int8-fp32-g128.bin --seed 1234 -t 56 -b 2048 -c 2048 -n 4 --memory-auto -p "$(echo "$LUOYU_PROMPT" | cut -d' ' -f 1-1500)"
Welcome to use the llama on the ITREX!
...
=== GRAPH Profiling ===
perf_total_per_op_us[                     ADD] =  51.409 ms
perf_total_per_op_us[                     MUL] =  26.328 ms
perf_total_per_op_us[                RMS_NORM] =  42.445 ms
perf_total_per_op_us[                 MUL_MAT] = 127.810 ms
perf_total_per_op_us[                 RESHAPE] =   0.446 ms
perf_total_per_op_us[                    VIEW] =   0.997 ms
perf_total_per_op_us[                 PERMUTE] =   0.101 ms
perf_total_per_op_us[               TRANSPOSE] =   0.105 ms
perf_total_per_op_us[                GET_ROWS] =   8.342 ms
perf_total_per_op_us[                    ROPE] =  44.115 ms
perf_total_per_op_us[                 MUL_QKV] = 252.611 ms
perf_total_per_op_us[                FFN_SILU] = 668.217 ms
perf_total_per_op_us[              FLASH_ATTN] = 179.883 ms
perf_total_per_op_us[    FLASH_ATTN_KV_UPDATE] =  68.703 ms
perf_total_per_op_us[           INNER PRODUCT] =   0.000 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[                     ADD] =   0.420 ms
perf_total_per_op_us[                     MUL] =   0.447 ms
perf_total_per_op_us[                RMS_NORM] =   1.377 ms
perf_total_per_op_us[                 RESHAPE] =   0.432 ms
perf_total_per_op_us[                    VIEW] =   0.956 ms
perf_total_per_op_us[                 PERMUTE] =   0.126 ms
perf_total_per_op_us[               TRANSPOSE] =   0.105 ms
perf_total_per_op_us[                GET_ROWS] =   0.024 ms
perf_total_per_op_us[                    ROPE] =   1.992 ms
perf_total_per_op_us[                 MUL_QKV] =   6.311 ms
perf_total_per_op_us[                FFN_SILU] =  14.597 ms
perf_total_per_op_us[              FLASH_ATTN] =   6.425 ms
perf_total_per_op_us[    FLASH_ATTN_KV_UPDATE] =   1.717 ms
perf_total_per_op_us[           INNER PRODUCT] =   3.535 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[                     ADD] =   0.402 ms
perf_total_per_op_us[                     MUL] =   0.358 ms
perf_total_per_op_us[                RMS_NORM] =   1.281 ms
perf_total_per_op_us[                 RESHAPE] =   0.427 ms
perf_total_per_op_us[                    VIEW] =   1.058 ms
perf_total_per_op_us[                 PERMUTE] =   0.106 ms
perf_total_per_op_us[               TRANSPOSE] =   0.102 ms
perf_total_per_op_us[                GET_ROWS] =   0.024 ms
perf_total_per_op_us[                    ROPE] =   1.919 ms
perf_total_per_op_us[                 MUL_QKV] =   5.881 ms
perf_total_per_op_us[                FFN_SILU] =  14.522 ms
perf_total_per_op_us[              FLASH_ATTN] =   6.389 ms
perf_total_per_op_us[    FLASH_ATTN_KV_UPDATE] =   1.621 ms
perf_total_per_op_us[           INNER PRODUCT] =   3.339 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[                     ADD] =   0.327 ms
perf_total_per_op_us[                     MUL] =   0.361 ms
perf_total_per_op_us[                RMS_NORM] =   1.272 ms
perf_total_per_op_us[                 RESHAPE] =   0.422 ms
perf_total_per_op_us[                    VIEW] =   1.032 ms
perf_total_per_op_us[                 PERMUTE] =   0.110 ms
perf_total_per_op_us[               TRANSPOSE] =   0.101 ms
perf_total_per_op_us[                GET_ROWS] =   0.023 ms
perf_total_per_op_us[                    ROPE] =   1.967 ms
perf_total_per_op_us[                 MUL_QKV] =   6.034 ms
perf_total_per_op_us[                FFN_SILU] =  14.527 ms
perf_total_per_op_us[              FLASH_ATTN] =   6.271 ms
perf_total_per_op_us[    FLASH_ATTN_KV_UPDATE] =   1.673 ms
perf_total_per_op_us[           INNER PRODUCT] =   3.444 ms
========================================

model_print_timings:        load time =  2691.89 ms
model_print_timings:      sample time =     2.36 ms /     4 runs   (    0.59 ms per token)
model_print_timings: prompt eval time =  1475.57 ms /  1975 tokens (    0.75 ms per token)
model_print_timings:        eval time =   124.68 ms /     3 runs   (   41.56 ms per token)
model_print_timings:       total time =  2853.38 ms
========== eval time log of each prediction ==========
prediction   0, time: 1475.57ms
prediction   1, time: 42.19ms
prediction   2, time: 41.22ms
prediction   3, time: 41.27ms

# fused attn disabled
rm -rf bin && cmake .. -GNinja -DNS_BUILD_TESTS=ON -DNE_PROFILING=ON -DCMAKE_BUILD_TYPE=Release && ninja run_llama && env ENGINE_PROFILING=1 numactl -m 1 -C 56-111 bin/run_llama -m llama-7b-hf-pr447-q4j-sym-int8-fp32-g128.bin --seed 1234 -t 56 -b 2048 -c 2048 -n 4 --memory-f16 -p "$(echo "$LUOYU_PROMPT" | cut -d' ' -f 1-1500)"
Welcome to use the llama on the ITREX!
...
=== GRAPH Profiling ===
perf_total_per_op_us[                     ADD] =  55.300 ms
perf_total_per_op_us[                     MUL] =  40.209 ms
perf_total_per_op_us[                RMS_NORM] =  63.544 ms
perf_total_per_op_us[                 MUL_MAT] = 6698.093 ms
perf_total_per_op_us[                   SCALE] = 1325.542 ms
perf_total_per_op_us[                     CPY] = 273.083 ms
perf_total_per_op_us[                 RESHAPE] =   0.460 ms
perf_total_per_op_us[                    VIEW] =   0.734 ms
perf_total_per_op_us[                 PERMUTE] =   0.323 ms
perf_total_per_op_us[               TRANSPOSE] =   0.105 ms
perf_total_per_op_us[                GET_ROWS] =   8.467 ms
perf_total_per_op_us[           DIAG_MASK_INF] =  69.310 ms
perf_total_per_op_us[                SOFT_MAX] = 226.629 ms
perf_total_per_op_us[                    ROPE] =  44.610 ms
perf_total_per_op_us[                 MUL_QKV] = 264.430 ms
perf_total_per_op_us[                FFN_SILU] = 672.668 ms
perf_total_per_op_us[           INNER PRODUCT] =   0.000 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[                     ADD] =   0.445 ms
perf_total_per_op_us[                     MUL] =   0.405 ms
perf_total_per_op_us[                RMS_NORM] =   1.232 ms
perf_total_per_op_us[                 MUL_MAT] =  10.702 ms
perf_total_per_op_us[                   SCALE] =   0.952 ms
perf_total_per_op_us[                     CPY] =   3.040 ms
perf_total_per_op_us[                 RESHAPE] =   0.416 ms
perf_total_per_op_us[                    VIEW] =   0.792 ms
perf_total_per_op_us[                 PERMUTE] =   0.323 ms
perf_total_per_op_us[               TRANSPOSE] =   0.103 ms
perf_total_per_op_us[                GET_ROWS] =   0.023 ms
perf_total_per_op_us[           DIAG_MASK_INF] =   0.118 ms
perf_total_per_op_us[                SOFT_MAX] =   1.359 ms
perf_total_per_op_us[                    ROPE] =   1.888 ms
perf_total_per_op_us[                 MUL_QKV] =   6.133 ms
perf_total_per_op_us[                FFN_SILU] =  14.607 ms
perf_total_per_op_us[           INNER PRODUCT] =   3.504 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[                     ADD] =   0.324 ms
perf_total_per_op_us[                     MUL] =   0.402 ms
perf_total_per_op_us[                RMS_NORM] =   1.321 ms
perf_total_per_op_us[                 MUL_MAT] =  10.624 ms
perf_total_per_op_us[                   SCALE] =   0.954 ms
perf_total_per_op_us[                     CPY] =   3.104 ms
perf_total_per_op_us[                 RESHAPE] =   0.425 ms
perf_total_per_op_us[                    VIEW] =   0.748 ms
perf_total_per_op_us[                 PERMUTE] =   0.316 ms
perf_total_per_op_us[               TRANSPOSE] =   0.102 ms
perf_total_per_op_us[                GET_ROWS] =   0.021 ms
perf_total_per_op_us[           DIAG_MASK_INF] =   0.111 ms
perf_total_per_op_us[                SOFT_MAX] =   1.362 ms
perf_total_per_op_us[                    ROPE] =   1.874 ms
perf_total_per_op_us[                 MUL_QKV] =   6.001 ms
perf_total_per_op_us[                FFN_SILU] =  14.542 ms
perf_total_per_op_us[           INNER PRODUCT] =   3.314 ms
========================================
=== GRAPH Profiling ===
perf_total_per_op_us[                     ADD] =   0.354 ms
perf_total_per_op_us[                     MUL] =   0.391 ms
perf_total_per_op_us[                RMS_NORM] =   1.379 ms
perf_total_per_op_us[                 MUL_MAT] =  10.610 ms
perf_total_per_op_us[                   SCALE] =   0.964 ms
perf_total_per_op_us[                     CPY] =   3.115 ms
perf_total_per_op_us[                 RESHAPE] =   0.430 ms
perf_total_per_op_us[                    VIEW] =   0.866 ms
perf_total_per_op_us[                 PERMUTE] =   0.336 ms
perf_total_per_op_us[               TRANSPOSE] =   0.109 ms
perf_total_per_op_us[                GET_ROWS] =   0.022 ms
perf_total_per_op_us[           DIAG_MASK_INF] =   0.108 ms
perf_total_per_op_us[                SOFT_MAX] =   1.410 ms
perf_total_per_op_us[                    ROPE] =   1.959 ms
perf_total_per_op_us[                 MUL_QKV] =   5.826 ms
perf_total_per_op_us[                FFN_SILU] =  14.737 ms
perf_total_per_op_us[           INNER PRODUCT] =   3.378 ms
========================================

model_print_timings:        load time = 10987.95 ms
model_print_timings:      sample time =     2.38 ms /     4 runs   (    0.60 ms per token)
model_print_timings: prompt eval time =  9748.26 ms /  1975 tokens (    4.94 ms per token)
model_print_timings:        eval time =   150.93 ms /     3 runs   (   50.31 ms per token)
model_print_timings:       total time = 11175.74 ms
========== eval time log of each prediction ==========
prediction   0, time: 9748.26ms
prediction   1, time: 50.56ms
prediction   2, time: 50.00ms
prediction   3, time: 50.37ms