Skip to content

Commit 2f72635

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
automatic dynamic unspecialize float (pytorch#141647)
Pull Request resolved: pytorch#141647 Approved by: https://github.com/ezyang
1 parent e29dabb commit 2f72635

File tree

9 files changed

+95
-29
lines changed

9 files changed

+95
-29
lines changed

test/dynamo/test_misc.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,7 +1398,7 @@ def fn(x, cfg):
13981398
cfg2.val = 2.0
13991399
v = opt_fn(v, cfg2) # 7
14001400
self.assertEqual(v[0], 7)
1401-
self.assertEqual(cnts.op_count, 8)
1401+
self.assertEqual(cnts.op_count, 9)
14021402

14031403
def test_config_getattr_default(self):
14041404
class Cfg:
@@ -3747,8 +3747,18 @@ def inner():
37473747
result1, result2, _ = opt_fn()
37483748
self.assertAlmostEqual(orig1 + 1 * i, result1)
37493749
self.assertTrue(torch.allclose(orig2 + 10 * i, result2))
3750-
self.assertEqual(cnts.frame_count, 1)
3751-
self.assertEqual(cnts.op_count, 3)
3750+
if i == 1:
3751+
# No automatic dynamic
3752+
self.assertEqual(cnts.frame_count, 1)
3753+
self.assertEqual(cnts.op_count, 3)
3754+
elif i == 2:
3755+
# Automatic dynamic float arguments kicked in
3756+
self.assertEqual(cnts.frame_count, 1)
3757+
self.assertEqual(cnts.op_count, 6)
3758+
else:
3759+
# No more recompiles
3760+
self.assertEqual(cnts.frame_count, 0)
3761+
self.assertEqual(cnts.op_count, 0)
37523762
cnts.clear()
37533763

37543764
def test_closure_with_mutation_and_graph_break(self):

test/dynamo/test_unspec.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -641,18 +641,19 @@ def f(x, y):
641641
cf = torch.compile(backend=cnts, fullgraph=True)(f)
642642

643643
x = torch.randn(3)
644-
self.assertEqual(f(x, 3.0), cf(x, 3.0))
644+
self.assertEqual(f(x, 2.0), cf(x, 2.0))
645+
self.assertEqual(f(x, 3.0), cf(x, 3.0)) # automatic dynamic kicks in here
645646
self.assertEqual(f(x, 4.0), cf(x, 4.0))
646-
self.assertExpectedInline(cnts.frame_count, """1""") # no recompile
647+
self.assertExpectedInline(cnts.frame_count, """2""") # no recompile
647648
self.assertEqual(f(x, 5.0), cf(x, 5.0))
648-
self.assertExpectedInline(cnts.frame_count, """2""") # guard worked
649+
self.assertExpectedInline(cnts.frame_count, """3""") # guard worked
649650
self.assertEqual(f(x, math.nan), cf(x, math.nan))
650-
self.assertExpectedInline(cnts.frame_count, """3""") # nan always recompiles
651+
self.assertExpectedInline(cnts.frame_count, """4""") # nan always recompiles
651652

652653
@torch._dynamo.config.patch(specialize_float=False, capture_scalar_outputs=True)
653654
def test_unspecialized_float_multiply_precision(self):
654655
dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64]
655-
for dtype in dtypes:
656+
for i, dtype in enumerate(dtypes):
656657

657658
def fn(x, y):
658659
return x * y
@@ -662,10 +663,19 @@ def fn(x, y):
662663
x = torch.randn(5, dtype=dtype, requires_grad=True)
663664
y1 = 1.00048828125
664665
y2 = 1.00048828126
666+
y3 = 1.00048828127
665667

666668
self.assertEqual(fn_opt(x, y1), fn(x, y1))
667669
self.assertEqual(fn_opt(x, y2), fn(x, y2))
668-
self.assertEqual(cnt.frame_count, 1)
670+
self.assertEqual(fn_opt(x, y3), fn(x, y3))
671+
if i == 0:
672+
# This is kind of quirky part of automatic dynamic,
673+
# since it just uses source name + tx.f_code as the key
674+
# subsequent recompilations will actually reuse the automatic
675+
# dynamic choices.
676+
self.assertEqual(cnt.frame_count, 2)
677+
else:
678+
self.assertEqual(cnt.frame_count, 1)
669679

670680
@torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False)
671681
def test_unspec_float_input_f64(self):

