Skip to content

Commit 1a85451

Browse files
larryliu0820pytorchmergebot
authored andcommitted
[export] Add unit test for SDPA export result (pytorch#117390)
Summary: A follow up for pytorch#117097. In that PR I didn't add `_scaled_dot_product_attention_for_cpu` into the core_aten_decomposition table. This PR does that and also add a unit test. Test Plan: python test/export/test_export.py -k test_scaled_dot_product_attention Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#117390 Approved by: https://github.com/drisspg
1 parent bf27dd6 commit 1a85451

File tree

4 files changed

+78
-3
lines changed

4 files changed

+78
-3
lines changed

test/expect/HasDecompTest.test_aten_core_operators.expect

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ aten::_native_batch_norm_legit
1515
aten::_native_batch_norm_legit.no_stats
1616
aten::_native_batch_norm_legit_functional
1717
aten::_native_batch_norm_legit_no_training
18-
aten::_scaled_dot_product_flash_attention_for_cpu
1918
aten::_softmax
2019
aten::_softmax.out
2120
aten::_to_copy

test/export/test_export.py

+76
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass
88

99
import torch
10+
import torch.nn.functional as F
1011
import torch._dynamo as torchdynamo
1112
from functorch.experimental.control_flow import cond, map
1213
from torch import Tensor
@@ -35,6 +36,13 @@
3536
from torch.export import Constraint, Dim
3637
from torch.fx.experimental.proxy_tensor import make_fx
3738
from torch.testing import FileCheck
39+
from torch.testing._internal.common_cuda import (
40+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
41+
)
42+
from torch.testing._internal.common_device_type import (
43+
onlyCPU,
44+
onlyCUDA,
45+
)
3846
from torch.testing._internal.common_utils import run_tests
3947
from torch._dynamo.test_case import TestCase
4048
from torch.utils._pytree import (
@@ -2528,6 +2536,74 @@ def forward(self, x):
25282536
# this doesn't work today
25292537
gm_unflat_strict = unflatten(ep)
25302538

2539+
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
2540+
class TestOneOffModelExportResult(TestCase):
2541+
def test_scaled_dot_product_attention_cpu(self):
2542+
"""
2543+
This test makes sure we are always getting the same decomposition result for SDPA.
2544+
As of now _scaled_dot_product_flash_attention_for_cpu is expected to show up in
2545+
export() result. Some downstream backend then further decompose it into core ATen
2546+
ops in torch/_decomp/decompositions.py (search for
2547+
_scaled_dot_product_flash_attention_for_cpu).
2548+
2549+
Export is decomposing based on the CompositeImplicitAutograd kernel implementation
2550+
of SDPA. If this test fails, it means the kernel is being modified. In this case
2551+
we strongly encourage you to change the decomposition rule under
2552+
torch/_decomp/decompositions.py along with the kernel changes, so all of the
2553+
downstream backends are not being affected.
2554+
"""
2555+
class ScaledDotProductAttention(torch.nn.Module):
2556+
def __init__(self):
2557+
super().__init__()
2558+
2559+
def forward(self, q, k, v):
2560+
attn_output = F.scaled_dot_product_attention(
2561+
q, k, v, None, dropout_p=0.0, is_causal=True
2562+
)
2563+
return attn_output
2564+
q = torch.randn(1, 1, 8, 8, device="cpu")
2565+
k = torch.randn(1, 1, 8, 8, device="cpu")
2566+
v = torch.randn(1, 1, 8, 8, device="cpu")
2567+
2568+
ep = torch.export.export(ScaledDotProductAttention(), (q, k, v))
2569+
self.assertExpectedInline(ep.graph_module.code.strip(), """\
2570+
def forward(self, l_q_, l_k_, l_v_):
2571+
_scaled_dot_product_flash_attention_for_cpu = torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default(l_q_, l_k_, l_v_, 0.0, True); l_q_ = l_k_ = l_v_ = None
2572+
getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None
2573+
return (getitem,)""")
2574+
2575+
@unittest.skipIf(
2576+
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
2577+
"Can't run fused SDPA on this platform",
2578+
)
2579+
def test_scaled_dot_product_attention_cuda(self):
2580+
"""
2581+
This test makes sure we are always getting the same decomposition result for SDPA.
2582+
As of now _scaled_dot_product_flash_attention is expected to show up in
2583+
export() result (GPU tensors are given). Currently there's no downstream
2584+
backend relies on this export result so if this test fails, feel free to
2585+
change it to the latest export() result.
2586+
"""
2587+
class ScaledDotProductAttention(torch.nn.Module):
2588+
def __init__(self):
2589+
super().__init__()
2590+
2591+
def forward(self, q, k, v):
2592+
attn_output = F.scaled_dot_product_attention(
2593+
q, k, v, None, dropout_p=0.0, is_causal=True
2594+
)
2595+
return attn_output
2596+
q = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda")
2597+
k = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda")
2598+
v = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda")
2599+
2600+
ep = torch.export.export(ScaledDotProductAttention(), (q, k, v))
2601+
self.assertExpectedInline(ep.graph_module.code.strip(), """\
2602+
def forward(self, l_q_, l_k_, l_v_):
2603+
_scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(l_q_, l_k_, l_v_, 0.0, True, scale = 0.125); l_q_ = l_k_ = l_v_ = None
2604+
getitem = _scaled_dot_product_flash_attention[0]; _scaled_dot_product_flash_attention = None
2605+
return (getitem,)""")
2606+
25312607

25322608
if __name__ == '__main__':
25332609
run_tests()

torch/_decomp/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
389389
aten.rrelu_with_noise_,
390390
aten.rsub.Scalar,
391391
aten.rsub.Tensor,
392-
aten._scaled_dot_product_flash_attention.default,
392+
aten._scaled_dot_product_flash_attention_for_cpu.default,
393393
aten.select_backward,
394394
aten.select_scatter,
395395
aten.sgn,

torch/_inductor/decomposition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
# the Inductor decomp table.
7474
decomps_to_exclude = [
7575
aten._unsafe_index,
76-
aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py
76+
aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py
7777
aten.clamp_max,
7878
aten.clamp_min,
7979
aten.glu, # inductor lowers this directly

0 commit comments

Comments
 (0)