Skip to content

Commit 19e93b8

Browse files
drisspgpytorchmergebot
authored andcommitted
Fixes last_dim stride check for singleton dimensions (pytorch#117001)
Fixes pytorch#116333 Pull Request resolved: pytorch#117001 Approved by: https://github.com/cpuhrsch
1 parent 8bcdde5 commit 19e93b8

File tree

4 files changed

+21
-3
lines changed

4 files changed

+21
-3
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
341341
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
342342
check_batch_size_and_num_heads_dense,
343343
check_nonzero_sequence_lengths_dense,
344-
check_last_dim_stride_equals_1_dense);
344+
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
345345
for (auto& constraint : dense_constraints) {
346346
if (!constraint(params, debug)) {
347347
return false;
@@ -399,7 +399,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
399399
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
400400
check_batch_size_and_num_heads_dense,
401401
check_nonzero_sequence_lengths_dense,
402-
check_last_dim_stride_equals_1_dense);
402+
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>);
403403
for (auto& constraint : dense_constraints) {
404404
if (!constraint(params, debug)) {
405405
return false;

aten/src/ATen/native/transformers/sdp_utils_cpp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
4646
check_attn_mask_shape,
4747
check_head_dim_size_cpp,
4848
check_nonzero_sequence_lengths_dense,
49-
check_last_dim_stride_equals_1_dense);
49+
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>);
5050
for (auto& constraint : constraints) {
5151
if (!constraint(params, debug)) {
5252
return false;

aten/src/ATen/native/transformers/sdp_utils_cpp.h

+8
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool
431431
return true;
432432
}
433433

434+
template<bool ignore_singleton_dim>
434435
inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
435436
// The stride checking for NestedTensors is done within the kernel
436437
// And .contiguous will be called if needed
@@ -439,6 +440,13 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool
439440
// fused_attention have stride 1
440441
bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
441442
params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;
443+
444+
// https://github.com/pytorch/pytorch/issues/116333
445+
// If the head_dim is size 1 the stride won't matter, but we
446+
// check this condition before padding the head_dim to 1
447+
if (ignore_singleton_dim){
448+
qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
449+
}
442450
bool mask_stride_equal_1 = params.attn_mask.has_value()
443451
? params.attn_mask.value().sym_stride(-1) == 1
444452
: true;

test/test_transformers.py

+10
Original file line numberDiff line numberDiff line change
@@ -2121,6 +2121,16 @@ def test_mem_eff_attention_non_contig_mask_bug(self, device):
21212121
max_diff = (out - out_contig).abs().mean()
21222122
self.assertTrue(max_diff.item() < 1e-7)
21232123

2124+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
2125+
def test_singelton_head_dim_stride_ne_1(self, device):
2126+
query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device)
2127+
query = query.transpose(-1, -2)
2128+
key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2129+
value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2130+
2131+
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
2132+
scaled_dot_product_attention(query, key, value)
2133+
21242134
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
21252135
@parametrize("type", ["dense", "nested"])
21262136
@parametrize("is_contiguous", [True, False])

0 commit comments

Comments
 (0)