test/dynamo/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns):
270270
'runtime_cudagraphify_time_us': None,
271271
'runtime_triton_autotune_time_us': None,
272272
'shape_env_guard_count': 0,
273-
'specialize_float': True,
273+
'specialize_float': False,
274274
'start_time': 0.0001,
275275
'start_time_us': 100,
276276
'structured_logging_overhead_s': 0.0,

test/inductor/test_cudagraph_trees.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,8 @@ def foo(x):
607607
@torch._functorch.config.patch("enable_autograd_cache", True)
608608
@torch._inductor.config.patch("fx_graph_cache", True)
609609
@torch._inductor.config.patch("fx_graph_remote_cache", False)
610+
# Currently fx graph cache is turned off for specialize_float=False
611+
@torch._dynamo.config.patch("specialize_float", True)
610612
def test_cache_hit_forward_miss_backward(self):
611613
# Test that we don't cache cudagraphs, skipping cudagraphs on backward on a cache miss
612614

@@ -661,6 +663,8 @@ def foo(x):
661663
@torch._functorch.config.patch("enable_autograd_cache", True)
662664
@torch._inductor.config.patch("fx_graph_cache", True)
663665
@torch._inductor.config.patch("fx_graph_remote_cache", False)
666+
# Currently fx graph cache is turned off for specialize_float=False
667+
@torch._dynamo.config.patch("specialize_float", True)
664668
def test_backward_gets_cached_cudagraphs(self):
665669
# We pass cpu tensors to foo and save that into the cache
666670
# On a subsequent run in a new process, cudagraphs should be
@@ -705,6 +709,8 @@ def foo(x):
705709
@torch._functorch.config.patch("enable_autograd_cache", True)
706710
@torch._inductor.config.patch("fx_graph_cache", True)
707711
@torch._inductor.config.patch("fx_graph_remote_cache", False)
712+
# Currently fx graph cache is turned off for specialize_float=False
713+
@torch._dynamo.config.patch("specialize_float", True)
708714
def test_cached_forward_backward(self):
709715
counters.clear()
710716
AOTAutogradCache.clear()

test/inductor/test_torchinductor_dynamic_shapes.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ def test_unspecialized_float_operations(self):
969969
"divide": operator.truediv,
970970
}
971971

972-
for name, op in operations.items():
972+
for i, (name, op) in enumerate(operations.items()):
973973
with self.subTest(operation=name):
974974

975975
def fn(x, y):
@@ -981,7 +981,14 @@ def fn(x, y):
981981
x = torch.arange(3)
982982
self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0))
983983
self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0))
984-
self.assertEqual(cnt.frame_count, 1)
984+
self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0))
985+
if i == 0:
986+
# Automatic dynamic state persists across
987+
# compiles so only the first compile
988+
# goes through the automatic dynamic step.
989+
self.assertEqual(cnt.frame_count, 2)
990+
else:
991+
self.assertEqual(cnt.frame_count, 1)
985992

986993
@torch._dynamo.config.patch(specialize_float=False)
987994
def test_unspecialized_float_fallback_specialization(self):
@@ -1005,8 +1012,25 @@ def fn(x, y, z):
10051012
self.assertEqual(fn(x, 2.0, z), fn_opt(x, 2.0, z))
10061013
self.assertEqual(fn(x, 3.0, z), fn_opt(x, 3.0, z))
10071014
self.assertEqual(fn(x, 4.0, z), fn_opt(x, 4.0, z))
1008-
# We expect frame count to be 2 since we will have
1009-
# one sledgehammer restart.
1015+
# Automatic dynamic float arguments
1016+
self.assertEqual(cnt.frame_count, 2)
1017+
1018+
def test_unspecialized_float_softshrink(self):
1019+
# This test is particularly interesting since it exercises
1020+
# both standard operator replacements ie. torch.ops.aten.mul.Tensor
1021+
# as well as comparison replacements ie. torch.ops.aten.ge.Scalar
1022+
def fn(x, y):
1023+
return torch._C._nn.softshrink(x, lambd=y)
1024+
1025+
cnt = CompileCounterWithBackend("inductor")
1026+
fn_opt = torch._dynamo.optimize(cnt)(fn)
1027+
x = torch.randn(5, 5)
1028+
1029+
print(fn(x, 2.0), fn_opt(x, 2.0))
1030+
1031+
self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0))
1032+
self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0))
1033+
self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0))
10101034
self.assertEqual(cnt.frame_count, 2)
10111035

