|
22 | 22 | from torch._dynamo import compiled_autograd, config
|
23 | 23 | from torch._dynamo.backends.debugging import aot_eager
|
24 | 24 | from torch._dynamo.device_interface import get_interface_for_device
|
| 25 | +from torch._dynamo.testing import normalize_gm |
25 | 26 | from torch._dynamo.utils import counters
|
26 | 27 | from torch._inductor import config as inductor_config
|
27 | 28 | from torch._inductor.test_case import run_tests, TestCase
|
@@ -2821,8 +2822,11 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline):
|
2821 | 2822 | opt_bwd()
|
2822 | 2823 |
|
2823 | 2824 | self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
2824 |
| - # always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops |
2825 |
| - self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) |
| 2825 | + # Compiled autograd's initial capture lifts custom C++ autograd::Function bwd instead of tracing |
| 2826 | + # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. |
| 2827 | + # In the future, we can consider having a cpu scalar movement pass sometime after we trace |
| 2828 | + # into the custom C++ autograd::Function (like in AOTDispatcher) |
| 2829 | + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) |
2826 | 2830 |
|
2827 | 2831 | def test_logs(self):
|
2828 | 2832 | logs, ctx = logs_to_string(
|
@@ -2941,12 +2945,11 @@ def forward(model, x):
|
2941 | 2945 |
|
2942 | 2946 | expected_logs = [
|
2943 | 2947 | "code: CompiledFunctionBackward (NodeCall 2)",
|
| 2948 | + "code: CompiledFunctionBackward0 (NodeCall 2)", |
2944 | 2949 | "aot0_primals_3",
|
2945 | 2950 | "aot0_relu",
|
2946 | 2951 | "aot0_le",
|
2947 | 2952 | "aot0_permute_2",
|
2948 |
| - "code: CompiledFunctionBackward0 (NodeCall 2)", |
2949 |
| - "aot0_tangents_1", |
2950 | 2953 | "aot0_full_default",
|
2951 | 2954 | "aot0_where",
|
2952 | 2955 | "aot0_mm",
|
@@ -2996,20 +2999,17 @@ def f(x):
|
2996 | 2999 |
|
2997 | 3000 | expected_logs = [
|
2998 | 3001 | "CompiledFunctionBackward1",
|
2999 |
| - "aot1_tangents_1", |
3000 | 3002 | "aot1_sin_1",
|
3001 |
| - "aot1_primals_2", |
3002 | 3003 | "aot1_neg",
|
3003 | 3004 | "aot0_tangents_2",
|
3004 | 3005 | "aot1_cos_1",
|
3005 |
| - "aot1_primals_1", |
3006 | 3006 | "aot0_tangents_1",
|
3007 | 3007 | "CompiledFunctionBackward0",
|
| 3008 | + "aot0_sin_1", |
3008 | 3009 | "aot0_neg",
|
3009 |
| - "aot0_sin", |
3010 | 3010 | "aot0_mul",
|
| 3011 | + "aot0_cos_1", |
3011 | 3012 | "aot0_mul_1",
|
3012 |
| - "aot0_cos", |
3013 | 3013 | "aot0_add",
|
3014 | 3014 | ]
|
3015 | 3015 |
|
@@ -3154,6 +3154,120 @@ def fn():
|
3154 | 3154 |
|
3155 | 3155 | self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)
|
3156 | 3156 |
|
| 3157 | + def test_tensor_subclass_basic(self): |
| 3158 | + from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode |
| 3159 | + |
| 3160 | + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: |
| 3161 | + lib.define("to_twotensor(Tensor a, Tensor b) -> Tensor") |
| 3162 | + lib.define("from_twotensor(Tensor c) -> (Tensor, Tensor)") |
| 3163 | + |
| 3164 | + def to_twotensor_backward(ctx, grad): |
| 3165 | + return torch.ops.mylib.from_twotensor(grad) |
| 3166 | + |
| 3167 | + def from_twotensor_backward(ctx, grad_a, grad_b): |
| 3168 | + raise AssertionError("shouldn't get hit") |
| 3169 | + |
| 3170 | + torch.library.register_autograd( |
| 3171 | + "mylib::to_twotensor", to_twotensor_backward, lib=lib |
| 3172 | + ) |
| 3173 | + torch.library.register_autograd( |
| 3174 | + "mylib::from_twotensor", from_twotensor_backward, lib=lib |
| 3175 | + ) |
| 3176 | + |
| 3177 | + @torch.library.register_torch_dispatch( |
| 3178 | + "mylib::to_twotensor", TwoTensorMode, lib=lib |
| 3179 | + ) |
| 3180 | + def _(_0, _1, _2, args, kwargs): |
| 3181 | + assert not kwargs |
| 3182 | + a, b = args |
| 3183 | + return TwoTensor(a.clone(), b.clone()) |
| 3184 | + |
| 3185 | + @torch.library.register_torch_dispatch( |
| 3186 | + "mylib::from_twotensor", TwoTensor, lib=lib |
| 3187 | + ) |
| 3188 | + def _(_0, _1, _2, args, kwargs): |
| 3189 | + assert not kwargs |
| 3190 | + (c,) = args |
| 3191 | + return c.a.clone(), c.b.clone() |
| 3192 | + |
| 3193 | + @torch.compile(backend="aot_eager", fullgraph=True) |
| 3194 | + def fn(x): |
| 3195 | + return x * x + 2 |
| 3196 | + |
| 3197 | + param1 = torch.randn(4, 4, requires_grad=True) |
| 3198 | + param2 = torch.randn(4, 4, requires_grad=True) |
| 3199 | + with TwoTensorMode(): |
| 3200 | + x = torch.ops.mylib.to_twotensor(param1, param2) |
| 3201 | + |
| 3202 | + inner_compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager") |
| 3203 | + graphs = [] |
| 3204 | + |
| 3205 | + def compiler_fn(gm): |
| 3206 | + graphs.append(gm) |
| 3207 | + return inner_compiler_fn(gm) |
| 3208 | + |
| 3209 | + with compiled_autograd._enable(compiler_fn): |
| 3210 | + res = fn(x) |
| 3211 | + res.sum().backward() |
| 3212 | + |
| 3213 | + self.assertEqual(param1.grad, 2 * param1) |
| 3214 | + self.assertEqual(param2.grad, 2 * param2) |
| 3215 | + self.assertEqual(len(graphs), 1) |
| 3216 | + |
| 3217 | + graph_code = normalize_gm(graphs[0].print_readable(print_output=False)) |
| 3218 | + # The graph should have make_subclass calls in it. |
| 3219 | + self.assertExpectedInline( |
| 3220 | + graph_code, |
| 3221 | + """\ |
| 3222 | +class CompiledAutograd0(torch.nn.Module): |
| 3223 | + def forward(self, inputs, sizes, scalars, hooks): |
| 3224 | + getitem = inputs[0] |
| 3225 | + getitem_1 = inputs[1] |
| 3226 | + getitem_2 = inputs[2] |
| 3227 | + getitem_3 = inputs[3] |
| 3228 | + getitem_4 = inputs[4]; inputs = None |
| 3229 | +
|
| 3230 | + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None |
| 3231 | + getitem_5 = validate_outputs[0]; validate_outputs = None |
| 3232 | +
|
| 3233 | + sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], [4, 4]); getitem_5 = None |
| 3234 | + getitem_6 = sum_backward0[0]; sum_backward0 = None |
| 3235 | + validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], True)]); getitem_6 = None |
| 3236 | + getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None |
| 3237 | +
|
| 3238 | + getitem_8 = hooks[0]; getitem_8 = None |
| 3239 | + call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], getitem_7); getitem_1 = getitem_2 = getitem_7 = None |
| 3240 | + aot0_primals_1 = call_aot_bwd_prologue[0] |
| 3241 | + aot0_primals_2 = call_aot_bwd_prologue[1] |
| 3242 | + aot0_tangents_1 = call_aot_bwd_prologue[2] |
| 3243 | + aot0_tangents_2 = call_aot_bwd_prologue[3]; call_aot_bwd_prologue = None |
| 3244 | +
|
| 3245 | + aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None |
| 3246 | + aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None |
| 3247 | +
|
| 3248 | + aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None |
| 3249 | + aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None |
| 3250 | +
|
| 3251 | + make_subclass = torch__dynamo_compiled_autograd_make_subclass(aot0_add_2, aot0_add_3); aot0_add_2 = aot0_add_3 = None |
| 3252 | +
|
| 3253 | + getitem_13 = hooks[1]; hooks = None |
| 3254 | + call_backward = torch__dynamo_external_utils_call_backward(getitem_13, (), make_subclass); getitem_13 = make_subclass = None |
| 3255 | + getitem_16 = call_backward[0] |
| 3256 | + getitem_17 = call_backward[1]; call_backward = None |
| 3257 | + validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_16, getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], False), ((None, None, device(type='cpu'), 6, 0, None), [4, 4], False)]); getitem_16 = getitem_17 = None |
| 3258 | + getitem_19 = validate_outputs_2[0] |
| 3259 | +
|
| 3260 | + accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_19); getitem_4 = getitem_19 = accumulate_grad__1 = None |
| 3261 | +
|
| 3262 | + getitem_20 = validate_outputs_2[1]; validate_outputs_2 = None |
| 3263 | +
|
| 3264 | + accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_20); getitem_3 = getitem_20 = accumulate_grad_ = None |
| 3265 | +
|
| 3266 | + _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None |
| 3267 | + return [] |
| 3268 | +""", # noqa: B950 |
| 3269 | + ) |
| 3270 | + |
3157 | 3271 | # https://github.com/pytorch/pytorch/issues/138920
|
3158 | 3272 | def test_compiled_autograd_does_not_specialize_on_bw_symints(self):
|
3159 | 3273 | class Mod(torch.nn.Module):
|
@@ -3247,7 +3361,7 @@ def inner_compiler(gm_, example_inputs_):
|
3247 | 3361 | # because we ignore all of these guards anyway in CA.
|
3248 | 3362 | # Once we stop using make_fx in CA, we won't have to worry about this specialization.
|
3249 | 3363 | view_nodes = graphs[1].graph.find_nodes(
|
3250 |
| - op="call_function", target=torch.ops.aten.view.default |
| 3364 | + op="call_function", target=torch.ops.aten.reshape.default |
3251 | 3365 | )
|
3252 | 3366 | # First 2 view nodes have a first argument that is a SymInt, not an int burned into the graph
|
3253 | 3367 | self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node))
|
@@ -3640,6 +3754,7 @@ def wrap_test_class(orig_cls):
|
3640 | 3754 | "test_tp_compile_comm_reordering",
|
3641 | 3755 | "test_unwrap_async_collective_tensor_tangent",
|
3642 | 3756 | # Uncategorized
|
| 3757 | + "test_not_implemented_grad", # Dynamo changes the types of exceptions |
3643 | 3758 | }
|
3644 | 3759 |
|
3645 | 3760 | if not HAS_CUDA:
|
|
0 commit comments