Skip to content

Commit 4bbea91

Browse files
oliviajaintianleiwuwangyemsOri Levarihariharans29
authored
Cherry Pick for ORT 1.7.0 (microsoft#6812)
Fix longformer parity and perf regression (microsoft#6760) … Adding fp16 support for Einsum Cuda kernel (microsoft#6775) … Update DirectML 1.4.1 to 1.4.2 for ORT 1.7 (microsoft#6780) … Fix regression in constant folding optimizer (microsoft#6795) Update transformers benchmark for transformers 4.3.* and ORT 1.7 (microsoft#6796) … Make keepdims to its default value when adding ReduceMin/ReduceMax fo (microsoft#6788)… … fix issues caused by quantize/calibrate changes (microsoft#6802) 6735 and 6728 already in release branch Co-authored-by: Tianlei Wu <[email protected]> Co-authored-by: Ye Wang <[email protected]> Co-authored-by: Ori Levari <[email protected]> Co-authored-by: Hariharan Seshadri <[email protected]> Co-authored-by: Chi Lo <[email protected]> Co-authored-by: stevenlix <[email protected]>
1 parent 67c478e commit 4bbea91

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1877
-732
lines changed

cmake/external/dml.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
2020
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
2121
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
2222
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
23-
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.4.1)
23+
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.4.2)
2424
set(DML_SHARED_LIB DirectML.dll)
2525

2626
# Restore nuget packages, which will pull down the DirectML redist package

onnxruntime/contrib_ops/cpu/bert/longformer_attention_base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,10 @@ class LongformerAttentionBase {
2525
int window_; // Attention windows length (W). It is half (one-sided) of total window size.
2626
};
2727

28+
namespace longformer {
29+
// Environment variable to give a hint about choosing kernels for less memory or latency.
30+
constexpr const char* kUseCompactMemory = "ORT_LONGFORMER_COMPACT_MEMORY";
31+
} // namespace longformer
32+
2833
} // namespace contrib
2934
} // namespace onnxruntime

onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/framework/tensorprotoutils.h"
66
#include "core/providers/cuda/cuda_common.h"
77
#include "core/providers/cuda/shared_inc/fpgeneric.h"
8+
#include "core/platform/env_var_utils.h"
89
#include "longformer_global_impl.h"
910
#include "longformer_attention_impl.h"
1011

@@ -49,7 +50,9 @@ class AutoDestoryCudaEvent {
4950
};
5051

5152
template <typename T>
52-
LongformerAttention<T>::LongformerAttention(const OpKernelInfo& info) : CudaKernel(info), LongformerAttentionBase(info) {}
53+
LongformerAttention<T>::LongformerAttention(const OpKernelInfo& info) : CudaKernel(info), LongformerAttentionBase(info) {
54+
use_compact_memory_ = ParseEnvironmentVariableWithDefault<bool>(longformer::kUseCompactMemory, false);
55+
}
5356

5457
template <typename T>
5558
Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
@@ -80,6 +83,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
8083

8184
constexpr size_t element_size = sizeof(T);
8285

86+
// TODO: only calculate once per model.
8387
// Build Global Index
8488
auto global_index_buffer = GetScratchBuffer<int>(batch_size * sequence_length);
8589
auto batch_global_num_buffer = GetScratchBuffer<int>(batch_size);
@@ -148,10 +152,11 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
148152
}
149153
}
150154

151-
// Cuda kernel implementation has a limitation of number of global tokens.
152-
if (max_num_global > window_) {
153-
ORT_THROW("LongformerAttention CUDA operator does not support number of global tokens > attention window.");
154-
}
155+
// Force to use fast kernel in two situations:
156+
// (1) global tokens > windows size. In that case, compact memory kernel cannot be used.
157+
// (2) sequence_length == 2 * attention_window. Use fast kernel to walk around parity issue of compact memory kernel.
158+
// In other case, we will choose according to user's environment variable setting (default is fast kernel).
159+
bool use_fast_kernel = (max_num_global > window_ || sequence_length == 2 * window_ || !use_compact_memory_);
155160

156161
// Fully connection for global projection.
157162
// Note that Q only need handle global query tokens if we split GEMM to global Q/K/V separately.
@@ -172,7 +177,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
172177
&one, reinterpret_cast<CudaT*>(global_gemm_buffer.get()), n, device_prop));
173178
}
174179

