Skip to content

Commit 22e9f32

Browse files
committed
[FlexAttention] Remove Old Constraint on lastdim strides
1 parent 20d62a8 commit 22e9f32

File tree

2 files changed

+80
-11
lines changed

2 files changed

+80
-11
lines changed

test/inductor/test_flex_attention.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,6 +2538,7 @@ def test_strided_backwards(self):
25382538
(1, 0, 2, 3), # Reverse order
25392539
(0, 2, 1, 3), # Mixed order
25402540
(2, 0, 1, 3), # Another mixed order
2541+
(0, 1, 3, 2), # Non contiguous last dim
25412542
],
25422543
)
25432544
@common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)])
@@ -2586,12 +2587,7 @@ def test_flex_attention_stride_ordering(self, mode, permute_order, shape):
25862587
@common_utils.parametrize("mode", ["eager", "inductor"])
25872588
@common_utils.parametrize(
25882589
"permute_order",
2589-
[
2590-
(0, 1, 2, 3),
2591-
(1, 0, 2, 3),
2592-
(0, 2, 1, 3),
2593-
(2, 0, 1, 3),
2594-
],
2590+
[(0, 1, 2, 3), (1, 0, 2, 3), (0, 2, 1, 3), (2, 0, 1, 3), (0, 1, 3, 2)],
25952591
)
25962592
@common_utils.parametrize("shape", [(2, 5, 128, 16), (4, 2, 64, 16)])
25972593
def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shape):
@@ -2637,6 +2633,70 @@ def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shap
26372633
f"Mode: {mode}, Stride order mismatch for {name}: grad {input_stride_order}, input {orig_stride_order}.",
26382634
)
26392635

2636+
@supported_platform
2637+
def test_non_contiguous_last_dim(self):
2638+
"""Test flex_attention with tensors having non contiguous last dimension."""
2639+
B, H, D = 4, 8, 64
2640+
device = "cuda"
2641+
dtype = torch.float16 if device == "cuda" else torch.float32
2642+
for S in [16, 64]:
2643+
2644+
def column_major_tensor():
2645+
tensor = torch.randn(
2646+
(B, H, S, D),
2647+
dtype=dtype,
2648+
device=device,
2649+
)
2650+
# Column major in last 2 dims
2651+
return tensor.transpose(-1, -2).contiguous().transpose(-1, -2)
2652+
2653+
q = column_major_tensor()
2654+
k = column_major_tensor()
2655+
v = column_major_tensor()
2656+
2657+
requires_grad = device in DEVICE_SUPPORTS_BACKWARDS
2658+
if requires_grad:
2659+
q.requires_grad_(True)
2660+
k.requires_grad_(True)
2661+
v.requires_grad_(True)
2662+
2663+
self.assertNotEqual(q.stride()[-1], 1)
2664+
self.assertNotEqual(k.stride()[-1], 1)
2665+
self.assertNotEqual(v.stride()[-1], 1)
2666+
2667+
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
2668+
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
2669+
2670+
golden_out = flex_attention(q_gold, k_gold, v_gold)
2671+
ref_out = flex_attention(q_ref, k_ref, v_ref)
2672+
2673+
flex_compiled = torch.compile(flex_attention, fullgraph=True, dynamic=True)
2674+
compiled_out = flex_compiled(q, k, v)
2675+
2676+
self._check_out(golden_out, ref_out, compiled_out)
2677+
2678+
if requires_grad:
2679+
backward_grad = torch.randn_like(ref_out)
2680+
2681+
golden_out.backward(backward_grad.to(torch.float64))
2682+
ref_out.backward(backward_grad)
2683+
compiled_out.backward(backward_grad)
2684+
2685+
self._check_out_and_grad(
2686+
golden_out,
2687+
ref_out,
2688+
compiled_out,
2689+
q_gold,
2690+
q_ref,
2691+
q,
2692+
k_gold,
2693+
k_ref,
2694+
k,
2695+
v_gold,
2696+
v_ref,
2697+
v,
2698+
)
2699+
26402700
@supported_platform
26412701
@common_utils.parametrize("compile", [True, False])
26422702
def test_fully_masked_out_rows_0_check(self, compile: bool):

torch/_inductor/kernel/flex_attention.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,15 @@ def check_cpu_supported():
930930
return supported
931931

932932

933+
def contiguous_last_dim(x):
934+
"""Ensure that realized IR node has a contigous stride in the last dimension."""
935+
strides = x.maybe_get_stride()
936+
if strides and strides[-1] != 1:
937+
contiguous_stride_order = list(reversed(range(len(x.get_size()))))
938+
return ExternKernel.require_stride_order(x, contiguous_stride_order)
939+
return x
940+
941+
933942
def lower_cpu(
934943
query,
935944
key,
@@ -1092,6 +1101,9 @@ def convert_mask_graph_module(mask_graph):
10921101
if isinstance(item, TensorBox):
10931102
fake_buffers.append(item.data.data) # type: ignore[attr-defined]
10941103

1104+
# CPU kernel requires last dim to be contiguous
1105+
query, key, value = map(contiguous_last_dim, [query, key, value])
1106+
10951107
(
10961108
query,
10971109
key,
@@ -1258,7 +1270,6 @@ def set_head_dim_values(
12581270
)
12591271

12601272

1261-
# TODO: We probably also need a layout constraint?
12621273
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
12631274
def flex_attention(
12641275
query,
@@ -1413,11 +1424,9 @@ def flex_attention(
14131424
else:
14141425
kernel_options.setdefault("IS_DIVISIBLE", True)
14151426

1416-
# Reuse query strides for output layout despite different last dimension.
1417-
# This works because only the last dim differs and we check it is contiguous.
1427+
# NB it is okay that the v_head_dim is different
1428+
# We are using these to match fill order of the output.
14181429
q_strides = query.get_stride()
1419-
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
1420-
14211430
# Construct output layout with strides matching the query.
14221431
out_size = [B, Hq, seq_len_q, v_head_dim]
14231432
out_strides = infer_dense_strides(out_size, q_strides)

0 commit comments

Comments
 (0)