@@ -55,15 +55,17 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
55
55
// CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch.
56
56
// ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs.
57
57
// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed.
58
+ // is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed.
58
59
// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t.
59
60
// This allows SoftMax to be fused with a cast immediately following the SoftMax.
61
+ // The mask should have the same shape as input, with a boolean indicate if the value is masked.
60
62
// For instance:
61
63
// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor.
62
64
// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
63
65
// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor.
64
66
65
- template <typename input_t , typename output_t , typename acc_t , int log2_elements, bool is_log_softmax>
66
- __global__ void softmax_warp_forward (output_t *dst, const input_t *src, int batch_size, int stride, int element_count)
67
+ template <typename input_t , typename output_t , typename acc_t , int log2_elements, bool is_log_softmax, bool is_masked = false >
68
+ __global__ void softmax_warp_forward (output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr )
67
69
{
68
70
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
69
71
constexpr int next_power_of_two = 1 << log2_elements;
@@ -84,7 +86,9 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
84
86
85
87
src += first_batch * stride + local_idx;
86
88
dst += first_batch * stride + local_idx;
87
-
89
+ if (is_masked) {
90
+ mask += first_batch * stride + local_idx;
91
+ }
88
92
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
89
93
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
90
94
// the nested loops.
@@ -108,10 +112,23 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
108
112
acc_t max_value[WARP_BATCH];
109
113
#pragma unroll
110
114
for (int i = 0 ; i < WARP_BATCH; ++i) {
115
+ bool is_meaningful_max = false ;
111
116
max_value[i] = elements[i][0 ];
112
117
#pragma unroll
113
- for (int it = 1 ; it < WARP_ITERATIONS; ++it) {
114
- max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
118
+ for (int it = 0 ; it < WARP_ITERATIONS; ++it) {
119
+ if (is_masked) {
120
+ if (mask[i*element_count+it*WARP_SIZE]) {
121
+ max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
122
+ is_meaningful_max = true ;
123
+ }
124
+ } else {
125
+ max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
126
+ }
127
+ }
128
+ if (is_masked) {
129
+ if (!is_meaningful_max) {
130
+ max_value[i] = -std::numeric_limits<acc_t >::infinity ();
131
+ }
115
132
}
116
133
}
117
134
warp_reduce<acc_t , WARP_BATCH, WARP_SIZE, Max>(max_value);
@@ -121,11 +138,22 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
121
138
for (int i = 0 ; i < WARP_BATCH; ++i) {
122
139
#pragma unroll
123
140
for (int it = 0 ; it < WARP_ITERATIONS; ++it) {
124
- if (is_log_softmax) {
125
- sum[i] += std::exp (elements[i][it] - max_value[i]);
141
+ if (!is_masked) {
142
+ if (is_log_softmax) {
143
+ sum[i] += std::exp (elements[i][it] - max_value[i]);
144
+ } else {
145
+ elements[i][it] = std::exp (elements[i][it] - max_value[i]);
146
+ sum[i] += elements[i][it];
147
+ }
126
148
} else {
127
- elements[i][it] = std::exp (elements[i][it] - max_value[i]);
128
- sum[i] += elements[i][it];
149
+ if (mask[i*element_count+it*WARP_SIZE]) {
150
+ if (is_log_softmax) {
151
+ sum[i] += std::exp (elements[i][it] - max_value[i]);
152
+ } else {
153
+ elements[i][it] = std::exp (elements[i][it] - max_value[i]);
154
+ sum[i] += elements[i][it];
155
+ }
156
+ }
129
157
}
130
158
}
131
159
}
@@ -141,6 +169,12 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
141
169
for (int it = 0 ; it < WARP_ITERATIONS; ++it) {
142
170
int element_index = local_idx + it * WARP_SIZE;
143
171
if (element_index < element_count) {
172
+ if (is_masked) {
173
+ if (!mask[i*element_count+it*WARP_SIZE]) {
174
+ dst[i*element_count+it*WARP_SIZE] = 0 ;
175
+ continue ;
176
+ }
177
+ }
144
178
if (is_log_softmax) {
145
179
dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i];
146
180
} else {
@@ -234,8 +268,8 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
234
268
235
269
} // end of anonymous namespace
236
270
237
- template <typename input_t , typename output_t , typename acc_t , bool is_log_softmax>
238
- void dispatch_softmax_forward (output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count)
271
+ template <typename input_t , typename output_t , typename acc_t , bool is_log_softmax, bool is_masked = false >
272
+ void dispatch_softmax_forward (output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr )
239
273
{
240
274
TORCH_INTERNAL_ASSERT ( softmax_elements >= 0 && softmax_elements <= 1024 );
241
275
if (softmax_elements == 0 ) {
@@ -260,9 +294,9 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele
260
294
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
261
295
switch (log2_elements) {
262
296
#define LAUNCH_SOFTMAX_WARP_FORWARD (L2E ) case L2E: \
263
- softmax_warp_forward<input_t , output_t , acc_t , L2E, is_log_softmax> \
297
+ softmax_warp_forward<input_t , output_t , acc_t , L2E, is_log_softmax, is_masked > \
264
298
<<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (dst, \
265
- src, batch_count, softmax_elements_stride, softmax_elements); \
299
+ src, batch_count, softmax_elements_stride, softmax_elements, mask ); \
266
300
C10_CUDA_KERNEL_LAUNCH_CHECK (); \
267
301
break ;
268
302
0 commit comments