175-
size_t workSpaceSize = GetLongformerAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, max_num_global, window_);
180+
size_t workSpaceSize = GetLongformerAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, max_num_global, window_, use_fast_kernel);
176181
auto workspace_buffer = GetScratchBuffer<void>(workSpaceSize);
177182
if (!LaunchLongformerAttentionKernel(
178183
device_prop,
@@ -193,7 +198,8 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
193198
head_size,
194199
window_,
195200
max_num_global,
196-
element_size)) {
201+
element_size,
202+
use_fast_kernel)) {
197203
// Get last error to reset it to cudaSuccess.
198204
CUDA_CALL(cudaGetLastError());
199205
return Status(common::ONNXRUNTIME, common::FAIL);

onnxruntime/contrib_ops/cuda/bert/longformer_attention.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class LongformerAttention final : public CudaKernel, public LongformerAttentionB
1818
public:
1919
LongformerAttention(const OpKernelInfo& info);
2020
Status ComputeInternal(OpKernelContext* context) const override;
21+
22+
private:
23+
bool use_compact_memory_;
2124
};
2225

2326
} // namespace cuda

onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu

Lines changed: 88 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ limitations under the License.
2828
#include "core/providers/cuda/cuda_common.h"
2929
#include "longformer_attention_impl.h"
3030
#include "attention_impl.h"
31-
#include "attention_softmax.h"
31+
#include "longformer_attention_softmax.h"
3232

3333
using namespace onnxruntime::cuda;
3434
using namespace cub;
@@ -53,11 +53,8 @@ namespace cuda {
5353
// [SoftmaxSpace: see below] [Q:BxNxSxH] [K:BxNxSxH] [V:BxNxSxH] [Global_Q:BxNxSxH] [Global_K:BxNxSxH] [Global_V:BxNxSxH]
5454
// where Global_Q, Global_K and Global_V are optional. They are not allocated when there is no global token.
5555
//
56-
// It is feasible to use compact format for Global_Q with shape BxNxGxH to save space. We do not use compact format for now.
57-
//
5856
// SoftmaxSpace layout:
5957
// [scratch1: (5S-3W)*W*N*B][scratch2: size_t 20]
60-
//
6158
// Scratch1 has 5 buffers for local and global attention calculation.
6259
// Scratch2 has 5 input pointers, 5 output pointers, 5 buffer sizes and 5 strides related to scratch1.
6360

@@ -74,10 +71,17 @@ size_t GetLongformerSoftmaxWorkspaceSize(
7471
int batch_size,
7572
int num_heads,
7673
int sequence_length,
77-
int window) {
78-
size_t scratch1_size = GetScratch1Size(element_size, batch_size, num_heads, sequence_length, window);
79-
size_t scratch2_size = 10 * (sizeof(void*) + sizeof(size_t));
80-
return scratch1_size + scratch2_size;
74+
int window,
75+
bool use_fast_kernel) {
76+
if (!use_fast_kernel) {
77+
size_t scratch1_size = GetScratch1Size(element_size, batch_size, num_heads, sequence_length, window);
78+
size_t scratch2_size = 10 * (sizeof(void*) + sizeof(size_t));
79+
return scratch1_size + scratch2_size;
80+
} else {
81+
// Non-compact layout when environment variable ORT_LONGFORMER_COMPACT_MEMORY=0 is set.
82+
// [scratch1: BxNxSxS] [scratch2: BxNxSxS]
83+
return 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length);
84+
}
8185
}
8286

