@@ -292,6 +292,88 @@ auto fixSizeOneDimStrideSDPA(
292
292
}
293
293
return strides;
294
294
}
295
+
296
+ void alloc_with_matching_layout (
297
+ const Tensor& q,
298
+ Tensor& output,
299
+ const std::vector<int64_t >& shape) {
300
+ TORCH_INTERNAL_ASSERT (
301
+ shape.size () == q.sizes ().size (),
302
+ " cuDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim" );
303
+
304
+ if (std::equal (q.sizes ().begin (), q.sizes ().end (), shape.begin ())) {
305
+ output = at::empty_like (q);
306
+ return ;
307
+ }
308
+
309
+ // get the "fill order," which is just an argsort on the strides
310
+ std::vector<int > fill_order (shape.size ());
311
+ std::iota (fill_order.begin (), fill_order.end (), 0 );
312
+ const auto q_strides = q.strides ();
313
+ std::stable_sort (
314
+ fill_order.begin (), fill_order.end (), [&q_strides](int idx1, int idx2) {
315
+ return q_strides[idx1] < q_strides[idx2];
316
+ });
317
+ std::vector<int64_t > ordered_strides (shape.size ());
318
+ int64_t current_stride = 1 ;
319
+ for (const int dim_idx : fill_order) {
320
+ ordered_strides[dim_idx] = current_stride;
321
+ current_stride *= shape[dim_idx];
322
+ }
323
+ output = at::empty (at::IntArrayRef (shape), q.options ())
324
+ .as_strided (
325
+ at::IntArrayRef (shape), at::IntArrayRef (ordered_strides), 0 );
326
+ }
327
+
328
+ void permute_to_matching_layout (const Tensor& output, Tensor& grad_output) {
329
+ const int dims = output.sizes ().size ();
330
+ std::vector<int64_t > outer_to_inner (dims);
331
+ std::iota (outer_to_inner.begin (), outer_to_inner.end (), 0 );
332
+ const auto o_strides = output.strides ();
333
+ std::stable_sort (
334
+ outer_to_inner.begin (),
335
+ outer_to_inner.end (),
336
+ [&o_strides](int idx1, int idx2) {
337
+ return o_strides[idx1] > o_strides[idx2];
338
+ });
339
+ std::vector<int64_t > inverse (dims);
340
+ for (int d = 0 ; d < dims; d++) {
341
+ inverse[d] = std::find (outer_to_inner.begin (), outer_to_inner.end (), d) -
342
+ outer_to_inner.begin ();
343
+ }
344
+ grad_output = grad_output.permute (at::IntArrayRef (outer_to_inner))
345
+ .contiguous ()
346
+ .permute (at::IntArrayRef (inverse));
347
+ }
348
+
349
+ bool same_strides (const Tensor& t1, const Tensor& t2) {
350
+ std::vector<int > t1_strides_no_ones;
351
+ std::vector<int > t2_strides_no_ones;
352
+ const auto t1strides = t1.strides ();
353
+ const auto t2strides = t2.strides ();
354
+ const int dim = t1strides.size ();
355
+ if (dim != (int )t2strides.size ()) {
356
+ return false ;
357
+ }
358
+ const auto t1sizes = t1.sizes ();
359
+ const auto t2sizes = t2.sizes ();
360
+
361
+ // we are going through strides backward here, but if both are backward it's
362
+ // comparable
363
+ for (int i = 0 ; i < dim; i++) {
364
+ if (t1sizes[i] > 1 ) {
365
+ t1_strides_no_ones.push_back (t1strides[i]);
366
+ }
367
+ if (t2sizes[i] > 1 ) {
368
+ t2_strides_no_ones.push_back (t2strides[i]);
369
+ }
370
+ }
371
+ return std::equal (
372
+ t1_strides_no_ones.begin (),
373
+ t1_strides_no_ones.end (),
374
+ t2_strides_no_ones.begin (),
375
+ t2_strides_no_ones.end ());
376
+ }
295
377
} // namespace
296
378
297
379
auto build_graph_and_tensors (
@@ -553,7 +635,8 @@ void run_cudnn_SDP_fprop(
553
635
Tensor& dropoutoffset) {
554
636
cudnnHandle_t handle = getCudnnHandle ();
555
637
if (!o.defined ()) {
556
- o = at::empty ({b, h, s_q, d_v}, q.options ());
638
+ // q is passed to us in BHSD dim order
639
+ alloc_with_matching_layout (q, o, {b, h, s_q, d_v});
557
640
}
558
641
559
642
if (return_softmaxstats && !softmaxstats.defined ()) {
@@ -660,30 +743,14 @@ void run_cudnn_SDP_bprop(
660
743
}
661
744
662
745
Tensor dO_ = dO;
663
- if (!dO.strides ()[dO.strides ().size () - 1 ]) {
664
- TORCH_WARN (
665
- " cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported."
666
- " Materializing a contiguous tensor which will increase memory usage..." );
667
- dO_ = dO.contiguous ();
668
- }
669
- if ( // handle trivial transposed case with a transposed dim of size 1
670
- // see also: https://github.com/pytorch/pytorch/issues/134001
671
- !(dO_.is_contiguous () && o.is_contiguous ()) &&
672
- !std::equal (
673
- o.strides ().begin (), o.strides ().end (), dO.strides ().begin ())) {
674
- TORCH_WARN (
746
+ if (!same_strides (o, dO)) {
747
+ TORCH_WARN_ONCE (
675
748
" cuDNN SDPA backward got grad_output.strides() != output.strides(), "
676
749
" attempting to materialize a grad_output with matching strides..." );
677
- if (o.is_contiguous ()) {
678
- dO_ = dO.contiguous ();
679
- } else {
680
- dO_ = dO.transpose (1 , 2 ).contiguous ().transpose (1 , 2 );
681
- }
750
+ permute_to_matching_layout (o, dO_);
682
751
}
683
752
TORCH_INTERNAL_ASSERT (
684
- (dO_.is_contiguous () && o.is_contiguous ()) ||
685
- std::equal (
686
- dO_.strides ().begin (), dO_.strides ().end (), o.strides ().begin ()),
753
+ same_strides (o, dO_),
687
754
" cuDNN SDPA expected grad_output.strides() == output.strides(), "
688
755
" the previous step probably failed to materialize a grad_output "
689
756
" with matching strides..." );
0 commit comments