Skip to content

Commit 72fd7ab

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca] fix flex attention backward HOP capture in initial graph (pytorch#143155)
FIXES pytorch#142313 So with previous HOPs, compiled autograd could just inline into their body and get their post-dispatch aten representation. You can't do that with this flex attention HOP, which just wants any proxy tracing mechanism to insert it into its graph. Okay, compiled autograd does use proxy tracing, so we can do that. This is safe because other than the reenter_make_fx call, there were no other make_fx internals usage in the HOP. And compiled autograd specializes on the AOT backward's saved symints which should cover any changes in shapes to the inputs of the HOP. However, there's still an issue: Dynamo doesn't know how to handle `FlexAttentionBackwardHOP` and will graph break, so the flex attention backward is running in eager as of this PR. The tlparse looks really scuffed after the compiled autograd capture: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpMMHBEH/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10 Pull Request resolved: pytorch#143155 Approved by: https://github.com/drisspg
1 parent b4f4c75 commit 72fd7ab

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch._dynamo.utils import counters
2525
from torch._inductor import config as inductor_config
2626
from torch._inductor.test_case import run_tests, TestCase
27+
from torch.nn.attention.flex_attention import flex_attention
2728
from torch.testing._internal.common_utils import (
2829
scoped_load_inline,
2930
skipIfWindows,
@@ -3216,6 +3217,31 @@ def inner_compiler(gm_, example_inputs_):
32163217
self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node))
32173218
self.assertTrue(isinstance(view_nodes[1].args[1][0], torch.fx.Node))
32183219

3220+
@unittest.skipIf(not HAS_CUDA, "requires cuda")
3221+
def test_flex_attention(self):
3222+
def fn():
3223+
@torch.compile(backend="aot_eager")
3224+
def fwd_bwd(x: torch.Tensor):
3225+
flex_attention(x, x, x).sum().backward()
3226+
3227+
for a, b in zip([12, 24, 48], [64, 128, 256]):
3228+
v = torch.zeros(
3229+
1,
3230+
1,
3231+
a * b,
3232+
b,
3233+
dtype=torch.bfloat16,
3234+
device="cuda",
3235+
requires_grad=True,
3236+
)
3237+
fwd_bwd(v)
3238+
yield v.grad
3239+
3240+
# TODO: Dynamo graph breaks on torch.ops.higher_order.flex_attention_backward
3241+
self.check_output_and_recompiles(
3242+
fn, count=3, compiler_fn=make_compiler_fn(fullgraph=False)
3243+
)
3244+
32193245
@unittest.expectedFailure
32203246
def test_saved_tensor_unpack_hook_ordering(self):
32213247
# not the correct behaviour, I'm just preventing this from changing silently

torch/_higher_order_ops/flex_attention.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch._C import DispatchKey
88
from torch._higher_order_ops.utils import (
99
_has_potential_branch_input_mutation,
10+
_maybe_reenter_make_fx,
1011
autograd_not_implemented,
1112
reenter_make_fx,
1213
save_tensors_and_symints_for_backward,
@@ -945,11 +946,14 @@ def trace_flex_attention_backward(
945946
mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
946947
mask_graph = block_mask[-1]
947948
with TransformGetItemToIndex():
948-
fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers)
949-
joint_graph = reenter_make_fx(joint_graph)(
949+
# There's no active make_fx during the compiled autograd graph's initial capture
950+
fw_graph = _maybe_reenter_make_fx(fw_graph)(
951+
*fw_example_vals, *score_mod_other_buffers
952+
)
953+
joint_graph = _maybe_reenter_make_fx(joint_graph)(
950954
*bw_example_vals, *score_mod_other_buffers
951955
)
952-
mask_graph = reenter_make_fx(mask_graph)(
956+
mask_graph = _maybe_reenter_make_fx(mask_graph)(
953957
*mask_example_vals, *mask_mod_other_buffers
954958
)
955959
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)

0 commit comments

Comments
 (0)