8387
size_t GetLongformerAttentionWorkspaceSize(
@@ -87,8 +91,9 @@ size_t GetLongformerAttentionWorkspaceSize(
8791
int head_size,
8892
int sequence_length,
8993
int max_num_global,
90-
int window) {
91-
size_t softmax_size = GetLongformerSoftmaxWorkspaceSize(element_size, batch_size, num_heads, sequence_length, window);
94+
int window,
95+
bool use_fast_kernel) {
96+
size_t softmax_size = GetLongformerSoftmaxWorkspaceSize(element_size, batch_size, num_heads, sequence_length, window, use_fast_kernel);
9297
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
9398
size_t global_qkv_size = max_num_global > 0 ? qkv_size : 0;
9499
return softmax_size + qkv_size + global_qkv_size;
@@ -100,6 +105,7 @@ size_t GetPinnedBufferSize(int batch_size) {
100105
return sizeof(int) * batch_size + GetScratch2Size();
101106
}
102107

108+
// Softmax kernel for compact format
103109
template <typename T, int blockSize>
104110
__launch_bounds__(blockSize)
105111
__global__ void LongformerSoftmaxKernel(const int* global_attention,
@@ -354,7 +360,6 @@ bool launchSoftmaxKernel(
354360
cudaStream_t stream,
355361
cublasHandle_t cublas,
356362
void* workspace,
357-
size_t softmax_workspace_size,
358363
const void* q, // transposed Q with shape (B, N, S, H)
359364
const void* k, // transposed K with shape (B, N, S, H)
360365
const void* v, // transposed V with shape (B, N, S, H)
@@ -373,10 +378,7 @@ bool launchSoftmaxKernel(
373378
int num_heads, // number of heads
374379
int head_size, // hidden size per head
375380
int window, // one sided window size
376-
int max_num_global, // maximum number of global tokens (G) in all batches
377381
size_t element_size) { // size of element: 2 for half, and 4 for float
378-
assert(max_num_global <= window);
379-
380382
const int* global_count = reinterpret_cast<const int*>(pinned_buffer);
381383

382384
bool is_fp16 = (element_size == 2);
@@ -605,7 +607,10 @@ bool launchSoftmaxKernel(
605607
resultType,
606608
algo));
607609

608-
void* global_q_batch = (char*)global_q + (i * elements_per_batch) * element_size; // For compact format: replace elements_per_batch by num_heads * max_num_global * head_size
610+
// It is feasible to use compact format for Global_Q with shape BxNxGxH to save space.
611+
// In that case, elements_per_batch is num_heads * max_num_global * head_size, and stride_per_head is max_num_global * head_size.
612+
613+
void* global_q_batch = (char*)global_q + (i * elements_per_batch) * element_size;
609614
void* global_k_batch = (char*)global_k + (i * elements_per_batch) * element_size;
610615
qk_batch = (char*)input_pointers[4] + (i * buffer_sizes[4] * num_heads) * element_size;
611616

@@ -625,7 +630,7 @@ bool launchSoftmaxKernel(
625630
global_q_batch, // B
626631
Btype, // B type
627632
head_size, // ldb
628-
stride_per_head, // strideB. For compact format: max_num_global * head_size.
633+
stride_per_head, // strideB.
629634
beta_0, // beta
630635
qk_batch, // C
631636
Ctype, // C type
@@ -827,8 +832,9 @@ bool LongformerQkvToContext(
827832
const T* global_input, const int* global_attention,
828833
const int* global_index, const int* batch_global_num, const int max_num_global,
829834
void* pinned_buffer, T* workspace,
830-
T* output) {
831-
size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize(element_size, batch_size, num_heads, sequence_length, window);
835+
T* output,
836+
size_t softmax_workspace_size,
837+
bool use_fast_kernel) {
832838
T* qkv = reinterpret_cast<T*>((char*)workspace + softmax_workspace_size);
833839

834840
// Number of elements in Q, K, V, Global_Q, Global_K or Global_V are same: BxNxSxH
@@ -862,33 +868,62 @@ bool LongformerQkvToContext(
862868
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
863869

864870
T* temp_output = qkv; // Q will be overwritten
865-
if (!launchSoftmaxKernel(
866-
stream,
867-
cublas,
868-
workspace,
869-
softmax_workspace_size,
870-
q, // Transposed Q with shape B x N x S x H
871-
k, // Transposed K with shape B x N x S x H
872-
v, // Transposed V with shape B x N x S x H
873-
attention_mask, // Attention mask flags with shape B x S
874-
global_q, // Transposed global Q with shape B x N x S x H.
875-
global_k, // Transposed global K with shape B x N x S x H
876-
global_v, // Transposed global V with shape B x N x S x H
877-
global_attention, // Global attention flags with shape B x S
878-
global_index, // Global index with shape B x S
879-
batch_global_num, // Number of global token per batch with shape B x 1
880-
pinned_buffer, // Pinned Memory Buffer
881-
temp_output, // Output with shape B x N x S x H
882-
rsqrt_head_size, // Scaler
883-
batch_size, // Batch size
884-
sequence_length, // Sequence length
885-
num_heads, // Number of attention heads
886-
head_size, // Hidden size per head
887-
window, // Half (one-sided) window size
888-
max_num_global, // Maximum number of global tokens (G)
889-
element_size)) {
890-
return false;
871+
872+
if (use_fast_kernel) {
873+
if (!launchSoftmaxFastKernel(
874+
stream,
875+
cublas,
876+
workspace, // softmax space
877+
q, // transposed Q with shape (B, N, S, H)
878+
k, // transposed K with shape (B, N, S, H)
879+
v, // transposed V with shape (B, N, S, H)
880+
attention_mask, // attention mask with shape (B, S), with value 0.0 not masked, and -10000.0 masked.
881+
global_q, // Q for global tokens with shape (B, N, S, H)
882+
global_k, // K for global tokens with shape (B, N, S, H)
883+
global_v, // V for global tokens with shape (B, N, S, H)
884+
global_attention, // global attention with shape (B, S), with value 0 for local attention and 1 for global attention.
885+
global_index, // Global index with shape (B, S)
886+
batch_global_num, // Number of global tokens per batch with shape (B, 1)
887+
pinned_buffer, // Pinned memory in CPU. Number of global tokens per batch with shape (B, 1)
888+
temp_output, // output with shape (B, N, S, H)
889+
rsqrt_head_size, // scalar
890+
batch_size, // batch size
891+
sequence_length, // sequence length
892+
num_heads, // number of heads
893+
head_size, // hidden size per head
894+
window, // Half (one-sided) window size
895+
element_size)) {
896+
return false;
897+
}
898+
} else {
899+
assert(max_num_global <= window);
900+
if (!launchSoftmaxKernel(
901+
stream,
902+
cublas,
903+
workspace, // softmax space
904+
q, // Transposed Q with shape B x N x S x H
905+
k, // Transposed K with shape B x N x S x H
906+
v, // Transposed V with shape B x N x S x H
907+
attention_mask, // Attention mask flags with shape B x S. Value -10000.0 means masked, and 0.0 not mased.
908+
global_q, // Transposed global Q with shape B x N x S x H.
909+
global_k, // Transposed global K with shape B x N x S x H
910+
global_v, // Transposed global V with shape B x N x S x H
911+
global_attention, // Global attention flags with shape B x S
912+
global_index, // Global index with shape B x S
913+
batch_global_num, // Number of global token per batch with shape B x 1
914+
pinned_buffer, // Pinned Memory Buffer
915+
temp_output, // Output with shape B x N x S x H
916+
rsqrt_head_size, // Scaler
917+
batch_size, // Batch size
918+
sequence_length, // Sequence length
919+
num_heads, // Number of attention heads
920+
head_size, // Hidden size per head
921+
window, // Half (one-sided) window size
922+
element_size)) {
923+
return false;
924+
}
891925
}
926+
892927

893928
// The temp_output is BxNxSxH, transpose it to final output BxSxNxH
894929
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, temp_output, output);
@@ -913,8 +948,10 @@ bool LaunchLongformerAttentionKernel(
913948
int head_size,
914949
int window,
915950
int max_num_global,
916-
const size_t element_size) {
951+
const size_t element_size,
952+
bool use_fast_kernel) {
917953
CublasMathModeSetter helper(device_prop, cublas, CUBLAS_TENSOR_OP_MATH);
954+
size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize(element_size, batch_size, num_heads, sequence_length, window, use_fast_kernel);
918955
if (element_size == 2) {
919956
return LongformerQkvToContext(cublas, stream,
920957
batch_size, sequence_length, num_heads, head_size, window, element_size,
@@ -927,7 +964,9 @@ bool LaunchLongformerAttentionKernel(
927964
max_num_global,
928965
pinned_buffer,
929966
reinterpret_cast<half*>(workspace),
930-
reinterpret_cast<half*>(output));
967+
reinterpret_cast<half*>(output),
968+
softmax_workspace_size,
969+
use_fast_kernel);
931970
} else {
932971
return LongformerQkvToContext(cublas, stream,
933972
batch_size, sequence_length, num_heads, head_size, window, element_size,
@@ -940,7 +979,9 @@ bool LaunchLongformerAttentionKernel(
940979
max_num_global,
941980
pinned_buffer,
942981
reinterpret_cast<float*>(workspace),
943-
reinterpret_cast<float*>(output));
982+
reinterpret_cast<float*>(output),
983+
softmax_workspace_size,
984+
use_fast_kernel);
944985
}
945986
}
946987

onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ size_t GetLongformerAttentionWorkspaceSize(
1818
int head_size,
1919
int sequence_length,
2020
int max_num_global,
21-
int window);
21+
int window,
22+
bool use_fast_kernel);
2223

23-
bool LaunchLongformerAttentionKernel(
24+
bool LaunchLongformerAttentionKernel(
2425
const cudaDeviceProp& device_prop, // Device Properties
2526
cublasHandle_t& cublas, // Cublas handle
2627
cudaStream_t stream, // CUDA stream
@@ -39,7 +40,8 @@ size_t GetLongformerAttentionWorkspaceSize(
3940
int head_size, // Hidden layer size per head (H)
4041
int window, // One sided attention window (W)
4142
int max_num_global, // Maximum number of global tokens (G)
42-
const size_t element_size // Element size of input tensor
43+
const size_t element_size, // Element size of input tensor,
44+
bool use_fast_kernel // Use compact memory
4345
);
4446

4547
} // namespace cuda

0 commit comments

Comments
 (0)