Skip to content

Commit 1a790f5

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
[RELAND] Error grad mode op in export API (pytorch#117420)
Summary: Title Test Plan: CI Differential Revision: D52706691 Pull Request resolved: pytorch#117420 Approved by: https://github.com/angelayi
1 parent d6847c5 commit 1a790f5

File tree

4 files changed

+148
-6
lines changed

4 files changed

+148
-6
lines changed

test/export/test_safeguard.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Owner(s): ["module: dynamo"]
2+
import unittest
3+
4+
import torch
5+
import torch._dynamo as torchdynamo
6+
from torch.export import export
7+
from torch.testing._internal.common_utils import run_tests, TestCase
8+
9+
10+
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
11+
class TestSafeguard(TestCase):
12+
# If the autograd state doesn't change, dynamo eliminates autograd state manager op and later export can succeed.
13+
# Otherwise, autograd can be preserved in the produced gragh, and export will fail.
14+
def test_global_autograd(self):
15+
def f1(a):
16+
with torch.no_grad():
17+
b = a + a
18+
return b
19+
20+
def f2(a):
21+
with torch.enable_grad():
22+
b = a + a
23+
return b
24+
25+
def f3(a):
26+
with torch.set_grad_enabled(False):
27+
b = a + a
28+
return b
29+
30+
def f4(a):
31+
with torch.set_grad_enabled(True):
32+
b = a + a
33+
return b
34+
35+
a = torch.randn(10)
36+
with torch.no_grad():
37+
export(f1, (a,))
38+
export(f2, (a,))
39+
export(f3, (a,))
40+
export(f4, (a,))
41+
42+
with torch.enable_grad():
43+
export(f2, (a,))
44+
export(f4, (a,))
45+
46+
with self.assertRaisesRegex(
47+
RuntimeError, "Encountered autograd state manager op.*"
48+
):
49+
export(f1, (a,))
50+
51+
with self.assertRaisesRegex(
52+
RuntimeError, "Encountered autograd state manager op.*"
53+
):
54+
export(f3, (a,))
55+
56+
def test_tensor_autograd(self):
57+
# dynamo errors when Tensor.requires_grad_ change the autograd state
58+
def f1(a):
59+
a.requires_grad_(True)
60+
b = a + a
61+
return b
62+
63+
# dynamo errors when Tensor.requires_grad_ change the autograd state
64+
def f2(a):
65+
a.requires_grad_(False)
66+
b = a + a
67+
return b
68+
69+
# dynamo always errors on Tensor.requires_grad
70+
def f3(a):
71+
a.requires_grad = False
72+
b = a + a
73+
return b
74+
75+
export(f1, (torch.randn(10, requires_grad=True),))
76+
export(f2, (torch.randn(10, requires_grad=False),))
77+
78+
with self.assertRaises(RuntimeError):
79+
export(f1, (torch.randn(10, requires_grad=False),))
80+
with self.assertRaises(RuntimeError):
81+
export(f2, (torch.randn(10, requires_grad=True),))
82+
with self.assertRaises(RuntimeError):
83+
export(f3, (torch.randn(10, requires_grad=False),))
84+
85+
86+
if __name__ == "__main__":
87+
run_tests()

test/onnx/test_fx_to_onnx.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -669,11 +669,12 @@ def forward(self, x):
669669
return self.normal.sample(x.shape)
670670

671671
x = torch.randn(2, 3)
672-
exported_program = torch.export.export(Model(), args=(x,))
673-
_ = torch.onnx.dynamo_export(
674-
exported_program,
675-
x,
676-
)
672+
with torch.no_grad():
673+
exported_program = torch.export.export(Model(), args=(x,))
674+
_ = torch.onnx.dynamo_export(
675+
exported_program,
676+
x,
677+
)
677678

678679
def test_aten_linalg_vector_norm_with_reducel2(self):
679680
class Net(nn.Module):

torch/export/_safeguard.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch
2+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
3+
from torch.overrides import TorchFunctionMode
4+
5+
6+
class AutogradStateOpsFailSafeguard(TorchFunctionMode):
7+
"""
8+
Detect grad state ops during exporting the graph and fail the process by
9+
raising an error, to avoid unexpected behavior. Those grad mode ops could be:
10+
`torch.no_grad`
11+
`torch.enable_grad`
12+
`torch.set_grad_enabled`
13+
14+
Export with predispatch mode is exempted.
15+
"""
16+
17+
def __torch_function__(self, func, types, args=(), kwargs=None):
18+
kwargs = kwargs or {}
19+
unsupported_grad_mode_ops = [
20+
torch._C._set_grad_enabled,
21+
]
22+
# It's only enabled while tracing, by confirming the torch dispatch mode is
23+
# any active PROXY. This is to allow the autograd ops out of tracing.
24+
current_state = torch._C.is_grad_enabled()
25+
if func in unsupported_grad_mode_ops:
26+
assert len(args) == 1
27+
changed_state = args[0]
28+
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
29+
# Intend to check if it's not the pre_dispatch mode. It's allowed to use
30+
# autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
31+
if (
32+
mode
33+
and isinstance(mode, ProxyTorchDispatchMode)
34+
and not mode.pre_dispatch
35+
and changed_state != current_state
36+
):
37+
raise RuntimeError(
38+
f"Encountered autograd state manager op {func} trying to change global autograd state "
39+
"while exporting. This is unsafe because we don't capture this op in torch.export "
40+
"today, hence we can't reflect the user intention soundly."
41+
)
42+
return func(*args, **kwargs)

torch/export/_trace.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import re
66
from collections import OrderedDict
7+
from contextlib import nullcontext
78
from typing import Any, Callable, Dict, List, Optional, Tuple
89

910
import torch
@@ -30,6 +31,8 @@
3031
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
3132
from torch.utils._sympy.value_ranges import ValueRangeError
3233

34+
from ._safeguard import AutogradStateOpsFailSafeguard
35+
3336
from .dynamic_shapes import _process_constraints, Constraint
3437
from .exported_program import (
3538
_disable_prexisiting_fake_mode,
@@ -380,10 +383,19 @@ def _export_non_strict(
380383
transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later.
381384
pre_dispatch=False,
382385
):
386+
# [NOTE] If the user is exporting under training mode, we want to detect if there is any
387+
# state change in the autograd global state and error. If the user is exporting under inference
388+
# mode, we don't care.
389+
is_grad_enabled = torch._C.is_grad_enabled()
390+
grad_safe_guard = (
391+
AutogradStateOpsFailSafeguard() if is_grad_enabled else nullcontext()
392+
)
383393
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
384394
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
385395
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
386-
with torch.nn.utils.stateless._reparametrize_module(mod, fake_params_buffers):
396+
with torch.nn.utils.stateless._reparametrize_module(
397+
mod, fake_params_buffers
398+
), grad_safe_guard: # type: ignore[attr-defined]
387399
gm, graph_signature = transform(aot_export_module)(
388400
mod,
389401
(*fake_args, *fake_kwargs.values()),

0 commit comments

Comments
 (0)