Skip to content

Commit aab67c6

Browse files
zrphercule2facebook-github-bot
authored andcommitted
Add native masked_softmax (pytorch#69268)
Summary: Pull Request resolved: pytorch#69268 This diff enabled native masked softmax on CUDA, also expanded our current warp_softmax to accept masking. The mask in this masked softmax has to be the same shape as input, and has to be contiguous. In a following diff I will submit later, I will have encoder mask layout included, where input is BHDD and mask is BD. Test Plan: buck build mode/opt -c fbcode.enable_gpu_sections=true caffe2/test:nn && buck-out/gen/caffe2/test/nn\#binary.par -r test_masked_softmax Reviewed By: ngimel Differential Revision: D32338419 fbshipit-source-id: 48c3fde793ad4535725d9dae712db42e2bdb8a49
1 parent a5996a6 commit aab67c6

File tree

5 files changed

+162
-23
lines changed

5 files changed

+162
-23
lines changed

aten/src/ATen/native/SoftMax.cpp

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,12 @@ TORCH_META_FUNC(_log_softmax_backward_data)
125125
namespace native {
126126
namespace {
127127

128-
template <typename scalar_t, bool LogSoftMax>
129-
void host_softmax(Tensor output, const Tensor& input, const int64_t dim) {
128+
template <typename scalar_t, bool LogSoftMax, bool MaskedSoftMax = false>
129+
void host_softmax(
130+
Tensor output,
131+
const Tensor& input,
132+
const int64_t dim,
133+
bool* mask = nullptr) {
130134
int64_t outer_size = 1;
131135
int64_t dim_size = input.size(dim);
132136
int64_t inner_size = 1;
@@ -140,6 +144,7 @@ void host_softmax(Tensor output, const Tensor& input, const int64_t dim) {
140144
int64_t outer_stride = dim_size * dim_stride;
141145
scalar_t* input_data_base = input.data_ptr<scalar_t>();
142146
scalar_t* output_data_base = output.data_ptr<scalar_t>();
147+
bool* mask_data_base = mask;
143148
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
144149
parallel_for(
145150
0, outer_size * inner_size, grain_size,
@@ -151,14 +156,38 @@ void host_softmax(Tensor output, const Tensor& input, const int64_t dim) {
151156
input_data_base + outer_idx * outer_stride + inner_idx;
152157
scalar_t* output_data =
153158
output_data_base + outer_idx * outer_stride + inner_idx;
159+
bool* mask_data = nullptr;
160+
if (MaskedSoftMax) {
161+
mask_data = mask_data_base + outer_idx * outer_stride + inner_idx;
162+
}
163+
164+
// Calc max in softmax dim
165+
bool is_meaningful_max = false;
154166
scalar_t max_input = input_data[0];
155-
for (const auto d : c10::irange(1, dim_size)) {
156-
max_input = std::max(max_input, input_data[d * dim_stride]);
167+
if (!MaskedSoftMax) {
168+
for (const auto d : c10::irange(1, dim_size)) {
169+
max_input = std::max(max_input, input_data[d * dim_stride]);
170+
}
171+
} else {
172+
for (const auto d : c10::irange(0, dim_size)) {
173+
if (mask_data[d * dim_stride]) {
174+
max_input = is_meaningful_max
175+
? std::max(max_input, input_data[d * dim_stride])
176+
: input_data[d * dim_stride];
177+
is_meaningful_max = true;
178+
}
179+
}
157180
}
158181

182+
// Calc sum in softmax dim
159183
acc_type<scalar_t, false> tmpsum = 0;
160184
for (const auto d : c10::irange(dim_size)) {
161-
scalar_t z = std::exp(input_data[d * dim_stride] - max_input);
185+
scalar_t z{};
186+
if (!MaskedSoftMax || mask_data[d * dim_stride]) {
187+
z = std::exp(input_data[d * dim_stride] - max_input);
188+
} else {
189+
z = 0;
190+
}
162191
if (!LogSoftMax) {
163192
output_data[d * dim_stride] = z;
164193
}
@@ -171,7 +200,9 @@ void host_softmax(Tensor output, const Tensor& input, const int64_t dim) {
171200
tmpsum = 1 / tmpsum;
172201
}
173202

203+
// update output
174204
for (const auto d : c10::irange(dim_size)) {
205+
// LogSoftMax and MaskedSoftMax should not both be true
175206
if (LogSoftMax) {
176207
output_data[d * dim_stride] =
177208
input_data[d * dim_stride] - max_input - tmpsum;
@@ -294,7 +325,10 @@ TORCH_IMPL_FUNC(log_softmax_cpu_out)
294325
} else {
295326
AT_DISPATCH_FLOATING_TYPES_AND(
296327
at::ScalarType::BFloat16, input_.scalar_type(), "log_softmax", [&] {
297-
host_softmax<scalar_t, true>(output, input_, dim_);
328+
host_softmax<
329+
scalar_t,
330+
true /* LogSoftMax */,
331+
false /* MaskedSoftMax */>(output, input_, dim_);
298332
});
299333
}
300334
}
@@ -431,5 +465,23 @@ Tensor log_softmax(const Tensor& self, Dimname dim, optional<ScalarType> dtype)
431465
return at::log_softmax(self, dimname_to_position(self, dim), dtype);
432466
}
433467

468+
Tensor masked_softmax_cpu(const Tensor& input, const Tensor& mask) {
469+
Tensor output = at::empty_like(input, input.options());
470+
TORCH_CHECK(
471+
input.sizes() == mask.sizes(), "Mask shape should match input shape");
472+
TORCH_CHECK(mask.is_contiguous(), "Mask should always be contiguous");
473+
TORCH_CHECK(
474+
mask.scalar_type() == ScalarType::Bool,
475+
"Mask should be a boolean tensor");
476+
AT_DISPATCH_FLOATING_TYPES_AND(
477+
at::ScalarType::BFloat16, input.scalar_type(), "log_softmax", [&] {
478+
host_softmax<
479+
scalar_t,
480+
false /* LogSoftMax */,
481+
true /* MaskedSoftMax */>(
482+
output, input, input.dim() - 1, mask.data_ptr<bool>());
483+
});
484+
return output;
485+
}
434486
}
435487
}

aten/src/ATen/native/cuda/PersistentSoftmax.cuh

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,17 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
5555
// CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch.
5656
// ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs.
5757
// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed.
58+
// is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed.
5859
// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t.
5960
// This allows SoftMax to be fused with a cast immediately following the SoftMax.
61+
// The mask should have the same shape as input, with a boolean indicate if the value is masked.
6062
// For instance:
6163
// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor.
6264
// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
6365
// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor.
6466

65-
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
66-
__global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count)
67+
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked = false>
68+
__global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr)
6769
{
6870
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
6971
constexpr int next_power_of_two = 1 << log2_elements;
@@ -84,7 +86,9 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
8486

8587
src += first_batch * stride + local_idx;
8688
dst += first_batch * stride + local_idx;
87-
89+
if (is_masked) {
90+
mask += first_batch * stride + local_idx;
91+
}
8892
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
8993
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
9094
// the nested loops.
@@ -108,10 +112,23 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
108112
acc_t max_value[WARP_BATCH];
109113
#pragma unroll
110114
for (int i = 0; i < WARP_BATCH; ++i) {
115+
bool is_meaningful_max = false;
111116
max_value[i] = elements[i][0];
112117
#pragma unroll
113-
for (int it = 1; it < WARP_ITERATIONS; ++it) {
114-
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
118+
for (int it = 0; it < WARP_ITERATIONS; ++it) {
119+
if (is_masked) {
120+
if (mask[i*element_count+it*WARP_SIZE]) {
121+
max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
122+
is_meaningful_max = true;
123+
}
124+
} else {
125+
max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
126+
}
127+
}
128+
if (is_masked) {
129+
if (!is_meaningful_max) {
130+
max_value[i] = -std::numeric_limits<acc_t>::infinity();
131+
}
115132
}
116133
}
117134
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
@@ -121,11 +138,22 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
121138
for (int i = 0; i < WARP_BATCH; ++i) {
122139
#pragma unroll
123140
for (int it = 0; it < WARP_ITERATIONS; ++it) {
124-
if (is_log_softmax) {
125-
sum[i] += std::exp(elements[i][it] - max_value[i]);
141+
if (!is_masked) {
142+
if (is_log_softmax) {
143+
sum[i] += std::exp(elements[i][it] - max_value[i]);
144+
} else {
145+
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
146+
sum[i] += elements[i][it];
147+
}
126148
} else {
127-
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
128-
sum[i] += elements[i][it];
149+
if (mask[i*element_count+it*WARP_SIZE]) {
150+
if (is_log_softmax) {
151+
sum[i] += std::exp(elements[i][it] - max_value[i]);
152+
} else {
153+
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
154+
sum[i] += elements[i][it];
155+
}
156+
}
129157
}
130158
}
131159
}
@@ -141,6 +169,12 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
141169
for (int it = 0; it < WARP_ITERATIONS; ++it) {
142170
int element_index = local_idx + it * WARP_SIZE;
143171
if (element_index < element_count) {
172+
if (is_masked) {
173+
if (!mask[i*element_count+it*WARP_SIZE]) {
174+
dst[i*element_count+it*WARP_SIZE] = 0;
175+
continue;
176+
}
177+
}
144178
if (is_log_softmax) {
145179
dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i];
146180
} else {
@@ -234,8 +268,8 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
234268

235269
} // end of anonymous namespace
236270

237-
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
238-
void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count)
271+
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked = false>
272+
void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr)
239273
{
240274
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
241275
if (softmax_elements == 0) {
@@ -260,9 +294,9 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele
260294
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
261295
switch (log2_elements) {
262296
#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \
263-
softmax_warp_forward<input_t, output_t, acc_t, L2E, is_log_softmax> \
297+
softmax_warp_forward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked> \
264298
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, \
265-
src, batch_count, softmax_elements_stride, softmax_elements); \
299+
src, batch_count, softmax_elements_stride, softmax_elements, mask); \
266300
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
267301
break;
268302

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,8 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
713713
int64_t remaining = outer_size;
714714
int64_t chunk_size = (1L << 30L) / dim_size;
715715
while(remaining > 0) {
716-
dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, is_log_softmax>(
717-
output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size));
716+
dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false>(
717+
output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */);
718718
input_ptr += chunk_size * dim_size;
719719
output_ptr += chunk_size * dim_size;
720720
remaining -= chunk_size;
@@ -734,8 +734,8 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
734734
int64_t remaining = outer_size;
735735
int64_t chunk_size = (1<<30) / dim_size;
736736
while(remaining > 0) {
737-
dispatch_softmax_forward<scalar_t, accscalar_t, accscalar_t, is_log_softmax>(
738-
output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size));
737+
dispatch_softmax_forward<scalar_t, accscalar_t, accscalar_t, is_log_softmax, false>(
738+
output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */);
739739
input_ptr += chunk_size * dim_size;
740740
output_ptr += chunk_size * dim_size;
741741
remaining -= chunk_size;
@@ -941,5 +941,32 @@ TORCH_IMPL_FUNC(softmax_backward_cuda_out)
941941
Tensor tmp = grad * output;
942942
host_softmax_backward<SoftMaxBackwardEpilogue,false>(tmp, output, dim, half_to_float, grad_input);
943943
}
944+
945+
Tensor masked_softmax_cuda(const Tensor& input, const Tensor& mask) {
946+
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor");
947+
TORCH_CHECK(mask.is_contiguous(), "Mask should always be contiguous");
948+
// Always do masked softmax on last dim
949+
int softmax_elements = input.size(input.dim() - 1);
950+
TORCH_CHECK(softmax_elements <= 1024, "TODO: Masked softmax only support softmax elements <= 1024");
951+
Tensor output = at::empty_like(input, input.options());
952+
int batch_count = input.numel() / softmax_elements;
953+
AT_DISPATCH_FLOATING_TYPES_AND2(
954+
ScalarType::Half,
955+
ScalarType::BFloat16,
956+
input.scalar_type(),
957+
"masked_softmax",
958+
[&] {
959+
using accscalar_t = acc_type<scalar_t, true>;
960+
dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false, true>(
961+
output.data_ptr<scalar_t>(), // dst
962+
input.data_ptr<scalar_t>(), // src
963+
softmax_elements,
964+
softmax_elements,
965+
batch_count,
966+
mask.data_ptr<bool>()
967+
);
968+
});
969+
return output;
970+
}
944971
}
945972
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5892,6 +5892,11 @@
58925892
- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
58935893
variants: function, method
58945894

5895+
- func: _masked_softmax(Tensor self, Tensor mask) -> Tensor
5896+
dispatch:
5897+
CUDA: masked_softmax_cuda
5898+
CPU: masked_softmax_cpu
5899+
58955900
- func: view(Tensor(a) self, int[] size) -> Tensor(a)
58965901
variants: method
58975902
device_check: NoCheck

test/test_nn.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15560,6 +15560,27 @@ def embedding_bag_check(indices, weights, mode, sparse, padding_idx):
1556015560
rtol = None
1556115561
self.assertEqual(grad, grad_check, msg=msg, atol=atol, rtol=rtol)
1556215562

15563+
def test_masked_softmax(self, device):
15564+
B = 10
15565+
num_heads = 8
15566+
L = 512
15567+
input = torch.randn((B, num_heads, L, L))
15568+
mask = torch.randint(0, 2, (B, L))
15569+
if (self.device_type == "cuda"):
15570+
input = input.cuda()
15571+
mask = mask.cuda()
15572+
mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
15573+
native_res = torch._masked_softmax(input, mask)
15574+
mask = mask.float()
15575+
15576+
def slow_masked_softmax(input, mask):
15577+
exp = torch.exp(input)
15578+
exp = exp * mask
15579+
s = exp.sum(dim=3, keepdim=True).expand(exp.size())
15580+
return exp / s
15581+
pt_res = slow_masked_softmax(input, mask)
15582+
self.assertEqual(pt_res, native_res, exact_dtype=True)
15583+
1556315584
# Test fails on Vg20
1556415585
@skipCUDAIfRocm
1556515586
@dtypesIfCUDA(torch.half, torch.float)

0 commit comments

Comments
 (0)