Skip to content

Commit 1565eba

Browse files
eqypytorchmergebot
authored andcommitted
[cuDNN][SDPA] Match query's memory layout ordering for output in cuDNN SDPA (pytorch#138354)
For pytorch#138340 ~~We might consider more sophisticated logic here but the corresponding logic in other backends doesn't seem to do anything fancy for non BSHD/BHSD cases https://github.com/pytorch/pytorch/blob/ea8ea2f33fc65b33dc562f4b0430f8c79eb81d8d/aten/src/ATen/native/transformers/cuda/attention.cu#L1145~~ ended up going with a more general approach to much more or less arbitrary layouts Pull Request resolved: pytorch#138354 Approved by: https://github.com/drisspg
1 parent a678eaf commit 1565eba

File tree

3 files changed

+131
-25
lines changed

3 files changed

+131
-25
lines changed

aten/src/ATen/native/cudnn/MHA.cpp

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,88 @@ auto fixSizeOneDimStrideSDPA(
292292
}
293293
return strides;
294294
}
295+
296+
void alloc_with_matching_layout(
297+
const Tensor& q,
298+
Tensor& output,
299+
const std::vector<int64_t>& shape) {
300+
TORCH_INTERNAL_ASSERT(
301+
shape.size() == q.sizes().size(),
302+
"cuDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim");
303+
304+
if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) {
305+
output = at::empty_like(q);
306+
return;
307+
}
308+
309+
// get the "fill order," which is just an argsort on the strides
310+
std::vector<int> fill_order(shape.size());
311+
std::iota(fill_order.begin(), fill_order.end(), 0);
312+
const auto q_strides = q.strides();
313+
std::stable_sort(
314+
fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) {
315+
return q_strides[idx1] < q_strides[idx2];
316+
});
317+
std::vector<int64_t> ordered_strides(shape.size());
318+
int64_t current_stride = 1;
319+
for (const int dim_idx : fill_order) {
320+
ordered_strides[dim_idx] = current_stride;
321+
current_stride *= shape[dim_idx];
322+
}
323+
output = at::empty(at::IntArrayRef(shape), q.options())
324+
.as_strided(
325+
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
326+
}
327+
328+
void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) {
329+
const int dims = output.sizes().size();
330+
std::vector<int64_t> outer_to_inner(dims);
331+
std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0);
332+
const auto o_strides = output.strides();
333+
std::stable_sort(
334+
outer_to_inner.begin(),
335+
outer_to_inner.end(),
336+
[&o_strides](int idx1, int idx2) {
337+
return o_strides[idx1] > o_strides[idx2];
338+
});
339+
std::vector<int64_t> inverse(dims);
340+
for (int d = 0; d < dims; d++) {
341+
inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) -
342+
outer_to_inner.begin();
343+
}
344+
grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner))
345+
.contiguous()
346+
.permute(at::IntArrayRef(inverse));
347+
}
348+
349+
bool same_strides(const Tensor& t1, const Tensor& t2) {
350+
std::vector<int> t1_strides_no_ones;
351+
std::vector<int> t2_strides_no_ones;
352+
const auto t1strides = t1.strides();
353+
const auto t2strides = t2.strides();
354+
const int dim = t1strides.size();
355+
if (dim != (int)t2strides.size()) {
356+
return false;
357+
}
358+
const auto t1sizes = t1.sizes();
359+
const auto t2sizes = t2.sizes();
360+
361+
// we are going through strides backward here, but if both are backward it's
362+
// comparable
363+
for (int i = 0; i < dim; i++) {
364+
if (t1sizes[i] > 1) {
365+
t1_strides_no_ones.push_back(t1strides[i]);
366+
}
367+
if (t2sizes[i] > 1) {
368+
t2_strides_no_ones.push_back(t2strides[i]);
369+
}
370+
}
371+
return std::equal(
372+
t1_strides_no_ones.begin(),
373+
t1_strides_no_ones.end(),
374+
t2_strides_no_ones.begin(),
375+
t2_strides_no_ones.end());
376+
}
295377
} // namespace
296378

