Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attention bias broadcast #24017

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class AttentionCPUBase : public AttentionBase {
// Here we handle the broadcast of batch_size and num_heads dimensions.
ptrdiff_t attn_bias_offset = 0;
if (attn_bias_dims[0] != 1) {
attn_bias_offset += SafeInt<ptrdiff_t>(batch_index) * num_heads_ * probs_matrix_size;
attn_bias_offset += SafeInt<ptrdiff_t>(batch_index) * attn_bias_dims[1] * probs_matrix_size;
}
if (attn_bias_dims[1] != 1) {
attn_bias_offset += head_index * probs_matrix_size;
Expand Down
17 changes: 9 additions & 8 deletions onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ namespace attention_softmax_cuda {
// grid size is (num_heads * sequence_length, batch_size, 1)
// input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length)
// bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
#define DECLARE_SOFTMAX_VARS() \
[[maybe_unused]] const int s = blockIdx.x % sequence_length; \
const int b = blockIdx.y; \
int64_t offset = static_cast<int64_t>(b * gridDim.x + blockIdx.x) * static_cast<int64_t>(total_sequence_length); \
[[maybe_unused]] int64_t bias_offset = 0; \
if constexpr (HAS_BIAS) { \
const int j = (broadcast_attn_bias_dim_0 ? 0 : (b * gridDim.x)) + (broadcast_attn_bias_dim_1 ? s : blockIdx.x); \
bias_offset = static_cast<int64_t>(j) * static_cast<int64_t>(total_sequence_length); \
#define DECLARE_SOFTMAX_VARS() \
[[maybe_unused]] const int s = blockIdx.x % sequence_length; \
const int b = blockIdx.y; \
int64_t offset = static_cast<int64_t>(b * gridDim.x + blockIdx.x) * static_cast<int64_t>(total_sequence_length); \
[[maybe_unused]] int64_t bias_offset = 0; \
if constexpr (HAS_BIAS) { \
const int j = (broadcast_attn_bias_dim_0 ? 0 : (b * (broadcast_attn_bias_dim_1 ? sequence_length : gridDim.x))) + \
(broadcast_attn_bias_dim_1 ? s : blockIdx.x); \
bias_offset = static_cast<int64_t>(j) * static_cast<int64_t>(total_sequence_length); \
}

// This kernel is for non causal, attention mask 1D or None, and total_sequence_length > 1024.
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ namespace attention_softmax_cuda {

template <typename T>
Status ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const int sequence_length,
const int batch_size, const int num_heads, const T* rel_pos_bias,
const int batch_size, const int num_heads, const T* attn_bias,
const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1,
T* input, T* output, bool causal);

template <typename T>
Status ComputeSoftmaxWithCumSeqLength(
const T* input,
const T* rel_pos_bias,
const T* attn_bias,
const bool broadcast_attn_bias_dim_0,
const bool broadcast_attn_bias_dim_1,
const int32_t* cum_seq_length,
Expand All @@ -34,7 +34,7 @@ Status ComputeSoftmaxWithMask1D(cudaStream_t stream,
const int num_heads,
const int* mask_index,
const int* mask_start,
const T* rel_pos_bias,
const T* attn_bias,
const bool broadcast_attn_bias_dim_0,
const bool broadcast_attn_bias_dim_1,
const T* input,
Expand All @@ -49,7 +49,7 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream,
const int num_heads,
const int* attention_mask,
const bool* key_padding_mask,
const T* rel_pos_bias,
const T* attn_bias,
const bool broadcast_attn_bias_dim_0,
const bool broadcast_attn_bias_dim_1,
const T* input,
Expand Down
66 changes: 33 additions & 33 deletions onnxruntime/contrib_ops/rocm/bert/attention_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
__device__ inline void Softmax(const int all_sequence_length,
const int valid_end,
const int valid_start,
const T* add_before_softmax,
const T* attn_bias,
const T* input,
T* output) {
using BlockReduce = hipcub::BlockReduce<float, TPB>;
Expand All @@ -59,9 +59,9 @@
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
const int index = offset + i;
float input_at_idx = add_before_softmax == nullptr
float input_at_idx = attn_bias == nullptr
? static_cast<float>(input[index])
: static_cast<float>(input[index] + add_before_softmax[index]);
: static_cast<float>(input[index] + attn_bias[index]);
if (thread_data_max < input_at_idx) {
thread_data_max = input_at_idx;
}
Expand All @@ -80,7 +80,7 @@
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
const int index = offset + i;
float val = add_before_softmax == nullptr ? input[index] : input[index] + add_before_softmax[index];
float val = attn_bias == nullptr ? input[index] : input[index] + attn_bias[index];
thread_data_sum += expf(val - max_block);
}
}
Expand All @@ -93,9 +93,9 @@

for (int i = threadIdx.x; i < all_sequence_length; i += TPB) {
const int index = offset + i;
float input_at_idx = add_before_softmax == nullptr
float input_at_idx = attn_bias == nullptr
? static_cast<float>(input[index])
: static_cast<float>(input[index] + add_before_softmax[index]);
: static_cast<float>(input[index] + attn_bias[index]);
const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f;
output[index] = T(val);
}
Expand All @@ -106,7 +106,7 @@
const int sequence_length,
const int valid_end,
const int valid_start,
const T* add_before_softmax,
const T* attn_bias,
const T* input,
T* output,
bool causal) {
Expand Down Expand Up @@ -137,9 +137,9 @@
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
float input_data = add_before_softmax == nullptr
float input_data = attn_bias == nullptr
? static_cast<float>(input[index])
: static_cast<float>(input[index] + add_before_softmax[index]);
: static_cast<float>(input[index] + attn_bias[index]);
float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F);
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end);

Expand Down Expand Up @@ -178,7 +178,7 @@
const int3 attention_mask_strides,
const int* attention_mask, // 2D, 3D or 4D attention mask
const bool* key_padding_mask,
const T* add_before_softmax,
const T* attn_bias,
const T* input,
T* output,
const bool causal,
Expand Down Expand Up @@ -225,8 +225,8 @@
}
}

if (add_before_softmax != nullptr) {
thread_data += float(add_before_softmax[index]);
if (attn_bias != nullptr) {
thread_data += float(attn_bias[index]);

Check warning on line 229 in onnxruntime/contrib_ops/rocm/bert/attention_softmax.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/rocm/bert/attention_softmax.h:229: Using deprecated casting style. Use static_cast<float>(...) instead [readability/casting] [4]
}
}

Expand Down Expand Up @@ -264,50 +264,50 @@

template <typename T, unsigned TPB>
__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length,
const T* add_before_softmax, const T* input, T* output, bool causal) {
const T* attn_bias, const T* input, T* output, bool causal) {
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0,
add_before_softmax, input, output, causal);
attn_bias, input, output, causal);
}

