|
7 | 7 | from dataclasses import dataclass
|
8 | 8 |
|
9 | 9 | import torch
|
| 10 | +import torch.nn.functional as F |
10 | 11 | import torch._dynamo as torchdynamo
|
11 | 12 | from functorch.experimental.control_flow import cond, map
|
12 | 13 | from torch import Tensor
|
|
35 | 36 | from torch.export import Constraint, Dim
|
36 | 37 | from torch.fx.experimental.proxy_tensor import make_fx
|
37 | 38 | from torch.testing import FileCheck
|
| 39 | +from torch.testing._internal.common_cuda import ( |
| 40 | + PLATFORM_SUPPORTS_FLASH_ATTENTION, |
| 41 | +) |
| 42 | +from torch.testing._internal.common_device_type import ( |
| 43 | + onlyCPU, |
| 44 | + onlyCUDA, |
| 45 | +) |
38 | 46 | from torch.testing._internal.common_utils import run_tests
|
39 | 47 | from torch._dynamo.test_case import TestCase
|
40 | 48 | from torch.utils._pytree import (
|
@@ -2528,6 +2536,74 @@ def forward(self, x):
|
2528 | 2536 | # this doesn't work today
|
2529 | 2537 | gm_unflat_strict = unflatten(ep)
|
2530 | 2538 |
|
| 2539 | +@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") |
| 2540 | +class TestOneOffModelExportResult(TestCase): |
| 2541 | + def test_scaled_dot_product_attention_cpu(self): |
| 2542 | + """ |
| 2543 | + This test makes sure we are always getting the same decomposition result for SDPA. |
| 2544 | + As of now _scaled_dot_product_flash_attention_for_cpu is expected to show up in |
| 2545 | + export() result. Some downstream backend then further decompose it into core ATen |
| 2546 | + ops in torch/_decomp/decompositions.py (search for |
| 2547 | + _scaled_dot_product_flash_attention_for_cpu). |
| 2548 | +
|
| 2549 | + Export is decomposing based on the CompositeImplicitAutograd kernel implementation |
| 2550 | + of SDPA. If this test fails, it means the kernel is being modified. In this case |
| 2551 | + we strongly encourage you to change the decomposition rule under |
| 2552 | + torch/_decomp/decompositions.py along with the kernel changes, so all of the |
| 2553 | + downstream backends are not being affected. |
| 2554 | + """ |
| 2555 | + class ScaledDotProductAttention(torch.nn.Module): |
| 2556 | + def __init__(self): |
| 2557 | + super().__init__() |
| 2558 | + |
| 2559 | + def forward(self, q, k, v): |
| 2560 | + attn_output = F.scaled_dot_product_attention( |
| 2561 | + q, k, v, None, dropout_p=0.0, is_causal=True |
| 2562 | + ) |
| 2563 | + return attn_output |
| 2564 | + q = torch.randn(1, 1, 8, 8, device="cpu") |
| 2565 | + k = torch.randn(1, 1, 8, 8, device="cpu") |
| 2566 | + v = torch.randn(1, 1, 8, 8, device="cpu") |
| 2567 | + |
| 2568 | + ep = torch.export.export(ScaledDotProductAttention(), (q, k, v)) |
| 2569 | + self.assertExpectedInline(ep.graph_module.code.strip(), """\ |
| 2570 | +def forward(self, l_q_, l_k_, l_v_): |
| 2571 | + _scaled_dot_product_flash_attention_for_cpu = torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default(l_q_, l_k_, l_v_, 0.0, True); l_q_ = l_k_ = l_v_ = None |
| 2572 | + getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None |
| 2573 | + return (getitem,)""") |
| 2574 | + |
| 2575 | + @unittest.skipIf( |
| 2576 | + not PLATFORM_SUPPORTS_FLASH_ATTENTION, |
| 2577 | + "Can't run fused SDPA on this platform", |
| 2578 | + ) |
| 2579 | + def test_scaled_dot_product_attention_cuda(self): |
| 2580 | + """ |
| 2581 | + This test makes sure we are always getting the same decomposition result for SDPA. |
| 2582 | + As of now _scaled_dot_product_flash_attention is expected to show up in |
| 2583 | + export() result (GPU tensors are given). Currently there's no downstream |
| 2584 | + backend relies on this export result so if this test fails, feel free to |
| 2585 | + change it to the latest export() result. |
| 2586 | + """ |
| 2587 | + class ScaledDotProductAttention(torch.nn.Module): |
| 2588 | + def __init__(self): |
| 2589 | + super().__init__() |
| 2590 | + |
| 2591 | + def forward(self, q, k, v): |
| 2592 | + attn_output = F.scaled_dot_product_attention( |
| 2593 | + q, k, v, None, dropout_p=0.0, is_causal=True |
| 2594 | + ) |
| 2595 | + return attn_output |
| 2596 | + q = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda") |
| 2597 | + k = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda") |
| 2598 | + v = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda") |
| 2599 | + |
| 2600 | + ep = torch.export.export(ScaledDotProductAttention(), (q, k, v)) |
| 2601 | + self.assertExpectedInline(ep.graph_module.code.strip(), """\ |
| 2602 | +def forward(self, l_q_, l_k_, l_v_): |
| 2603 | + _scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(l_q_, l_k_, l_v_, 0.0, True, scale = 0.125); l_q_ = l_k_ = l_v_ = None |
| 2604 | + getitem = _scaled_dot_product_flash_attention[0]; _scaled_dot_product_flash_attention = None |
| 2605 | + return (getitem,)""") |
| 2606 | + |
2531 | 2607 |
|
2532 | 2608 | if __name__ == '__main__':
|
2533 | 2609 | run_tests()
|
0 commit comments