@@ -28,7 +28,7 @@ limitations under the License.
28
28
#include " core/providers/cuda/cuda_common.h"
29
29
#include " longformer_attention_impl.h"
30
30
#include " attention_impl.h"
31
- #include " attention_softmax .h"
31
+ #include " longformer_attention_softmax .h"
32
32
33
33
using namespace onnxruntime ::cuda;
34
34
using namespace cub ;
@@ -53,11 +53,8 @@ namespace cuda {
53
53
// [SoftmaxSpace: see below] [Q:BxNxSxH] [K:BxNxSxH] [V:BxNxSxH] [Global_Q:BxNxSxH] [Global_K:BxNxSxH] [Global_V:BxNxSxH]
54
54
// where Global_Q, Global_K and Global_V are optional. They are not allocated when there is no global token.
55
55
//
56
- // It is feasible to use compact format for Global_Q with shape BxNxGxH to save space. We do not use compact format for now.
57
- //
58
56
// SoftmaxSpace layout:
59
57
// [scratch1: (5S-3W)*W*N*B][scratch2: size_t 20]
60
- //
61
58
// Scratch1 has 5 buffers for local and global attention calculation.
62
59
// Scratch2 has 5 input pointers, 5 output pointers, 5 buffer sizes and 5 strides related to scratch1.
63
60
@@ -74,10 +71,17 @@ size_t GetLongformerSoftmaxWorkspaceSize(
74
71
int batch_size,
75
72
int num_heads,
76
73
int sequence_length,
77
- int window) {
78
- size_t scratch1_size = GetScratch1Size (element_size, batch_size, num_heads, sequence_length, window);
79
- size_t scratch2_size = 10 * (sizeof (void *) + sizeof (size_t ));
80
- return scratch1_size + scratch2_size;
74
+ int window,
75
+ bool use_fast_kernel) {
76
+ if (!use_fast_kernel) {
77
+ size_t scratch1_size = GetScratch1Size (element_size, batch_size, num_heads, sequence_length, window);
78
+ size_t scratch2_size = 10 * (sizeof (void *) + sizeof (size_t ));
79
+ return scratch1_size + scratch2_size;
80
+ } else {
81
+ // Non-compact layout when environment variable ORT_LONGFORMER_COMPACT_MEMORY=0 is set.
82
+ // [scratch1: BxNxSxS] [scratch2: BxNxSxS]
83
+ return 2 * GetAttentionScratchSize (element_size, batch_size, num_heads, sequence_length, sequence_length);
84
+ }
81
85
}
82
86
83
87
size_t GetLongformerAttentionWorkspaceSize (
@@ -87,8 +91,9 @@ size_t GetLongformerAttentionWorkspaceSize(
87
91
int head_size,
88
92
int sequence_length,
89
93
int max_num_global,
90
- int window) {
91
- size_t softmax_size = GetLongformerSoftmaxWorkspaceSize (element_size, batch_size, num_heads, sequence_length, window);
94
+ int window,
95
+ bool use_fast_kernel) {
96
+ size_t softmax_size = GetLongformerSoftmaxWorkspaceSize (element_size, batch_size, num_heads, sequence_length, window, use_fast_kernel);
92
97
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
93
98
size_t global_qkv_size = max_num_global > 0 ? qkv_size : 0 ;
94
99
return softmax_size + qkv_size + global_qkv_size;
@@ -100,6 +105,7 @@ size_t GetPinnedBufferSize(int batch_size) {
100
105
return sizeof (int ) * batch_size + GetScratch2Size ();
101
106
}
102
107
108
+ // Softmax kernel for compact format
103
109
template <typename T, int blockSize>
104
110
__launch_bounds__ (blockSize)
105
111
__global__ void LongformerSoftmaxKernel (const int * global_attention,
@@ -354,7 +360,6 @@ bool launchSoftmaxKernel(
354
360
cudaStream_t stream,
355
361
cublasHandle_t cublas,
356
362
void * workspace,
357
- size_t softmax_workspace_size,
358
363
const void * q, // transposed Q with shape (B, N, S, H)
359
364
const void * k, // transposed K with shape (B, N, S, H)
360
365
const void * v, // transposed V with shape (B, N, S, H)
@@ -373,10 +378,7 @@ bool launchSoftmaxKernel(
373
378
int num_heads, // number of heads
374
379
int head_size, // hidden size per head
375
380
int window, // one sided window size
376
- int max_num_global, // maximum number of global tokens (G) in all batches
377
381
size_t element_size) { // size of element: 2 for half, and 4 for float
378
- assert (max_num_global <= window);
379
-
380
382
const int * global_count = reinterpret_cast <const int *>(pinned_buffer);
381
383
382
384
bool is_fp16 = (element_size == 2 );
@@ -605,7 +607,10 @@ bool launchSoftmaxKernel(
605
607
resultType,
606
608
algo));
607
609
608
- void * global_q_batch = (char *)global_q + (i * elements_per_batch) * element_size; // For compact format: replace elements_per_batch by num_heads * max_num_global * head_size
610
+ // It is feasible to use compact format for Global_Q with shape BxNxGxH to save space.
611
+ // In that case, elements_per_batch is num_heads * max_num_global * head_size, and stride_per_head is max_num_global * head_size.
612
+
613
+ void * global_q_batch = (char *)global_q + (i * elements_per_batch) * element_size;
609
614
void * global_k_batch = (char *)global_k + (i * elements_per_batch) * element_size;
610
615
qk_batch = (char *)input_pointers[4 ] + (i * buffer_sizes[4 ] * num_heads) * element_size;
611
616
@@ -625,7 +630,7 @@ bool launchSoftmaxKernel(
625
630
global_q_batch, // B
626
631
Btype, // B type
627
632
head_size, // ldb
628
- stride_per_head, // strideB. For compact format: max_num_global * head_size.
633
+ stride_per_head, // strideB.
629
634
beta_0, // beta
630
635
qk_batch, // C
631
636
Ctype, // C type
@@ -827,8 +832,9 @@ bool LongformerQkvToContext(
827
832
const T* global_input, const int * global_attention,
828
833
const int * global_index, const int * batch_global_num, const int max_num_global,
829
834
void * pinned_buffer, T* workspace,
830
- T* output) {
831
- size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize (element_size, batch_size, num_heads, sequence_length, window);
835
+ T* output,
836
+ size_t softmax_workspace_size,
837
+ bool use_fast_kernel) {
832
838
T* qkv = reinterpret_cast <T*>((char *)workspace + softmax_workspace_size);
833
839
834
840
// Number of elements in Q, K, V, Global_Q, Global_K or Global_V are same: BxNxSxH
@@ -862,33 +868,62 @@ bool LongformerQkvToContext(
862
868
const float rsqrt_head_size = 1 .f / sqrt (static_cast <float >(head_size));
863
869
864
870
T* temp_output = qkv; // Q will be overwritten
865
- if (!launchSoftmaxKernel (
866
- stream,
867
- cublas,
868
- workspace,
869
- softmax_workspace_size,
870
- q, // Transposed Q with shape B x N x S x H
871
- k, // Transposed K with shape B x N x S x H
872
- v, // Transposed V with shape B x N x S x H
873
- attention_mask, // Attention mask flags with shape B x S
874
- global_q, // Transposed global Q with shape B x N x S x H.
875
- global_k, // Transposed global K with shape B x N x S x H
876
- global_v, // Transposed global V with shape B x N x S x H
877
- global_attention, // Global attention flags with shape B x S
878
- global_index, // Global index with shape B x S
879
- batch_global_num, // Number of global token per batch with shape B x 1
880
- pinned_buffer, // Pinned Memory Buffer
881
- temp_output, // Output with shape B x N x S x H
882
- rsqrt_head_size, // Scaler
883
- batch_size, // Batch size
884
- sequence_length, // Sequence length
885
- num_heads, // Number of attention heads
886
- head_size, // Hidden size per head
887
- window, // Half (one-sided) window size
888
- max_num_global, // Maximum number of global tokens (G)
889
- element_size)) {
890
- return false ;
871
+
872
+ if (use_fast_kernel) {
873
+ if (!launchSoftmaxFastKernel (
874
+ stream,
875
+ cublas,
876
+ workspace, // softmax space
877
+ q, // transposed Q with shape (B, N, S, H)
878
+ k, // transposed K with shape (B, N, S, H)
879
+ v, // transposed V with shape (B, N, S, H)
880
+ attention_mask, // attention mask with shape (B, S), with value 0.0 not masked, and -10000.0 masked.
881
+ global_q, // Q for global tokens with shape (B, N, S, H)
882
+ global_k, // K for global tokens with shape (B, N, S, H)
883
+ global_v, // V for global tokens with shape (B, N, S, H)
884
+ global_attention, // global attention with shape (B, S), with value 0 for local attention and 1 for global attention.
885
+ global_index, // Global index with shape (B, S)
886
+ batch_global_num, // Number of global tokens per batch with shape (B, 1)
887
+ pinned_buffer, // Pinned memory in CPU. Number of global tokens per batch with shape (B, 1)
888
+ temp_output, // output with shape (B, N, S, H)
889
+ rsqrt_head_size, // scalar
890
+ batch_size, // batch size
891
+ sequence_length, // sequence length
892
+ num_heads, // number of heads
893
+ head_size, // hidden size per head
894
+ window, // Half (one-sided) window size
895
+ element_size)) {
896
+ return false ;
897
+ }
898
+ } else {
899
+ assert (max_num_global <= window);
900
+ if (!launchSoftmaxKernel (
901
+ stream,
902
+ cublas,
903
+ workspace, // softmax space
904
+ q, // Transposed Q with shape B x N x S x H
905
+ k, // Transposed K with shape B x N x S x H
906
+ v, // Transposed V with shape B x N x S x H
907
+ attention_mask, // Attention mask flags with shape B x S. Value -10000.0 means masked, and 0.0 not mased.
908
+ global_q, // Transposed global Q with shape B x N x S x H.
909
+ global_k, // Transposed global K with shape B x N x S x H
910
+ global_v, // Transposed global V with shape B x N x S x H
911
+ global_attention, // Global attention flags with shape B x S
912
+ global_index, // Global index with shape B x S
913
+ batch_global_num, // Number of global token per batch with shape B x 1
914
+ pinned_buffer, // Pinned Memory Buffer
915
+ temp_output, // Output with shape B x N x S x H
916
+ rsqrt_head_size, // Scaler
917
+ batch_size, // Batch size
918
+ sequence_length, // Sequence length
919
+ num_heads, // Number of attention heads
920
+ head_size, // Hidden size per head
921
+ window, // Half (one-sided) window size
922
+ element_size)) {
923
+ return false ;
924
+ }
891
925
}
926
+
892
927
893
928
// The temp_output is BxNxSxH, transpose it to final output BxSxNxH
894
929
return LaunchTransCtx (stream, sequence_length, batch_size, head_size, num_heads, temp_output, output);
@@ -913,8 +948,10 @@ bool LaunchLongformerAttentionKernel(
913
948
int head_size,
914
949
int window,
915
950
int max_num_global,
916
- const size_t element_size) {
951
+ const size_t element_size,
952
+ bool use_fast_kernel) {
917
953
CublasMathModeSetter helper (device_prop, cublas, CUBLAS_TENSOR_OP_MATH);
954
+ size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize (element_size, batch_size, num_heads, sequence_length, window, use_fast_kernel);
918
955
if (element_size == 2 ) {
919
956
return LongformerQkvToContext (cublas, stream,
920
957
batch_size, sequence_length, num_heads, head_size, window, element_size,
@@ -927,7 +964,9 @@ bool LaunchLongformerAttentionKernel(
927
964
max_num_global,
928
965
pinned_buffer,
929
966
reinterpret_cast <half*>(workspace),
930
- reinterpret_cast <half*>(output));
967
+ reinterpret_cast <half*>(output),
968
+ softmax_workspace_size,
969
+ use_fast_kernel);
931
970
} else {
932
971
return LongformerQkvToContext (cublas, stream,
933
972
batch_size, sequence_length, num_heads, head_size, window, element_size,
@@ -940,7 +979,9 @@ bool LaunchLongformerAttentionKernel(
940
979
max_num_global,
941
980
pinned_buffer,
942
981
reinterpret_cast <float *>(workspace),
943
- reinterpret_cast <float *>(output));
982
+ reinterpret_cast <float *>(output),
983
+ softmax_workspace_size,
984
+ use_fast_kernel);
944
985
}
945
986
}
946
987
0 commit comments