@@ -358,42 +358,42 @@ Status LeanAttention(
358
358
constexpr bool is_bf16 = false ;
359
359
360
360
ORT_RETURN_IF_ERROR (onnxruntime::lean::mha_fwd_kvcache (
361
- device_prop, stream,
362
- data.q ,
363
- data.k , // k_cache
364
- data.v , // v_cache
365
- nullptr , // new_k (we have appended new_k to k_cache)
366
- nullptr , // new_v (we have appended new_v to k_cache)
367
- data.output ,
368
- reinterpret_cast <void *>(data.softmax_lse ),
369
- nullptr , // seqlens_k
370
- nullptr , // cos_cache
371
- nullptr , // sin_cache
372
- nullptr , // block_table
373
- parameters.batch_size ,
374
- parameters.num_heads ,
375
- parameters.num_heads , // num_heads_k
376
- parameters.head_size ,
377
- parameters.sequence_length , // seqlen_q
378
- parameters.total_sequence_length , // seqlen_k
379
- 0 , // seqlen_k_new
380
- 0 , // rotary_dim
381
- scale, // softmax_scale
382
- parameters.is_unidirectional ,
383
- is_bf16,
384
- false , // past_bsnh
385
- data.num_splits ,
386
- data.grid_dim_z ,
387
- data.max_tiles_per_tb ,
388
- data.high_load_tbs ,
389
- data.tiles_per_head ,
390
- reinterpret_cast <void *>(data.softmax_lse_accum ),
391
- reinterpret_cast <void *>(data.out_accum ),
392
- data.lean_sync_flag ,
393
- -1 , // local_window_size
394
- false , // is_rotary_interleaved
395
- false // is_packed_qkv
396
- ));
361
+ device_prop, stream,
362
+ data.q ,
363
+ data.k , // k_cache
364
+ data.v , // v_cache
365
+ nullptr , // new_k (we have appended new_k to k_cache)
366
+ nullptr , // new_v (we have appended new_v to k_cache)
367
+ data.output ,
368
+ reinterpret_cast <void *>(data.softmax_lse ),
369
+ nullptr , // seqlens_k
370
+ nullptr , // cos_cache
371
+ nullptr , // sin_cache
372
+ nullptr , // block_table
373
+ parameters.batch_size ,
374
+ parameters.num_heads ,
375
+ parameters.num_heads , // num_heads_k
376
+ parameters.head_size ,
377
+ parameters.sequence_length , // seqlen_q
378
+ parameters.total_sequence_length , // seqlen_k
379
+ 0 , // seqlen_k_new
380
+ 0 , // rotary_dim
381
+ scale, // softmax_scale
382
+ parameters.is_unidirectional ,
383
+ is_bf16,
384
+ false , // past_bsnh
385
+ data.num_splits ,
386
+ data.grid_dim_z ,
387
+ data.max_tiles_per_tb ,
388
+ data.high_load_tbs ,
389
+ data.tiles_per_head ,
390
+ reinterpret_cast <void *>(data.softmax_lse_accum ),
391
+ reinterpret_cast <void *>(data.out_accum ),
392
+ data.lean_sync_flag ,
393
+ -1 , // local_window_size
394
+ false , // is_rotary_interleaved
395
+ false // is_packed_qkv
396
+ ));
397
397
398
398
return Status::OK ();
399
399
}
@@ -414,8 +414,6 @@ Status LeanAttention(
414
414
}
415
415
#endif
416
416
417
-
418
-
419
417
template <typename T>
420
418
Status CudnnFlashAttention (
421
419
cudnnHandle_t cudnn_handle,
@@ -439,21 +437,21 @@ Status CudnnFlashAttention(
439
437
data.k ,
440
438
data.v ,
441
439
attention_bias,
442
- nullptr , // (optional) mask_sequence_lengths_q
443
- mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
440
+ nullptr , // (optional) mask_sequence_lengths_q
441
+ mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
444
442
parameters.batch_size ,
445
- parameters.num_heads , // num_heads_q,
446
- parameters.num_heads , // num_heads_kv,
447
- parameters.head_size , // head_size_qk
448
- parameters.v_head_size , // head_size_v
449
- parameters.sequence_length , // sequence_length_q
450
- parameters.total_sequence_length , // sequence_length_kv
451
- scale, // scaling factor applied prior softmax
452
- parameters.is_unidirectional , // causal
453
- is_bf16, // True if bfloat16, otherwise float16
454
- parameters.broadcast_attn_bias_dim_0 , // broadcast attention bias dimension 0 or not
455
- parameters.broadcast_attn_bias_dim_1 , // broadcast attention bias dimension 1 or not
456
- 0 , // sliding window length. 0 means no sliding window.
443
+ parameters.num_heads , // num_heads_q,
444
+ parameters.num_heads , // num_heads_kv,
445
+ parameters.head_size , // head_size_qk
446
+ parameters.v_head_size , // head_size_v
447
+ parameters.sequence_length , // sequence_length_q
448
+ parameters.total_sequence_length , // sequence_length_kv
449
+ scale, // scaling factor applied prior softmax
450
+ parameters.is_unidirectional , // causal
451
+ is_bf16, // True if bfloat16, otherwise float16
452
+ parameters.broadcast_attn_bias_dim_0 , // broadcast attention bias dimension 0 or not
453
+ parameters.broadcast_attn_bias_dim_1 , // broadcast attention bias dimension 1 or not
454
+ 0 , // sliding window length. 0 means no sliding window.
457
455
data.qkv_format ,
458
456
cudnn_handle,
459
457
ort_stream,
@@ -540,10 +538,9 @@ Status EfficientAttention(
540
538
541
539
template <typename T, typename QK>
542
540
Status LaunchDecoderMaskedMultiHeadAttention (
543
- const DecoderMaskedMultiHeadAttentionParameters& parameters,
544
- cudaStream_t stream,
545
- const int head_size) {
546
-
541
+ const DecoderMaskedMultiHeadAttentionParameters& parameters,
542
+ cudaStream_t stream,
543
+ const int head_size) {
547
544
DUMP_STRING_INIT ();
548
545
DUMP_STRING (" DMMHA parameters..." );
549
546
DUMP_STRING (" is_mha = " , (parameters.is_mha == true ));
@@ -763,7 +760,7 @@ Status UnfusedAttention(
763
760
if (nullptr != data.output_qk ) {
764
761
int64_t qk_size = (int64_t )batch_size * num_heads * sequence_length * total_sequence_length;
765
762
ORT_RETURN_IF_ERROR (
766
- (CopyQK<T, QK>(stream, static_cast <int >(qk_size), data.scratch , reinterpret_cast <QK*>(data.output_qk ))));
763
+ (CopyQK<T, QK>(stream, static_cast <int >(qk_size), data.scratch , reinterpret_cast <QK*>(data.output_qk ))));
767
764
}
768
765
ORT_RETURN_IF_ERROR (
769
766
ComputeSoftmax<T>(
@@ -802,7 +799,7 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
802
799
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
803
800
// When there is past state, the head size for Q/K/V shall be same: H == H_v.
804
801
805
- if (nullptr != data.present ) { // Attention op
802
+ if (nullptr != data.present ) { // Attention op
806
803
assert (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH ||
807
804
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
808
805
@@ -811,12 +808,10 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
811
808
stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
812
809
max_threads_per_block, 2 , data.past , data.k , data.present ));
813
810
814
-
815
-
816
811
// Update pointers to present_k and present_v.
817
812
data.k = data.present ;
818
813
data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
819
- } else { // MultiHeadAttention op
814
+ } else { // MultiHeadAttention op
820
815
if (nullptr != data.present_key ) {
821
816
ORT_ENFORCE (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH ||
822
817
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
@@ -826,16 +821,16 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
826
821
827
822
ORT_RETURN_IF_ERROR (
828
823
LaunchConcatTensorToTensor (stream, total_sequence_length, sequence_length,
829
- batch_size, qk_head_size, num_heads,
830
- max_threads_per_block, 1 , data.past_key , data.k , data.present_key ));
824
+ batch_size, qk_head_size, num_heads,
825
+ max_threads_per_block, 1 , data.past_key , data.k , data.present_key ));
831
826
ORT_RETURN_IF_ERROR (
832
827
LaunchConcatTensorToTensor (stream, total_sequence_length, sequence_length,
833
- batch_size, v_head_size, num_heads,
834
- max_threads_per_block, 1 , data.past_value , data.v , data.present_value ));
828
+ batch_size, v_head_size, num_heads,
829
+ max_threads_per_block, 1 , data.past_value , data.v , data.present_value ));
835
830
// Update pointers to present_k and present_v.
836
831
data.k = data.present_key ;
837
832
data.v = data.present_value ;
838
- } else { // nullptr == data.past_key && nullptr != data.present_key
833
+ } else { // nullptr == data.past_key && nullptr != data.present_key
839
834
if (data.k != data.present_key ) {
840
835
int64_t k_size = (int64_t )batch_size * num_heads * total_sequence_length * qk_head_size;
841
836
cudaMemcpyAsync (data.present_key , data.k , k_size * sizeof (T), cudaMemcpyDeviceToDevice, stream);
@@ -889,7 +884,7 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i
889
884
return Status::OK ();
890
885
}
891
886
892
- if (combined_key_value) { // Attention op
887
+ if (combined_key_value) { // Attention op
893
888
assert (data.gemm_buffer != nullptr );
894
889
895
890
if (data.present != data.past ) {
@@ -924,9 +919,9 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i
924
919
constexpr bool is_past_kv_bnsh_format = true ;
925
920
constexpr bool is_new_kv_bnsh_format = true ;
926
921
ORT_RETURN_IF_ERROR (LaunchConcatKVInPlace (
927
- batch_size, num_heads, qk_head_size, parameters.max_sequence_length ,
928
- data.seqlens_k_total , nullptr , parameters.sequence_length , data.k , data.v , data.present_key , data.present_value ,
929
- is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, max_threads_per_block));
922
+ batch_size, num_heads, qk_head_size, parameters.max_sequence_length ,
923
+ data.seqlens_k_total , nullptr , parameters.sequence_length , data.k , data.v , data.present_key , data.present_value ,
924
+ is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, max_threads_per_block));
930
925
931
926
data.k = data.present_key ;
932
927
data.v = data.present_value ;
@@ -981,13 +976,13 @@ Status QkvToContext(
981
976
982
977
if (!parameters.past_present_share_buffer ) {
983
978
ORT_RETURN_IF_ERROR (ConcatPastToPresent<T>(batch_size, num_heads, qk_head_size, v_head_size,
984
- sequence_length, total_sequence_length,
985
- stream, max_threads_per_block, data));
979
+ sequence_length, total_sequence_length,
980
+ stream, max_threads_per_block, data));
986
981
987
982
} else { // past_present_share_buffer
988
983
ORT_RETURN_IF_ERROR (PastPresentBufferShare<T>(batch_size, num_heads, qk_head_size, v_head_size,
989
- sequence_length, fused_runner,
990
- parameters, data, stream, max_threads_per_block));
984
+ sequence_length, fused_runner,
985
+ parameters, data, stream, max_threads_per_block));
991
986
}
992
987
993
988
// Q, K and V are ready now
@@ -1078,24 +1073,24 @@ template Status QkvToContext<half, float>(
1078
1073
AttentionData<half>& data);
1079
1074
1080
1075
template Status LaunchDecoderMaskedMultiHeadAttention<float , float >(
1081
- const DecoderMaskedMultiHeadAttentionParameters& parameters,
1082
- cudaStream_t stream,
1083
- const int head_size);
1076
+ const DecoderMaskedMultiHeadAttentionParameters& parameters,
1077
+ cudaStream_t stream,
1078
+ const int head_size);
1084
1079
1085
1080
template Status LaunchDecoderMaskedMultiHeadAttention<float , half>(
1086
- const DecoderMaskedMultiHeadAttentionParameters& parameters,
1087
- cudaStream_t stream,
1088
- const int head_size);
1081
+ const DecoderMaskedMultiHeadAttentionParameters& parameters,
1082
+ cudaStream_t stream,
1083
+ const int head_size);
1089
1084
1090
1085
template Status LaunchDecoderMaskedMultiHeadAttention<uint16_t , float >(
1091
- const DecoderMaskedMultiHeadAttentionParameters& parameters,
1092
- cudaStream_t stream,
1093
- const int head_size);
1086
+ const DecoderMaskedMultiHeadAttentionParameters& parameters,
1087
+ cudaStream_t stream,
1088
+ const int head_size);
1094
1089
1095
1090
template Status LaunchDecoderMaskedMultiHeadAttention<uint16_t , half>(
1096
- const DecoderMaskedMultiHeadAttentionParameters& parameters,
1097
- cudaStream_t stream,
1098
- const int head_size);
1091
+ const DecoderMaskedMultiHeadAttentionParameters& parameters,
1092
+ cudaStream_t stream,
1093
+ const int head_size);
1099
1094
1100
1095
} // namespace cuda
1101
1096
} // namespace contrib
0 commit comments