1
+ diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
2
+ index 4c80f549..34327633 100644
3
+ --- a/examples/41_fused_multi_head_attention/kernel_forward.h
4
+ +++ b/examples/41_fused_multi_head_attention/kernel_forward.h
5
+ @@ -221,6 +221,8 @@ struct AttentionKernel {
6
+ int32_t num_batches = 0;
7
+ int32_t num_heads = 0;
8
+
9
+ + bool use_smooth_softmax = false;
10
+ +
11
+ // dropout
12
+ bool use_dropout = false;
13
+ unsigned long long dropout_batch_head_rng_offset = 0;
14
+ @@ -897,7 +899,8 @@ struct AttentionKernel {
15
+ p.num_keys - iter_key_start,
16
+ iter_key_start == 0,
17
+ iteratorC_tile_offset,
18
+ - kSupportsBias ? 1.0f : p.scale);
19
+ + kSupportsBias ? 1.0f : p.scale,
20
+ + p.use_smooth_softmax);
21
+
22
+ // Output results to shared-memory
23
+ int warp_idx_mn_0 = my_warp_id %
24
+ @@ -1166,7 +1169,8 @@ struct AttentionKernel {
25
+ int max_col,
26
+ bool is_first,
27
+ typename WarpIteratorC::TensorCoord const& tile_offset,
28
+ - float scaling) {
29
+ + float scaling,
30
+ + bool use_smooth_softmax) {
31
+ /* Iterates on the accumulator and corresponding position on result matrix
32
+
33
+ (1) Update `mi[r]` to the max value of the row `r`
34
+ @@ -1257,7 +1261,7 @@ struct AttentionKernel {
35
+ accum_t mi_row, total_row;
36
+ LambdaIterator::iterateRows(
37
+ lane_offset,
38
+ - [&](int accum_m) { mi_row = mi[accum_m]; },
39
+ + [&](int accum_m) { mi_row = mi[accum_m];},
40
+ [&](int accum_m, int accum_n, int idx) {
41
+ frag[idx] =
42
+ (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
43
+ @@ -1294,7 +1298,7 @@ struct AttentionKernel {
44
+ for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
45
+ total_row += addition_storage[id + kQueriesPerBlock * i];
46
+ }
47
+ - s_prime[id] = total_row;
48
+ + s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
49
+ }
50
+ }
51
+
1
52
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
2
53
index 964d2ff3..b366bc14 100644
3
54
--- a/include/cutlass/functional.h
4
55
+++ b/include/cutlass/functional.h
5
56
@@ -39,6 +39,7 @@
6
57
#include "cutlass/numeric_types.h"
7
-
58
+
8
59
#include <cuda_runtime.h>
9
60
+ #include <cuda_fp16.h>
10
-
61
+
11
62
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
12
63
#include <mma.h>
13
64
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
@@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644
19
70
return reinterpret_cast<half_t const &>(result);
20
71
+ #else
21
72
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
22
- + #endif
73
+ + #endif
23
74
#else
24
75
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
25
- #endif
76
+ #endif
0 commit comments