From 5226307d79de28e40e7f6693db98e7ae36f9e079 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 13 Mar 2025 02:41:18 +0000 Subject: [PATCH 1/3] fix attention bias broadcast on dim 1 --- .../contrib_ops/cpu/bert/attention_cpu_base.h | 2 +- .../cuda/bert/attention_softmax.cu | 17 +- .../test/python/transformers/test_mha.py | 152 +++++++++--------- 3 files changed, 90 insertions(+), 81 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 87938f3728750..4345675b7e966 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -208,7 +208,7 @@ class AttentionCPUBase : public AttentionBase { // Here we handle the broadcast of batch_size and num_heads dimensions. ptrdiff_t attn_bias_offset = 0; if (attn_bias_dims[0] != 1) { - attn_bias_offset += SafeInt(batch_index) * num_heads_ * probs_matrix_size; + attn_bias_offset += SafeInt(batch_index) * attn_bias_dims[1] * probs_matrix_size; } if (attn_bias_dims[1] != 1) { attn_bias_offset += head_index * probs_matrix_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index 52f94247a8b2b..04bb571f43fa3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -49,14 +49,15 @@ namespace attention_softmax_cuda { // grid size is (num_heads * sequence_length, batch_size, 1) // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) -#define DECLARE_SOFTMAX_VARS() \ - [[maybe_unused]] const int s = blockIdx.x % sequence_length; \ - const int b = blockIdx.y; \ - int64_t offset = static_cast(b * gridDim.x + blockIdx.x) * static_cast(total_sequence_length); \ - [[maybe_unused]] int64_t bias_offset = 0; \ - if constexpr (HAS_BIAS) { \ - const int j = (broadcast_attn_bias_dim_0 ? 0 : (b * gridDim.x)) + (broadcast_attn_bias_dim_1 ? s : blockIdx.x); \ - bias_offset = static_cast(j) * static_cast(total_sequence_length); \ +#define DECLARE_SOFTMAX_VARS() \ + [[maybe_unused]] const int s = blockIdx.x % sequence_length; \ + const int b = blockIdx.y; \ + int64_t offset = static_cast(b * gridDim.x + blockIdx.x) * static_cast(total_sequence_length); \ + [[maybe_unused]] int64_t bias_offset = 0; \ + if constexpr (HAS_BIAS) { \ + const int j = (broadcast_attn_bias_dim_0 ? 0 : (b * (broadcast_attn_bias_dim_1 ? sequence_length : gridDim.x))) + \ + (broadcast_attn_bias_dim_1 ? s : blockIdx.x); \ + bias_offset = static_cast(j) * static_cast(total_sequence_length); \ } // This kernel is for non causal, attention mask 1D or None, and total_sequence_length > 1024. diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index dc19e3ec95243..8332f9cf7594e 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -41,15 +41,15 @@ def get_provider_support_info(provider: str, use_kv_cache: bool): device_id = torch.cuda.current_device() device = torch.device("cuda", device_id) - dtype = torch.float16 + dtypes = [torch.float16, torch.float] else: assert provider == "CPUExecutionProvider" formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] if not use_kv_cache: formats.append(InputFormats.Q_K_V_BSNH_BNSH_BNSH) device = torch.device("cpu") - dtype = torch.float - return device, dtype, formats + dtypes = [torch.float] + return device, dtypes, formats def get_bias_support(format: InputFormats): @@ -211,10 +211,11 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): return yield + # Lengths of arraies are prime numbers since modulo (% length) is used in non comprehensive mode. batch_sizes = [1, 2, 3] - sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 512] - heads = [1, 3, 4, 16] - head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + sequence_lengths = [1, 16, 127, 128, 256, 384, 512] + heads = [1, 2, 3, 4, 16] + head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 256] mask_formats = [ AttentionMaskFormat.Mask_None, @@ -223,7 +224,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): ] atten_bias_options = get_atten_bias_support() - device, dtype, formats = get_provider_support_info(provider, False) + device, dtypes, formats = get_provider_support_info(provider, False) if comprehensive: sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: @@ -239,30 +240,31 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, ) in atten_bias_options: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - has_bias=has_bias, - mask_format=mask_format, - has_attn_bias=has_attn_bias, - broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, - broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, - ) - yield config + for dtype in dtypes: + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, + ) + yield config else: - test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) + test_cases = 2 * max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): batch_size = batch_sizes[i % len(batch_sizes)] sequence_length = sequence_lengths[i % len(sequence_lengths)] @@ -272,6 +274,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ i % len(atten_bias_options) ] + dtype = dtypes[i % len(dtypes)] for format in formats: for causal in get_causal_support(format): for has_bias in get_bias_support(format): @@ -304,11 +307,13 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): return yield + # Lengths of arraies are prime numbers since modulo (% length) is used in non comprehensive mode. batch_sizes = [1, 2, 3] - sequence_lengths = [1, 15, 16, 255, 256, 512] - heads = [1, 3, 4, 16] - head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - device, dtype, formats = get_provider_support_info(provider, True) + sequence_lengths = [1, 15, 16, 255, 256, 384, 512] + heads = [1, 2, 3, 4, 16] + head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 224, 256] + + device, dtypes, formats = get_provider_support_info(provider, True) mask_formats = [ AttentionMaskFormat.Mask_None, AttentionMaskFormat.Mask_1D_Key_SeqLen, @@ -328,38 +333,39 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): for has_past_input in [True, False]: for mask_format in mask_formats: for has_bias in get_bias_support(format): - for ( - has_attn_bias, - broadcast_attn_bias_dim_0, - broadcast_attn_bias_dim_1, - ) in atten_bias_options: - sequence_length = 1 if has_past_input else past_sequence_length - past_seq_len = past_sequence_length if has_past_input else 0 - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_seq_len, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - has_past_input=has_past_input, - share_past_present_buffer=False, - input_format=format, - has_bias=has_bias, - mask_format=mask_format, - has_attn_bias=has_attn_bias, - broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, - broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, - ) - yield config + for dtype in dtypes: + for ( + has_attn_bias, + broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, + ) in atten_bias_options: + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, + ) + yield config else: - test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) + test_cases = 2 * max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): batch_size = batch_sizes[i % len(batch_sizes)] past_sequence_length = sequence_lengths[i % len(sequence_lengths)] @@ -369,6 +375,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ i % len(atten_bias_options) ] + dtype = dtypes[i % len(dtypes)] + for format in formats: for causal in get_causal_support(format): for has_past_input in [True, False]: @@ -401,7 +409,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): def lean_attention_test_cases(provider: str, comprehensive: bool): - if provider == "CUDAExecutionProvider" and get_compute_capability() < 80: + if provider != "CUDAExecutionProvider" or get_compute_capability() < 80: return yield @@ -409,7 +417,7 @@ def lean_attention_test_cases(provider: str, comprehensive: bool): sequence_lengths = [2, 15, 16, 255, 256, 512, 1024, 2048, 4096, 8192] if comprehensive else [2, 255, 512] heads = [1, 4, 16] if comprehensive else [1, 4] head_sizes = [64, 128] - device, dtype, formats = get_provider_support_info(provider, True) + device, dtypes, formats = get_provider_support_info(provider, True) mask_formats = [AttentionMaskFormat.Mask_None] sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory @@ -433,7 +441,7 @@ def lean_attention_test_cases(provider: str, comprehensive: bool): max_cache_sequence_length=None, provider=provider, device=device, - dtype=dtype, + dtype=dtypes[0], use_kv_cache=True, has_past_input=True, share_past_present_buffer=False, @@ -453,7 +461,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): heads = [4] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if comprehensive else [32, 64] - device, dtype, formats = get_provider_support_info(provider, False) + device, dtypes, formats = get_provider_support_info(provider, False) for format in formats: for causal in get_causal_support(format): @@ -473,7 +481,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): max_cache_sequence_length=None, provider=provider, device=device, - dtype=dtype, + dtype=dtypes[0], use_kv_cache=False, share_past_present_buffer=False, input_format=format, @@ -493,7 +501,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if comprehensive else [32, 64] sequence_length = 1 - device, dtype, formats = get_provider_support_info(provider, True) + device, dtypes, formats = get_provider_support_info(provider, True) for format in formats: for causal in get_causal_support(format): @@ -513,7 +521,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): max_cache_sequence_length=None, provider=provider, device=device, - dtype=dtype, + dtype=dtypes[0], use_kv_cache=True, has_past_input=True, share_past_present_buffer=False, From 426b13751eaf7b2803efc9889447b0fc41383949 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 13 Mar 2025 02:41:32 +0000 Subject: [PATCH 2/3] naming style --- .../contrib_ops/cuda/bert/attention_softmax.h | 8 +-- .../contrib_ops/rocm/bert/attention_softmax.h | 66 +++++++++---------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index f7fab268b4607..126de15362761 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -9,14 +9,14 @@ namespace attention_softmax_cuda { template Status ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const int sequence_length, - const int batch_size, const int num_heads, const T* rel_pos_bias, + const int batch_size, const int num_heads, const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, T* input, T* output, bool causal); template Status ComputeSoftmaxWithCumSeqLength( const T* input, - const T* rel_pos_bias, + const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, @@ -34,7 +34,7 @@ Status ComputeSoftmaxWithMask1D(cudaStream_t stream, const int num_heads, const int* mask_index, const int* mask_start, - const T* rel_pos_bias, + const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, const T* input, @@ -49,7 +49,7 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, const int num_heads, const int* attention_mask, const bool* key_padding_mask, - const T* rel_pos_bias, + const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, const T* input, diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h index 5bcd46f9b9ea8..9f2faa228cf79 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -40,7 +40,7 @@ template __device__ inline void Softmax(const int all_sequence_length, const int valid_end, const int valid_start, - const T* add_before_softmax, + const T* attn_bias, const T* input, T* output) { using BlockReduce = hipcub::BlockReduce; @@ -59,9 +59,9 @@ __device__ inline void Softmax(const int all_sequence_length, for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { const int index = offset + i; - float input_at_idx = add_before_softmax == nullptr + float input_at_idx = attn_bias == nullptr ? static_cast(input[index]) - : static_cast(input[index] + add_before_softmax[index]); + : static_cast(input[index] + attn_bias[index]); if (thread_data_max < input_at_idx) { thread_data_max = input_at_idx; } @@ -80,7 +80,7 @@ __device__ inline void Softmax(const int all_sequence_length, for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { const int index = offset + i; - float val = add_before_softmax == nullptr ? input[index] : input[index] + add_before_softmax[index]; + float val = attn_bias == nullptr ? input[index] : input[index] + attn_bias[index]; thread_data_sum += expf(val - max_block); } } @@ -93,9 +93,9 @@ __device__ inline void Softmax(const int all_sequence_length, for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { const int index = offset + i; - float input_at_idx = add_before_softmax == nullptr + float input_at_idx = attn_bias == nullptr ? static_cast(input[index]) - : static_cast(input[index] + add_before_softmax[index]); + : static_cast(input[index] + attn_bias[index]); const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; output[index] = T(val); } @@ -106,7 +106,7 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, const int sequence_length, const int valid_end, const int valid_start, - const T* add_before_softmax, + const T* attn_bias, const T* input, T* output, bool causal) { @@ -137,9 +137,9 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = add_before_softmax == nullptr + float input_data = attn_bias == nullptr ? static_cast(input[index]) - : static_cast(input[index] + add_before_softmax[index]); + : static_cast(input[index] + attn_bias[index]); float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); @@ -178,7 +178,7 @@ __global__ void SoftmaxWithRawMaskSmallKernel( const int3 attention_mask_strides, const int* attention_mask, // 2D, 3D or 4D attention mask const bool* key_padding_mask, - const T* add_before_softmax, + const T* attn_bias, const T* input, T* output, const bool causal, @@ -225,8 +225,8 @@ __global__ void SoftmaxWithRawMaskSmallKernel( } } - if (add_before_softmax != nullptr) { - thread_data += float(add_before_softmax[index]); + if (attn_bias != nullptr) { + thread_data += float(attn_bias[index]); } } @@ -264,50 +264,50 @@ __global__ void SoftmaxWithRawMaskSmallKernel( template __global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const T* add_before_softmax, const T* input, T* output, bool causal) { + const T* attn_bias, const T* input, T* output, bool causal) { SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, - add_before_softmax, input, output, causal); + attn_bias, input, output, causal); } template -__global__ void SoftmaxKernel(const int all_sequence_length, const T* add_before_softmax, const T* input, T* output) { - Softmax(all_sequence_length, all_sequence_length, 0, add_before_softmax, input, output); +__global__ void SoftmaxKernel(const int all_sequence_length, const T* attn_bias, const T* input, T* output) { + Softmax(all_sequence_length, all_sequence_length, 0, attn_bias, input, output); } template Status ComputeSoftmax( hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const T* add_before_softmax, const T* input, T* output, bool causal) { + const T* attn_bias, const T* input, T* output, bool causal) { const dim3 grid(sequence_length * num_heads, batch_size, 1); if (all_sequence_length <= 32) { const int blockSize = 32; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, add_before_softmax, input, output, causal); + all_sequence_length, sequence_length, attn_bias, input, output, causal); } else if (all_sequence_length <= 64) { const int blockSize = 64; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, add_before_softmax, input, output, causal); + all_sequence_length, sequence_length, attn_bias, input, output, causal); } else if (all_sequence_length <= 128) { const int blockSize = 128; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, add_before_softmax, input, output, causal); + all_sequence_length, sequence_length, attn_bias, input, output, causal); } else if (all_sequence_length <= 256) { const int blockSize = 256; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, add_before_softmax, input, output, causal); + all_sequence_length, sequence_length, attn_bias, input, output, causal); } else if (all_sequence_length <= 512) { const int blockSize = 512; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, add_before_softmax, input, output, causal); + all_sequence_length, sequence_length, attn_bias, input, output, causal); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, add_before_softmax, input, output, causal); + all_sequence_length, sequence_length, attn_bias, input, output, causal); } else if (!causal) { const int blockSize = 1024; SoftmaxKernel<<>>( - all_sequence_length, add_before_softmax, input, output); + all_sequence_length, attn_bias, input, output); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); } @@ -318,7 +318,7 @@ Status ComputeSoftmax( template __global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, - const T* add_before_softmax, const T* input, T* output, + const T* attn_bias, const T* input, T* output, bool causal) { __shared__ int start_position; __shared__ int end_position; @@ -337,12 +337,12 @@ __global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const in __syncthreads(); SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, - add_before_softmax, input, output, causal); + attn_bias, input, output, causal); } template __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* mask_end, const int* mask_start, - const T* add_before_softmax, const T* input, T* output) { + const T* attn_bias, const T* input, T* output) { __shared__ int start_position; __shared__ int end_position; @@ -359,7 +359,7 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* ma } __syncthreads(); - Softmax(all_sequence_length, end_position, start_position, add_before_softmax, input, output); + Softmax(all_sequence_length, end_position, start_position, attn_bias, input, output); } template @@ -367,13 +367,13 @@ Status ComputeSoftmaxWithMask1D( hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* mask_index, const int* mask_start, - const T* add_before_softmax, const T* input, T* output, const bool causal) { + const T* attn_bias, const T* input, T* output, const bool causal) { const dim3 grid(sequence_length * num_heads, batch_size, 1); #define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ MaskedSoftmaxKernelSmall<<>>( \ all_sequence_length, sequence_length, mask_index, mask_start, \ - add_before_softmax, input, output, causal); + attn_bias, input, output, causal); if (all_sequence_length <= 32) { DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); @@ -391,7 +391,7 @@ Status ComputeSoftmaxWithMask1D( const int blockSize = 1024; MaskedSoftmaxKernel<<>>( all_sequence_length, mask_index, mask_start, - add_before_softmax, input, output); + attn_bias, input, output); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); } @@ -410,7 +410,7 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, const int3 attention_mask_strides, const int* attention_mask, const bool* key_padding_mask, - const T* add_before_softmax, + const T* attn_bias, const T* input, T* output, const bool causal, @@ -426,7 +426,7 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, #define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ SoftmaxWithRawMaskSmallKernel<<>>( \ all_sequence_length, sequence_length, attention_mask_strides, \ - attention_mask, key_padding_mask, add_before_softmax, input, out, \ + attention_mask, key_padding_mask, attn_bias, input, out, \ causal, rsqrt_head_size, \ use_persistent_softmax, mask_filter_value); From de9798a80e30a1c44002caf3f3591acde1548760 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 12 Mar 2025 21:13:19 -0700 Subject: [PATCH 3/3] review feedback --- onnxruntime/test/python/transformers/test_mha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 8332f9cf7594e..f6403636e79d9 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -211,7 +211,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): return yield - # Lengths of arraies are prime numbers since modulo (% length) is used in non comprehensive mode. + # Lengths of arrays are prime numbers since modulo (% length) is used in non comprehensive mode. batch_sizes = [1, 2, 3] sequence_lengths = [1, 16, 127, 128, 256, 384, 512] heads = [1, 2, 3, 4, 16] @@ -307,7 +307,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): return yield - # Lengths of arraies are prime numbers since modulo (% length) is used in non comprehensive mode. + # Lengths of arrays are prime numbers since modulo (% length) is used in non comprehensive mode. batch_sizes = [1, 2, 3] sequence_lengths = [1, 15, 16, 255, 256, 384, 512] heads = [1, 2, 3, 4, 16]