Skip to content

Commit 9c506aa

Browse files
ColinPepplerpytorchmergebot
authored andcommitted
[aotinductor] add option to disable runtime assertions (pytorch#146462)
A recent user experience is like this: * User runs AOTI lowering, it's successful. * They take AOTI model and run it with some sample inputs. Everything runs well * Then they boot up a serving test that loads the AOTI model and runs it with a set of sample requests. * They see that some of the requests fail. The logs show them this: * AOTInductorModel run failed with input spec: [1, 32]:c10::BFloat16, [2]:long ... * Error: u45 >= 2 * To the untrained eye, "AOTInductorModel run failed" is all they see. But, the true reason is Error: u45 >= 2 However, the assertion isn't always correct. * In fact, u45 can actually be 0. * So, why did AOTI say u45 ≥ 2? It's a two-piece combo: * With 0/1 Specialization, the ShapeEnv creates symbolic shapes (e.g. s0) with a default value-range of [2, inf] * In the graph, Dynamo traces torch.mul(A, B) where A is [s0, ...]and B is [u45, ...]. So, Dynamo learns Eq(s0, u45). * Therefore, u45 also has a range of [2, inf]. Hence, the incorrect runtime assertion. So, the motivation for this PR is to add an option to disable the logging. If you run into a situation like this. However, another way to avoid this is to call `mark_unbacked()` on all the dynamic dims. @diff-train-skip-merge Pull Request resolved: pytorch#146462 Approved by: https://github.com/desertfire, https://github.com/22quinn
1 parent 26358fa commit 9c506aa

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

test/inductor/test_aot_inductor.py

+6
Original file line numberDiff line numberDiff line change
@@ -4141,6 +4141,12 @@ def forward(self, a, b, c):
41414141
unexpected_inputs = (torch.ones(0, device=self.device), b, c)
41424142
compiled(*unexpected_inputs)
41434143

4144+
# Try it again without runtime assertions.
4145+
with config.patch({"scalar_asserts": False}):
4146+
AOTIRunnerUtil.run_multiple(
4147+
self.device, model, [example_inputs, unexpected_inputs]
4148+
)
4149+
41444150
def test_none_args_aot_codegen(self):
41454151
if self.device != GPU_TYPE:
41464152
raise unittest.SkipTest("requires GPU")

torch/_inductor/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def prologue_fusion_enabled() -> bool:
148148
# put correctness assertions in generated code
149149
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
150150
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
151+
scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
151152

152153
# enable loop reordering based on input orders
153154
pick_loop_orders = True

torch/_inductor/ir.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6333,14 +6333,16 @@ def get_unbacked_symbol_uses(self): # type: ignore[no-untyped-def]
63336333
return free_unbacked_symbols(self.scalar)
63346334

63356335
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
6336+
if not config.scalar_asserts:
6337+
return
63366338
# NB: It is EXTREMELY important not to simplify the scalar under assertion here,
63376339
# because simplify is done with respect to runtime asserts. So if you have
63386340
# "u0 == 0" in the runtime asserts, if you subsequently try to
63396341
# simplify(u0 == 0), you will get True (because we've already runtime assert'ed
63406342
# that it's true). But we're code generating the actual runtime assert here!!
63416343
symbol = next(iter(self.get_unbacked_symbol_uses()))
6342-
symbol_str = f"std::to_string({symbol})"
63436344
if V.graph.cpp_wrapper:
6345+
symbol_str = f"std::to_string({symbol})"
63446346
sizevar = V.graph.wrapper_code.codegen_cpp_sizevar(
63456347
self.scalar, simplify=False
63466348
)

0 commit comments

Comments
 (0)