@@ -2538,6 +2538,7 @@ def test_strided_backwards(self):
2538
2538
(1 , 0 , 2 , 3 ), # Reverse order
2539
2539
(0 , 2 , 1 , 3 ), # Mixed order
2540
2540
(2 , 0 , 1 , 3 ), # Another mixed order
2541
+ (0 , 1 , 3 , 2 ), # Non contiguous last dim
2541
2542
],
2542
2543
)
2543
2544
@common_utils .parametrize ("shape" , [(2 , 1 , 128 , 16 ), (4 , 2 , 64 , 16 )])
@@ -2586,12 +2587,7 @@ def test_flex_attention_stride_ordering(self, mode, permute_order, shape):
2586
2587
@common_utils .parametrize ("mode" , ["eager" , "inductor" ])
2587
2588
@common_utils .parametrize (
2588
2589
"permute_order" ,
2589
- [
2590
- (0 , 1 , 2 , 3 ),
2591
- (1 , 0 , 2 , 3 ),
2592
- (0 , 2 , 1 , 3 ),
2593
- (2 , 0 , 1 , 3 ),
2594
- ],
2590
+ [(0 , 1 , 2 , 3 ), (1 , 0 , 2 , 3 ), (0 , 2 , 1 , 3 ), (2 , 0 , 1 , 3 ), (0 , 1 , 3 , 2 )],
2595
2591
)
2596
2592
@common_utils .parametrize ("shape" , [(2 , 5 , 128 , 16 ), (4 , 2 , 64 , 16 )])
2597
2593
def test_flex_attention_backward_stride_ordering (self , mode , permute_order , shape ):
@@ -2637,6 +2633,70 @@ def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shap
2637
2633
f"Mode: { mode } , Stride order mismatch for { name } : grad { input_stride_order } , input { orig_stride_order } ." ,
2638
2634
)
2639
2635
2636
+ @supported_platform
2637
+ def test_non_contiguous_last_dim (self ):
2638
+ """Test flex_attention with tensors having non contiguous last dimension."""
2639
+ B , H , D = 4 , 8 , 64
2640
+ device = "cuda"
2641
+ dtype = torch .float16 if device == "cuda" else torch .float32
2642
+ for S in [16 , 64 ]:
2643
+
2644
+ def column_major_tensor ():
2645
+ tensor = torch .randn (
2646
+ (B , H , S , D ),
2647
+ dtype = dtype ,
2648
+ device = device ,
2649
+ )
2650
+ # Column major in last 2 dims
2651
+ return tensor .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
2652
+
2653
+ q = column_major_tensor ()
2654
+ k = column_major_tensor ()
2655
+ v = column_major_tensor ()
2656
+
2657
+ requires_grad = device in DEVICE_SUPPORTS_BACKWARDS
2658
+ if requires_grad :
2659
+ q .requires_grad_ (True )
2660
+ k .requires_grad_ (True )
2661
+ v .requires_grad_ (True )
2662
+
2663
+ self .assertNotEqual (q .stride ()[- 1 ], 1 )
2664
+ self .assertNotEqual (k .stride ()[- 1 ], 1 )
2665
+ self .assertNotEqual (v .stride ()[- 1 ], 1 )
2666
+
2667
+ q_ref , k_ref , v_ref = query_key_value_clones (q , k , v )
2668
+ q_gold , k_gold , v_gold = query_key_value_clones (q , k , v , torch .float64 )
2669
+
2670
+ golden_out = flex_attention (q_gold , k_gold , v_gold )
2671
+ ref_out = flex_attention (q_ref , k_ref , v_ref )
2672
+
2673
+ flex_compiled = torch .compile (flex_attention , fullgraph = True , dynamic = True )
2674
+ compiled_out = flex_compiled (q , k , v )
2675
+
2676
+ self ._check_out (golden_out , ref_out , compiled_out )
2677
+
2678
+ if requires_grad :
2679
+ backward_grad = torch .randn_like (ref_out )
2680
+
2681
+ golden_out .backward (backward_grad .to (torch .float64 ))
2682
+ ref_out .backward (backward_grad )
2683
+ compiled_out .backward (backward_grad )
2684
+
2685
+ self ._check_out_and_grad (
2686
+ golden_out ,
2687
+ ref_out ,
2688
+ compiled_out ,
2689
+ q_gold ,
2690
+ q_ref ,
2691
+ q ,
2692
+ k_gold ,
2693
+ k_ref ,
2694
+ k ,
2695
+ v_gold ,
2696
+ v_ref ,
2697
+ v ,
2698
+ )
2699
+
2640
2700
@supported_platform
2641
2701
@common_utils .parametrize ("compile" , [True , False ])
2642
2702
def test_fully_masked_out_rows_0_check (self , compile : bool ):
0 commit comments