17
17
_create_empty_block_mask ,
18
18
_DEFAULT_SPARSE_BLOCK_SIZE ,
19
19
_identity ,
20
+ _score_mod_signature ,
20
21
and_masks ,
21
22
BlockMask ,
22
23
create_block_mask ,
@@ -212,8 +213,7 @@ def _check_equal(
212
213
):
213
214
compiled_error = (golden_out - compiled_out ).abs ().mean ()
214
215
ref_error = (golden_out - ref_out ).abs ().mean ()
215
- # TODO: Make this check stricter after updating eager SDPA masked_softmax semantics
216
- if torch .isnan (compiled_error ).any () and not torch .isnan (ref_error ).any ():
216
+ if torch .isnan (compiled_error ).any () or torch .isnan (ref_error ).any ():
217
217
self .assertTrue (False , "Output/Grad with NaN" )
218
218
if compiled_error > ref_error * fudge_factor :
219
219
name = tensor_name if tensor_name is not None else ""
@@ -263,7 +263,7 @@ def _check_out_and_grad(
263
263
264
264
def run_test (
265
265
self ,
266
- score_mod : Callable ,
266
+ score_mod : _score_mod_signature ,
267
267
dtype : torch .dtype = torch .float16 ,
268
268
Q_B : int = B ,
269
269
Q_H : int = H ,
@@ -273,6 +273,7 @@ def run_test(
273
273
KV_H : int = H ,
274
274
KV_S : int = S ,
275
275
KV_D : int = D ,
276
+ block_mask : Optional [BlockMask ] = None ,
276
277
):
277
278
q = torch .randn (
278
279
(Q_B , Q_H , Q_S , Q_D ), dtype = dtype , device = "cuda" , requires_grad = True
@@ -285,7 +286,6 @@ def run_test(
285
286
)
286
287
q_ref , k_ref , v_ref = query_key_value_clones (q , k , v )
287
288
q_gold , k_gold , v_gold = query_key_value_clones (q , k , v , torch .float64 )
288
- block_mask = None
289
289
sdpa_partial = create_attention (
290
290
score_mod , block_mask , enable_gqa = (not Q_H == KV_H )
291
291
)
@@ -1437,7 +1437,8 @@ def mask_mod(b, h, q, kv):
1437
1437
out .sum ().backward ()
1438
1438
1439
1439
@supported_platform
1440
- def test_fully_masked_out_rows (self ):
1440
+ @common_utils .parametrize ("compile" , [True , False ])
1441
+ def test_fully_masked_out_rows_0_check (self , compile : bool ):
1441
1442
# Ensure fully masked out rows won't cause NaNs.
1442
1443
query = torch .randn (
1443
1444
(B , H , S , D ), dtype = torch .float32 , device = "cuda" , requires_grad = True
@@ -1448,23 +1449,40 @@ def test_fully_masked_out_rows(self):
1448
1449
value = torch .randn (
1449
1450
(B , H , S , D ), dtype = torch .float32 , device = "cuda" , requires_grad = True
1450
1451
)
1451
- do = torch .randn ((B , H , S , D ), dtype = torch .float32 , device = "cuda" )
1452
1452
1453
1453
M = S // 2
1454
1454
1455
1455
def mask_mod (b , h , q , kv ):
1456
1456
return q < M
1457
1457
1458
1458
block_mask = create_block_mask (mask_mod , 1 , 1 , S , S )
1459
- out = torch .compile (flex_attention , dynamic = False )(
1460
- query , key , value , block_mask = block_mask
1459
+
1460
+ flex = (
1461
+ torch .compile (flex_attention , dynamic = False ) if compile else flex_attention
1461
1462
)
1462
- # TODO: Switch to self.run_test_with_call after updating eager SDPA masked_softmax semantics
1463
+ out , lse = flex ( query , key , value , block_mask = block_mask , return_lse = True )
1463
1464
self .assertEqual (out [:, :, M :, :].sum (), 0 )
1465
+ self .assertTrue ((lse [:, :, M :] == 0.0 ).all ())
1464
1466
1465
- out .backward (do )
1467
+ loss = out .sum () + lse .sum ()
1468
+ loss .backward ()
1466
1469
self .assertEqual (query .grad [:, :, M :, :].sum (), 0 )
1467
1470
1471
+ @supported_platform
1472
+ @common_utils .parametrize ("compile" , [True , False ])
1473
+ def test_fully_masked_out_rows (self , compile : bool ):
1474
+ M = S // 2
1475
+
1476
+ def mask_mod (b , h , q , kv ):
1477
+ return q < M
1478
+
1479
+ block_mask = create_block_mask (mask_mod , 1 , 1 , S , S )
1480
+
1481
+ def noop_mod (score , b , h , q_idx , kv_idx ):
1482
+ return score
1483
+
1484
+ self .run_test (noop_mod , torch .float32 , B , H , S , D , B , H , S , D , block_mask )
1485
+
1468
1486
@supported_platform
1469
1487
def test_comparison_vs_sdpa (self ):
1470
1488
def causal (score , b , h , q_idx , kv_idx ):
0 commit comments