Skip to content

Commit ea141d8

Browse files
zou3519pytorchmergebot
authored andcommitted
functional compiled autograd (pytorch#144707)
This PR squashes together the following commits: pytorch#144115 pytorch#143417 pytorch#143405 pytorch#143387 pytorch#143304 pytorch#143296 This is a refactor of compiled autograd to use "functional autograd". The end goal is that it gets compiled autograd's initial capture to stop specializing on Tensor metadata, therefore allowing compiled autograd to better handle Tensor subclasses. For more information, please read the commit messages for each PR. Pull Request resolved: pytorch#144707 Approved by: https://github.com/bdhirsh, https://github.com/xmfan, https://github.com/jansel
1 parent 87fdadd commit ea141d8

28 files changed

+1809
-223
lines changed

aten/src/ATen/TensorGeometry.h

+10
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry {
3737
has_symbolic_sizes_strides_(
3838
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
3939

40+
explicit TensorGeometry(
41+
std::vector<at::SymInt> sizes,
42+
std::vector<at::SymInt> strides,
43+
at::SymInt storage_offset)
44+
: sizes_(std::move(sizes)),
45+
strides_(std::move(strides)),
46+
storage_offset_(std::move(storage_offset)) {
47+
recompute();
48+
}
49+
4050
// true if the tensor is contiguous
4151
bool is_contiguous() const;
4252

build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ core_trainer_sources = [
138138
"torch/csrc/autograd/variable.cpp",
139139
"torch/csrc/autograd/utils/warnings.cpp",
140140
"torch/csrc/autograd/jit_decomp_interface.cpp",
141+
"torch/csrc/dynamo/compiled_autograd.cpp",
141142
"torch/csrc/jit/frontend/name_mangler.cpp",
142143
"torch/csrc/jit/ir/type_hashing.cpp",
143144
"torch/csrc/jit/serialization/pickler.cpp",

test/dynamo/test_backward_higher_order_ops.py

+34-20
Original file line numberDiff line numberDiff line change
@@ -121,23 +121,30 @@ def fn(x, y):
121121
out.backward(grad_out)
122122
actual = normalize_gm(graph.print_readable(False))
123123
self.assertEqual(x.grad, grad_out * grad_out)
124-
self.assertExpectedInline(
125-
actual,
126-
"""\
124+
if backend in ["aot_eager", "inductor"]:
125+
self.assertExpectedInline(
126+
actual,
127+
"""\
127128
class GraphModule(torch.nn.Module):
128129
def forward(self, L_inputs_ : list):
129130
l_inputs_ = L_inputs_
130131
131-
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
132+
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
132133
133-
new_grad: "f32[s0]" = torch.clone(getitem)
134+
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None
135+
getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None
134136
135-
result: "f32[s0]" = getitem * getitem; getitem = None
137+
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None
138+
getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
136139
137-
new_grad_1: "f32[s0]" = torch.clone(result); result = None
140+
new_grad: "f32[2]" = torch.clone(getitem_5)
141+
142+
result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None
143+
144+
new_grad_1: "f32[2]" = torch.clone(result); result = None
138145
return (new_grad, new_grad_1)
139146
""",
140-
)
147+
)
141148

142149
graph = None
143150

@@ -162,7 +169,7 @@ def inner_compiler(gm_, example_inputs_):
162169
gm, backend=inner_compiler, fullgraph=True, dynamic=True
163170
)
164171

165-
for backend in ["eager", "aot_eager", "inductor"]:
172+
for backend in ["inductor"]:
166173
torch._dynamo.reset()
167174
x = torch.tensor([0.5, 0.5], requires_grad=True)
168175
y = torch.tensor([0.5, 0.5], requires_grad=True)
@@ -187,26 +194,33 @@ def fn(x, y):
187194
actual = normalize_gm(graph.print_readable(False))
188195
self.assertEqual(obj.counter, 1)
189196
self.assertEqual(x.grad, grad_out + grad_out)
190-
self.assertExpectedInline(
191-
actual,
192-
"""\
197+
if backend in ["aot_eager", "inductor"]:
198+
self.assertExpectedInline(
199+
actual,
200+
"""\
193201
class GraphModule(torch.nn.Module):
194-
def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"):
202+
def forward(self, L_inputs_ : list, L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s1)"):
195203
l_inputs_ = L_inputs_
196-
l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter
204+
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
197205
198-
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
206+
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
199207
200-
new_grad: "f32[s0]" = torch.clone(getitem)
208+
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None
209+
getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None
201210
202-
add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None
211+
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None
212+
getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
203213
204-
result: "f32[s0]" = getitem * getitem; getitem = None
214+
new_grad: "f32[2]" = torch.clone(getitem_5)
205215
206-
new_grad_1: "f32[s0]" = torch.clone(result); result = None
216+
add: "Sym(s1 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
217+
218+
result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None
219+
220+
new_grad_1: "f32[2]" = torch.clone(result); result = None
207221
return (new_grad, new_grad_1, add)
208222
""",
209-
)
223+
)
210224

211225
out = fn(x, y)
212226
out.backward(grad_out)

test/inductor/test_compiled_autograd.py

+125-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch._dynamo import compiled_autograd, config
2323
from torch._dynamo.backends.debugging import aot_eager
2424
from torch._dynamo.device_interface import get_interface_for_device
25+
from torch._dynamo.testing import normalize_gm
2526
from torch._dynamo.utils import counters
2627
from torch._inductor import config as inductor_config
2728
from torch._inductor.test_case import run_tests, TestCase
@@ -2821,8 +2822,11 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline):
28212822
opt_bwd()
28222823

28232824
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
2824-
# always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops
2825-
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
2825+
# Compiled autograd's initial capture lifts custom C++ autograd::Function bwd instead of tracing
2826+
# into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
2827+
# In the future, we can consider having a cpu scalar movement pass sometime after we trace
2828+
# into the custom C++ autograd::Function (like in AOTDispatcher)
2829+
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
28262830

28272831
def test_logs(self):
28282832
logs, ctx = logs_to_string(
@@ -2941,12 +2945,11 @@ def forward(model, x):
29412945

29422946
expected_logs = [
29432947
"code: CompiledFunctionBackward (NodeCall 2)",
2948+
"code: CompiledFunctionBackward0 (NodeCall 2)",
29442949
"aot0_primals_3",
29452950
"aot0_relu",
29462951
"aot0_le",
29472952
"aot0_permute_2",
2948-
"code: CompiledFunctionBackward0 (NodeCall 2)",
2949-
"aot0_tangents_1",
29502953
"aot0_full_default",
29512954
"aot0_where",
29522955
"aot0_mm",
@@ -2996,20 +2999,17 @@ def f(x):
29962999

29973000
expected_logs = [
29983001
"CompiledFunctionBackward1",
2999-
"aot1_tangents_1",
30003002
"aot1_sin_1",
3001-
"aot1_primals_2",
30023003
"aot1_neg",
30033004
"aot0_tangents_2",
30043005
"aot1_cos_1",
3005-
"aot1_primals_1",
30063006
"aot0_tangents_1",
30073007
"CompiledFunctionBackward0",
3008+
"aot0_sin_1",
30083009
"aot0_neg",
3009-
"aot0_sin",
30103010
"aot0_mul",
3011+
"aot0_cos_1",
30113012
"aot0_mul_1",
3012-
"aot0_cos",
30133013
"aot0_add",
30143014
]
30153015

@@ -3154,6 +3154,120 @@ def fn():
31543154

31553155
self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)
31563156

3157+
def test_tensor_subclass_basic(self):
3158+
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
3159+
3160+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
3161+
lib.define("to_twotensor(Tensor a, Tensor b) -> Tensor")
3162+
lib.define("from_twotensor(Tensor c) -> (Tensor, Tensor)")
3163+
3164+
def to_twotensor_backward(ctx, grad):
3165+
return torch.ops.mylib.from_twotensor(grad)
3166+
3167+
def from_twotensor_backward(ctx, grad_a, grad_b):
3168+
raise AssertionError("shouldn't get hit")
3169+
3170+
torch.library.register_autograd(
3171+
"mylib::to_twotensor", to_twotensor_backward, lib=lib
3172+
)
3173+
torch.library.register_autograd(
3174+
"mylib::from_twotensor", from_twotensor_backward, lib=lib
3175+
)
3176+
3177+
@torch.library.register_torch_dispatch(
3178+
"mylib::to_twotensor", TwoTensorMode, lib=lib
3179+
)
3180+
def _(_0, _1, _2, args, kwargs):
3181+
assert not kwargs
3182+
a, b = args
3183+
return TwoTensor(a.clone(), b.clone())
3184+
3185+
@torch.library.register_torch_dispatch(
3186+
"mylib::from_twotensor", TwoTensor, lib=lib
3187+
)
3188+
def _(_0, _1, _2, args, kwargs):
3189+
assert not kwargs
3190+
(c,) = args
3191+
return c.a.clone(), c.b.clone()
3192+
3193+
@torch.compile(backend="aot_eager", fullgraph=True)
3194+
def fn(x):
3195+
return x * x + 2
3196+
3197+
param1 = torch.randn(4, 4, requires_grad=True)
3198+
param2 = torch.randn(4, 4, requires_grad=True)
3199+
with TwoTensorMode():
3200+
x = torch.ops.mylib.to_twotensor(param1, param2)
3201+
3202+
inner_compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager")
3203+
graphs = []
3204+
3205+
def compiler_fn(gm):
3206+
graphs.append(gm)
3207+
return inner_compiler_fn(gm)
3208+
3209+
with compiled_autograd._enable(compiler_fn):
3210+
res = fn(x)
3211+
res.sum().backward()
3212+
3213+
self.assertEqual(param1.grad, 2 * param1)
3214+
self.assertEqual(param2.grad, 2 * param2)
3215+
self.assertEqual(len(graphs), 1)
3216+
3217+
graph_code = normalize_gm(graphs[0].print_readable(print_output=False))
3218+
# The graph should have make_subclass calls in it.
3219+
self.assertExpectedInline(
3220+
graph_code,
3221+
"""\
3222+
class CompiledAutograd0(torch.nn.Module):
3223+
def forward(self, inputs, sizes, scalars, hooks):
3224+
getitem = inputs[0]
3225+
getitem_1 = inputs[1]
3226+
getitem_2 = inputs[2]
3227+
getitem_3 = inputs[3]
3228+
getitem_4 = inputs[4]; inputs = None
3229+
3230+
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
3231+
getitem_5 = validate_outputs[0]; validate_outputs = None
3232+
3233+
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], [4, 4]); getitem_5 = None
3234+
getitem_6 = sum_backward0[0]; sum_backward0 = None
3235+
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], True)]); getitem_6 = None
3236+
getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None
3237+
3238+
getitem_8 = hooks[0]; getitem_8 = None
3239+
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], getitem_7); getitem_1 = getitem_2 = getitem_7 = None
3240+
aot0_primals_1 = call_aot_bwd_prologue[0]
3241+
aot0_primals_2 = call_aot_bwd_prologue[1]
3242+
aot0_tangents_1 = call_aot_bwd_prologue[2]
3243+
aot0_tangents_2 = call_aot_bwd_prologue[3]; call_aot_bwd_prologue = None
3244+
3245+
aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None
3246+
aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None
3247+
3248+
aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None
3249+
aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None
3250+
3251+
make_subclass = torch__dynamo_compiled_autograd_make_subclass(aot0_add_2, aot0_add_3); aot0_add_2 = aot0_add_3 = None
3252+
3253+
getitem_13 = hooks[1]; hooks = None
3254+
call_backward = torch__dynamo_external_utils_call_backward(getitem_13, (), make_subclass); getitem_13 = make_subclass = None
3255+
getitem_16 = call_backward[0]
3256+
getitem_17 = call_backward[1]; call_backward = None
3257+
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_16, getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], False), ((None, None, device(type='cpu'), 6, 0, None), [4, 4], False)]); getitem_16 = getitem_17 = None
3258+
getitem_19 = validate_outputs_2[0]
3259+
3260+
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_19); getitem_4 = getitem_19 = accumulate_grad__1 = None
3261+
3262+
getitem_20 = validate_outputs_2[1]; validate_outputs_2 = None
3263+
3264+
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_20); getitem_3 = getitem_20 = accumulate_grad_ = None
3265+
3266+
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
3267+
return []
3268+
""", # noqa: B950
3269+
)
3270+
31573271
# https://github.com/pytorch/pytorch/issues/138920
31583272
def test_compiled_autograd_does_not_specialize_on_bw_symints(self):
31593273
class Mod(torch.nn.Module):
@@ -3247,7 +3361,7 @@ def inner_compiler(gm_, example_inputs_):
32473361
# because we ignore all of these guards anyway in CA.
32483362
# Once we stop using make_fx in CA, we won't have to worry about this specialization.
32493363
view_nodes = graphs[1].graph.find_nodes(
3250-
op="call_function", target=torch.ops.aten.view.default
3364+
op="call_function", target=torch.ops.aten.reshape.default
32513365
)
32523366
# First 2 view nodes have a first argument that is a SymInt, not an int burned into the graph
32533367
self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node))
@@ -3640,6 +3754,7 @@ def wrap_test_class(orig_cls):
36403754
"test_tp_compile_comm_reordering",
36413755
"test_unwrap_async_collective_tensor_tangent",
36423756
# Uncategorized
3757+
"test_not_implemented_grad", # Dynamo changes the types of exceptions
36433758
}
36443759

36453760
if not HAS_CUDA:

test/inductor/test_distributed_patterns.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ def test_module_backward_hooks_eager(self):
337337
self.assertEqual(fw_cnt.frame_count, 1)
338338
self.assertEqual(fw_cnt.op_count, 5)
339339
self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None
340-
self.assertEqual(bw_cnt.op_count, 48)
340+
self.assertEqual(
341+
bw_cnt.op_count, 72
342+
) # Number of ops in the Dynamo-produced graphs
341343

342344
def test_module_backward_hooks_aot(self):
343345
m1, inp1 = init_module_bw_hooks(True)

0 commit comments

Comments
 (0)