10121036
@torch._dynamo.config.patch(specialize_float=False)
@@ -1021,8 +1045,7 @@ def fn(x, y):
10211045
self.assertEqual(fn(2.0, y), fn_opt(2.0, y))
10221046
self.assertEqual(fn(3.0, y), fn_opt(3.0, y))
10231047
self.assertEqual(fn(4.0, y), fn_opt(4.0, y))
1024-
# We expect frame count to be N + 1 since we will have
1025-
# one sledgehammer restart for the first compile.
1048+
# N + 1 for automatic dynamic float arguments
10261049
self.assertEqual(cnt.frame_count, 4)
10271050

10281051
def test_sort_dynamic_shape_with_check(self, device):

torch/_dynamo/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
# Whether or not to specialize on float inputs. Dynamo will always promote
6666
# float inputs into Tensor inputs, but at the moment, backends inconsistently
6767
# support codegen on float (this is to be fixed).
68-
specialize_float = True
68+
specialize_float = True if is_fbcode() else False
6969

7070
# legacy config, does nothing now!
7171
dynamic_shapes = True

torch/_dynamo/variables/builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,6 +1904,13 @@ def wrap_symfloat(self, value):
19041904
if self.name in self.tx.output.unspec_variable_map:
19051905
return self.tx.output.unspec_variable_map[self.name]
19061906

1907+
frame_state_entry = process_automatic_dynamic(
1908+
self.tx,
1909+
self.source.name(),
1910+
FrameStateSizeEntry.make_scalar(value),
1911+
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
1912+
)
1913+
19071914
# NB: we specialize on nan input, because our guard modeling in
19081915
# ShapeEnv cannot deal with nan
19091916
if (
@@ -1918,6 +1925,7 @@ def wrap_symfloat(self, value):
19181925
# python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950
19191926
or torch._inductor.config.triton.cudagraphs
19201927
or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False)
1928+
or frame_state_entry.scalar is not auto_dynamic
19211929
):
19221930
self.install_guards(GuardBuilder.CONSTANT_MATCH)
19231931
return ConstantVariable.create(value=value, source=self.source)

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,6 @@ def check_cacheable(gm: torch.fx.GraphModule):
205205
"Cannot cache a graph with compiled autograd enabled"
206206
)
207207

208-
if not torch._dynamo.config.specialize_float:
209-
raise BypassAOTAutogradCache(
210-
"Cannot cache a graph with specialize_float disabled"
211-
)
212-
213208
if not (
214209
torch._inductor.config.fx_graph_cache or should_use_remote_fx_graph_cache()
215210
):

torch/fx/passes/_tensorify_python_scalars.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,16 @@
7373

7474

7575
SUPPORTED_OPS = {
76-
torch.ops.aten.mul.Tensor,
77-
torch.ops.aten.add.Tensor,
78-
torch.ops.aten.sub.Tensor,
79-
torch.ops.aten.div.Tensor,
76+
torch.ops.aten.mul.Tensor: torch.ops.aten.mul.Tensor,
77+
torch.ops.aten.add.Tensor: torch.ops.aten.add.Tensor,
78+
torch.ops.aten.sub.Tensor: torch.ops.aten.sub.Tensor,
79+
torch.ops.aten.div.Tensor: torch.ops.aten.div.Tensor,
80+
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
81+
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
82+
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
83+
torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
84+
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
85+
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
8086
}
8187

8288

@@ -232,7 +238,9 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
232238
should_restart = True
233239

234240
# Look for functions to convert
235-
if node.op == "call_function" and node.target in SUPPORTED_OPS:
241+
if node.op == "call_function" and (
242+
replacement_op := SUPPORTED_OPS.get(node.target)
243+
):
236244
args: List[Any] = []
237245
transform = False
238246
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
@@ -253,7 +261,13 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
253261
# We use _expr instead of expr b/c we want the symbol not the replacement
254262
tensorified_symbols.add(a.meta["val"].node._expr)
255263

256-
if proxy.node.meta["val"].dtype != compute_dtype:
264+
# The upcasting is irrelevant when the compute dtype is bool. This happens
265+
# in cases where we are tensorifying a comparison operator such as
266+
# torch.ops.aten.gt.Tensor
267+
if (
268+
compute_dtype != torch.bool
269+
and proxy.node.meta["val"].dtype != compute_dtype
270+
):
257271
proxy = torch.ops.prims.convert_element_type.default(
258272
proxy, compute_dtype
259273
)
@@ -265,7 +279,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
265279
args.append(a)
266280

267281
if transform:
268-
replacement_proxy = node.target(*args)
282+
replacement_proxy = replacement_op(*args)
269283

270284
if compute_dtype != node.meta["val"].dtype:
271285
replacement_proxy = (

0 commit comments

Comments
 (0)