297379
auto build_graph_and_tensors(
@@ -553,7 +635,8 @@ void run_cudnn_SDP_fprop(
553635
Tensor& dropoutoffset) {
554636
cudnnHandle_t handle = getCudnnHandle();
555637
if (!o.defined()) {
556-
o = at::empty({b, h, s_q, d_v}, q.options());
638+
// q is passed to us in BHSD dim order
639+
alloc_with_matching_layout(q, o, {b, h, s_q, d_v});
557640
}
558641

559642
if (return_softmaxstats && !softmaxstats.defined()) {
@@ -660,30 +743,14 @@ void run_cudnn_SDP_bprop(
660743
}
661744

662745
Tensor dO_ = dO;
663-
if (!dO.strides()[dO.strides().size() - 1]) {
664-
TORCH_WARN(
665-
"cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported."
666-
" Materializing a contiguous tensor which will increase memory usage...");
667-
dO_ = dO.contiguous();
668-
}
669-
if ( // handle trivial transposed case with a transposed dim of size 1
670-
// see also: https://github.com/pytorch/pytorch/issues/134001
671-
!(dO_.is_contiguous() && o.is_contiguous()) &&
672-
!std::equal(
673-
o.strides().begin(), o.strides().end(), dO.strides().begin())) {
674-
TORCH_WARN(
746+
if (!same_strides(o, dO)) {
747+
TORCH_WARN_ONCE(
675748
"cuDNN SDPA backward got grad_output.strides() != output.strides(), "
676749
"attempting to materialize a grad_output with matching strides...");
677-
if (o.is_contiguous()) {
678-
dO_ = dO.contiguous();
679-
} else {
680-
dO_ = dO.transpose(1, 2).contiguous().transpose(1, 2);
681-
}
750+
permute_to_matching_layout(o, dO_);
682751
}
683752
TORCH_INTERNAL_ASSERT(
684-
(dO_.is_contiguous() && o.is_contiguous()) ||
685-
std::equal(
686-
dO_.strides().begin(), dO_.strides().end(), o.strides().begin()),
753+
same_strides(o, dO_),
687754
"cuDNN SDPA expected grad_output.strides() == output.strides(), "
688755
"the previous step probably failed to materialize a grad_output "
689756
"with matching strides...");

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,21 @@ namespace {
5656
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
5757
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
5858
bool check_prefer_cudnn_attention() {
59-
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000
59+
// TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0
60+
// see context: https://github.com/pytorch/pytorch/issues/138340
61+
// return false;
62+
#if defined(CUDNN_VERSION)
63+
64+
#if CUDNN_VERSION > 90000
6065
auto dprops = at::cuda::getCurrentDeviceProperties();
6166
return dprops->major >= 9;
6267
#else
6368
return false;
6469
#endif
70+
71+
#else
72+
return false;
73+
#endif
6574
}
6675

6776
// flash_attention V2 is universally faster than efficient_attention and Math

test/test_transformers.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,9 +2529,9 @@ def test_cudnn_attention_trivial_output_transpose(self, device):
25292529
def test_cudnn_attention_nonmodulo64seqlen(self, device):
25302530
# see also: https://github.com/pytorch/pytorch/issues/137347
25312531
mask = torch.randint(0, 2, (2, 1, 157, 6404)).to(device="cuda", dtype=torch.bool)
2532-
q = torch.randn(2, 32, 157, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True)
2533-
k = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True)
2534-
v = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True)
2532+
q = torch.randn(2, 32, 157, 128, device='cuda', dtype=torch.float16, requires_grad=True)
2533+
k = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.float16, requires_grad=True)
2534+
v = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.float16, requires_grad=True)
25352535
q_cpu = q.detach().clone().cpu()
25362536
k_cpu = k.detach().clone().cpu()
25372537
v_cpu = v.detach().clone().cpu()
@@ -2564,6 +2564,36 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device):
25642564
torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
25652565
torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
25662566

2567+
@skipIfRocm
2568+
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
2569+
def test_cudnn_attention_preserves_query_layout(self, device):
2570+
2571+
def test_attention(backend: SDPBackend, permute_order: List[List[int]]):
2572+
BHSqD = [4, 16, 256, 64]
2573+
BHSkvD = [4, 16, 512, 64]
2574+
2575+
shape_q = [BHSqD[idx] for idx in permute_order]
2576+
shape_kv = [BHSkvD[idx] for idx in permute_order]
2577+
reverse = [permute_order.index(idx) for idx in range(4)]
2578+
q = torch.randn(*shape_q, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse)
2579+
k = torch.randn(*shape_kv, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse)
2580+
v = torch.randn(*shape_kv, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse)
2581+
self.assertEqual(q.shape, BHSqD)
2582+
self.assertEqual(k.shape, BHSkvD)
2583+
self.assertEqual(v.shape, BHSkvD)
2584+
2585+
with sdpa_kernel(backend):
2586+
out = F.scaled_dot_product_attention(q, k, v)
2587+
self.assertTrue(out.permute(permute_order).is_contiguous())
2588+
out.sum().backward()
2589+
2590+
permute_orders = list()
2591+
permutable = [0, 1, 2]
2592+
permute_orders = itertools.permutations(permutable)
2593+
2594+
for permute_order in permute_orders:
2595+
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3])
2596+
25672597
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
25682598
@parametrize("mask_dim", [1, 2, 3, 4])
25692599
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):

0 commit comments

Comments
 (0)