template <typename T, unsigned TPB>
__global__ void SoftmaxKernel(const int all_sequence_length, const T* add_before_softmax, const T* input, T* output) {
Softmax<T, TPB>(all_sequence_length, all_sequence_length, 0, add_before_softmax, input, output);
__global__ void SoftmaxKernel(const int all_sequence_length, const T* attn_bias, const T* input, T* output) {
Softmax<T, TPB>(all_sequence_length, all_sequence_length, 0, attn_bias, input, output);
}

template <typename T>
Status ComputeSoftmax(
hipStream_t stream,
const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const T* add_before_softmax, const T* input, T* output, bool causal) {
const T* attn_bias, const T* input, T* output, bool causal) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
all_sequence_length, sequence_length, add_before_softmax, input, output, causal);
all_sequence_length, sequence_length, attn_bias, input, output, causal);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
all_sequence_length, sequence_length, add_before_softmax, input, output, causal);
all_sequence_length, sequence_length, attn_bias, input, output, causal);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
all_sequence_length, sequence_length, add_before_softmax, input, output, causal);
all_sequence_length, sequence_length, attn_bias, input, output, causal);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
all_sequence_length, sequence_length, add_before_softmax, input, output, causal);
all_sequence_length, sequence_length, attn_bias, input, output, causal);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
all_sequence_length, sequence_length, add_before_softmax, input, output, causal);
all_sequence_length, sequence_length, attn_bias, input, output, causal);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
all_sequence_length, sequence_length, add_before_softmax, input, output, causal);
all_sequence_length, sequence_length, attn_bias, input, output, causal);
} else if (!causal) {
const int blockSize = 1024;
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
all_sequence_length, add_before_softmax, input, output);
all_sequence_length, attn_bias, input, output);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024.");
}
Expand All @@ -318,7 +318,7 @@
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length,
const int* mask_end, const int* mask_start,
const T* add_before_softmax, const T* input, T* output,
const T* attn_bias, const T* input, T* output,
bool causal) {
__shared__ int start_position;
__shared__ int end_position;
Expand All @@ -337,12 +337,12 @@
__syncthreads();

SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, end_position, start_position,
add_before_softmax, input, output, causal);
attn_bias, input, output, causal);
}

template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* mask_end, const int* mask_start,
const T* add_before_softmax, const T* input, T* output) {
const T* attn_bias, const T* input, T* output) {
__shared__ int start_position;
__shared__ int end_position;

Expand All @@ -359,21 +359,21 @@
}
__syncthreads();

Softmax<T, TPB>(all_sequence_length, end_position, start_position, add_before_softmax, input, output);
Softmax<T, TPB>(all_sequence_length, end_position, start_position, attn_bias, input, output);
}

template <typename T>
Status ComputeSoftmaxWithMask1D(
hipStream_t stream,
const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* mask_index, const int* mask_start,
const T* add_before_softmax, const T* input, T* output, const bool causal) {
const T* attn_bias, const T* input, T* output, const bool causal) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);

#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \
MaskedSoftmaxKernelSmall<T, block_size><<<grid, block_size, 0, stream>>>( \
all_sequence_length, sequence_length, mask_index, mask_start, \
add_before_softmax, input, output, causal);
attn_bias, input, output, causal);

if (all_sequence_length <= 32) {
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32);
Expand All @@ -391,7 +391,7 @@
const int blockSize = 1024;
MaskedSoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
all_sequence_length, mask_index, mask_start,
add_before_softmax, input, output);
attn_bias, input, output);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024.");
}
Expand All @@ -410,7 +410,7 @@
const int3 attention_mask_strides,
const int* attention_mask,
const bool* key_padding_mask,
const T* add_before_softmax,
const T* attn_bias,
const T* input,
T* output,
const bool causal,
Expand All @@ -426,7 +426,7 @@
#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \
SoftmaxWithRawMaskSmallKernel<T, block_size><<<grid, block_size, 0, stream>>>( \
all_sequence_length, sequence_length, attention_mask_strides, \
attention_mask, key_padding_mask, add_before_softmax, input, out, \
attention_mask, key_padding_mask, attn_bias, input, out, \
causal, rsqrt_head_size, \
use_persistent_softmax, mask_filter_value);

Expand Down
Loading
Loading