Skip to content

Commit bf8855b

Browse files
wangyemsYour Name
and
Your Name
authored
Support Smooth Softmax in fmha (#21885)
### Description <!-- Describe your changes. --> refer to #21867 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Your Name <[email protected]>
1 parent ef073fd commit bf8855b

9 files changed

+66
-9
lines changed
+55-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,64 @@
1+
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
2+
index 4c80f549..34327633 100644
3+
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
4+
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
5+
@@ -221,6 +221,8 @@ struct AttentionKernel {
6+
int32_t num_batches = 0;
7+
int32_t num_heads = 0;
8+
9+
+ bool use_smooth_softmax = false;
10+
+
11+
// dropout
12+
bool use_dropout = false;
13+
unsigned long long dropout_batch_head_rng_offset = 0;
14+
@@ -897,7 +899,8 @@ struct AttentionKernel {
15+
p.num_keys - iter_key_start,
16+
iter_key_start == 0,
17+
iteratorC_tile_offset,
18+
- kSupportsBias ? 1.0f : p.scale);
19+
+ kSupportsBias ? 1.0f : p.scale,
20+
+ p.use_smooth_softmax);
21+
22+
// Output results to shared-memory
23+
int warp_idx_mn_0 = my_warp_id %
24+
@@ -1166,7 +1169,8 @@ struct AttentionKernel {
25+
int max_col,
26+
bool is_first,
27+
typename WarpIteratorC::TensorCoord const& tile_offset,
28+
- float scaling) {
29+
+ float scaling,
30+
+ bool use_smooth_softmax) {
31+
/* Iterates on the accumulator and corresponding position on result matrix
32+
33+
(1) Update `mi[r]` to the max value of the row `r`
34+
@@ -1257,7 +1261,7 @@ struct AttentionKernel {
35+
accum_t mi_row, total_row;
36+
LambdaIterator::iterateRows(
37+
lane_offset,
38+
- [&](int accum_m) { mi_row = mi[accum_m]; },
39+
+ [&](int accum_m) { mi_row = mi[accum_m];},
40+
[&](int accum_m, int accum_n, int idx) {
41+
frag[idx] =
42+
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
43+
@@ -1294,7 +1298,7 @@ struct AttentionKernel {
44+
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
45+
total_row += addition_storage[id + kQueriesPerBlock * i];
46+
}
47+
- s_prime[id] = total_row;
48+
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
49+
}
50+
}
51+
152
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
253
index 964d2ff3..b366bc14 100644
354
--- a/include/cutlass/functional.h
455
+++ b/include/cutlass/functional.h
556
@@ -39,6 +39,7 @@
657
#include "cutlass/numeric_types.h"
7-
58+
859
#include <cuda_runtime.h>
960
+#include <cuda_fp16.h>
10-
61+
1162
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
1263
#include <mma.h>
1364
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
@@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644
1970
return reinterpret_cast<half_t const &>(result);
2071
+#else
2172
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
22-
+#endif
73+
+#endif
2374
#else
2475
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
25-
#endif
76+
#endif

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

+1
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ Status EfficientAttention(
415415
p.v_head_size = parameters.v_head_size;
416416
p.causal = parameters.is_unidirectional;
417417
p.scale = scale;
418+
p.use_smooth_softmax = false;
418419

419420
if (nullptr == data.mask_index) {
420421
p.seqlen_k_ptr = nullptr;

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

+2
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
220220
p.bias_strideM = 0;
221221
p.bias_strideB = 0;
222222
}
223+
224+
p.use_smooth_softmax = params.use_smooth_softmax;
223225
}
224226

225227
auto kernel_fn = attention_kernel_batched_impl<Attention>;

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct MemoryEfficientAttentionParams {
2525
int32_t qk_head_size;
2626
int32_t v_head_size;
2727
bool causal;
28+
bool use_smooth_softmax;
2829

2930
float scale;
3031

onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
153153
#if USE_MEMORY_EFFICIENT_ATTENTION
154154
int sm = (device_prop.major * 10) + device_prop.minor;
155155
bool use_memory_efficient_attention =
156-
!use_smooth_softmax_ &&
157156
!use_flash_attention &&
158157
!disable_memory_efficient_attention_ &&
159158
local_window_size_ == -1 &&

onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

+3-2
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,8 @@ Status FlashAttention(
678678
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
679679
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
680680
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
681-
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
682-
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
681+
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
682+
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
683683
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));
684684

685685
// if (parameters.left_padding && parameters.is_prompt) {
@@ -843,6 +843,7 @@ Status EfficientAttention(
843843
: nullptr;
844844
p.stream = stream;
845845
p.has_custom_right_padding = true;
846+
p.use_smooth_softmax = parameters.use_smooth_softmax;
846847
run_memory_efficient_attention(p);
847848

848849
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);

onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu

+1
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ Status FusedScaledDotProductAttentionCutlass(
515515
p.qk_head_size = parameters.head_size;
516516
p.v_head_size = parameters.v_head_size;
517517
p.causal = false;
518+
p.use_smooth_softmax = false;
518519
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
519520
: parameters.scale;
520521
p.seqlen_k_ptr = nullptr;

onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu

+1
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ Status FusedAttentionCutlass(
693693
p.qk_head_size = parameters.head_size;
694694
p.v_head_size = parameters.v_head_size;
695695
p.causal = false;
696+
p.use_smooth_softmax = false;
696697
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
697698
: parameters.scale;
698699
p.seqlen_k_ptr = nullptr;

onnxruntime/test/python/transformers/test_flash_attn_cuda.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2219,7 +2219,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave
22192219
rotary=rotary,
22202220
rotary_interleaved=rotary_interleaved,
22212221
packed=packed,
2222-
use_smooth_softmax=False,
2222+
use_smooth_softmax=True,
22232223
)
22242224

22252225
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
@@ -2263,7 +2263,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved,
22632263
rotary=rotary,
22642264
rotary_interleaved=rotary_interleaved,
22652265
packed=packed,
2266-
use_smooth_softmax=False,
2266+
use_smooth_softmax=True,
22672267
)
22682268
parity_check_gqa_past_no_buff(
22692269
config,

0 commit comments

Comments
 (0)