Skip to content

Commit 629bd6f

Browse files
drisspgpytorchmergebot
authored andcommitted
Update FlexAttention with masking semantic (pytorch#133373)
Pull Request resolved: pytorch#133373 Approved by: https://github.com/yanboliang
1 parent e792980 commit 629bd6f

File tree

5 files changed

+85
-18
lines changed

5 files changed

+85
-18
lines changed

test/inductor/test_flex_attention.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_create_empty_block_mask,
1818
_DEFAULT_SPARSE_BLOCK_SIZE,
1919
_identity,
20+
_score_mod_signature,
2021
and_masks,
2122
BlockMask,
2223
create_block_mask,
@@ -212,8 +213,7 @@ def _check_equal(
212213
):
213214
compiled_error = (golden_out - compiled_out).abs().mean()
214215
ref_error = (golden_out - ref_out).abs().mean()
215-
# TODO: Make this check stricter after updating eager SDPA masked_softmax semantics
216-
if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
216+
if torch.isnan(compiled_error).any() or torch.isnan(ref_error).any():
217217
self.assertTrue(False, "Output/Grad with NaN")
218218
if compiled_error > ref_error * fudge_factor:
219219
name = tensor_name if tensor_name is not None else ""
@@ -263,7 +263,7 @@ def _check_out_and_grad(
263263

264264
def run_test(
265265
self,
266-
score_mod: Callable,
266+
score_mod: _score_mod_signature,
267267
dtype: torch.dtype = torch.float16,
268268
Q_B: int = B,
269269
Q_H: int = H,
@@ -273,6 +273,7 @@ def run_test(
273273
KV_H: int = H,
274274
KV_S: int = S,
275275
KV_D: int = D,
276+
block_mask: Optional[BlockMask] = None,
276277
):
277278
q = torch.randn(
278279
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
@@ -285,7 +286,6 @@ def run_test(
285286
)
286287
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
287288
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
288-
block_mask = None
289289
sdpa_partial = create_attention(
290290
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
291291
)
@@ -1437,7 +1437,8 @@ def mask_mod(b, h, q, kv):
14371437
out.sum().backward()
14381438

14391439
@supported_platform
1440-
def test_fully_masked_out_rows(self):
1440+
@common_utils.parametrize("compile", [True, False])
1441+
def test_fully_masked_out_rows_0_check(self, compile: bool):
14411442
# Ensure fully masked out rows won't cause NaNs.
14421443
query = torch.randn(
14431444
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
@@ -1448,23 +1449,40 @@ def test_fully_masked_out_rows(self):
14481449
value = torch.randn(
14491450
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
14501451
)
1451-
do = torch.randn((B, H, S, D), dtype=torch.float32, device="cuda")
14521452

14531453
M = S // 2
14541454

14551455
def mask_mod(b, h, q, kv):
14561456
return q < M
14571457

14581458
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
1459-
out = torch.compile(flex_attention, dynamic=False)(
1460-
query, key, value, block_mask=block_mask
1459+
1460+
flex = (
1461+
torch.compile(flex_attention, dynamic=False) if compile else flex_attention
14611462
)
1462-
# TODO: Switch to self.run_test_with_call after updating eager SDPA masked_softmax semantics
1463+
out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True)
14631464
self.assertEqual(out[:, :, M:, :].sum(), 0)
1465+
self.assertTrue((lse[:, :, M:] == 0.0).all())
14641466

1465-
out.backward(do)
1467+
loss = out.sum() + lse.sum()
1468+
loss.backward()
14661469
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
14671470

1471+
@supported_platform
1472+
@common_utils.parametrize("compile", [True, False])
1473+
def test_fully_masked_out_rows(self, compile: bool):
1474+
M = S // 2
1475+
1476+
def mask_mod(b, h, q, kv):
1477+
return q < M
1478+
1479+
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
1480+
1481+
def noop_mod(score, b, h, q_idx, kv_idx):
1482+
return score
1483+
1484+
self.run_test(noop_mod, torch.float32, B, H, S, D, B, H, S, D, block_mask)
1485+
14681486
@supported_platform
14691487
def test_comparison_vs_sdpa(self):
14701488
def causal(score, b, h, q_idx, kv_idx):

test/inductor/test_flex_decoding.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,20 @@ def run_test(
284284
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
285285
)
286286
compiled_sdpa = torch.compile(sdpa_partial)
287-
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
288-
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
289-
compiled_out = compiled_sdpa(q, k, v)
287+
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
288+
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
289+
compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True)
290290

291291
self._check_out(
292292
golden_out,
293293
ref_out,
294294
compiled_out,
295295
)
296+
self._check_out(
297+
gold_lse,
298+
ref_lse,
299+
compiled_lse,
300+
)
296301

297302
def run_test_with_call(
298303
self,
@@ -762,6 +767,38 @@ def bias_mod(score, batch, head, token_q, token_kv):
762767

763768
self.run_test(bias_mod)
764769

770+
@supported_platform
771+
def test_fully_masked_out_rows_0_check_gqa(self):
772+
# Ensure fully masked out rows won't cause NaNs.
773+
query = torch.randn(
774+
(B, Hq, S, D), dtype=torch.float32, device="cuda", requires_grad=True
775+
)
776+
key = torch.randn(
777+
(B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
778+
)
779+
value = torch.randn(
780+
(B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
781+
)
782+
783+
M = S // 2
784+
785+
def mask_mod(b, h, q, kv):
786+
return q < M
787+
788+
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
789+
790+
flex = torch.compile(flex_attention, dynamic=False)
791+
792+
out, lse = flex(
793+
query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True
794+
)
795+
self.assertEqual(out[:, :, M:, :].sum(), 0)
796+
self.assertTrue((lse[:, :, M:] == 0.0).all())
797+
798+
loss = out.sum() + lse.sum()
799+
loss.backward()
800+
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
801+
765802
@supported_platform
766803
def test_windowed_no_mask_vs_sdpa(self):
767804
score_mod = _generate_windowed(1000)

torch/_higher_order_ops/flex_attention.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,12 @@ def math_attention(
204204
mask_mod_other_buffers,
205205
)
206206

207-
# TODO Unconditionally return logsumexp for backwards
208-
# if any(t.requires_grad for t in (query, key, value)):
207+
# Set fully masked rows' sumexp to 0.0
209208
logsumexp = post_mod_scores.logsumexp(dim=-1)
209+
masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)
210+
logsumexp = torch.where(masked_rows, 0.0, logsumexp)
210211

211-
post_mod_scores = post_mod_scores.softmax(dim=-1)
212+
post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)
212213

213214
return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
214215

torch/_inductor/kernel/flex_attention.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,13 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK
302302
)
303303
304304
305-
# Store output and logsumexp
306-
l_i = tl.where(l_i == 0, 1, l_i)
305+
# [Note] Handle fully masked out rows:
306+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
307+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
308+
l_i = tl.where(l_i == 0.0, 1, l_i)
309+
masked_out_rows = (m_i == float("-inf"))
310+
m_i = tl.where(masked_out_rows, 0, m_i)
311+
307312
acc = acc / l_i[:, None]
308313
idx_z = tl.program_id(1) // HQ
309314
idx_hq = tl.program_id(1) % HQ

torch/_inductor/kernel/flex_decoding.py

+6
Original file line numberDiff line numberDiff line change
@@ -524,11 +524,17 @@ def create_flex_decoding_kernel(*args, **kwargs):
524524
# Reduction
525525

526526
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
527+
# See [Note] Handle fully masked out rows:
528+
# g_M Is the global max among split kv blocks.
529+
masked_rows = lowerings[aten.eq](g_M, -float("inf"))
530+
g_M = lowerings[aten.where](masked_rows, 0.0, g_M)
527531
adj_M = lowerings[aten.sub](buf_M, g_M)
528532
alpha = lowerings[aten.exp2](adj_M)
529533

530534
buf_L = lowerings[aten.mul](buf_L, alpha)
531535
g_L = lowerings[aten.sum](buf_L, axis=1)
536+
masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1)
537+
g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L)
532538
logsumexp = lowerings[aten.log2](g_L)
533539
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
534540

0 commit comments

Comments
 (0)