Skip to content

Commit 31bb65d

Browse files
sijiacpytorchmergebot
authored andcommitted
[Inductor] Fix conditional codegen (pytorch#129492)
Summary: We have the cache to guarantee the `sym` is codegen only once, see the following code ``` def ensure_size_computed(self, sym: sympy.Symbol): if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): if sym in self.computed_sizes: return self.computed_sizes.add(sym) expr = V.graph.sizevars.inv_precomputed_replacements[sym] self.writeline( f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}" ) ``` However, we don't consider the case when same `sym`s need to be codegen in both conditions (true branch and false branch), which caused the issue of `undefined symbols`: P1441378833 To fix the issue, we use a stack to capture the state before doing the condition codegen and restore the state after doing the codegen Test Plan: TORCH_LOGS="+inductor" buck2 run mode/dev-nosan -c fbcode.nvcc_arch=h100 -c fbcode.enable_gpu_sections=true --config 'cxx.extra_cxxflags=-g1' -c fbcode.platform010_cuda_version=12 //scripts/hhh:repro_cond_torch_compile PYTORCH_TEST_FBCODE=1 TORCH_COMPILE_DEBUG=1 buck2 run mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true //caffe2/test/inductor:control_flow -- -r test_cond_control_flow_with_precomputed_size Differential Revision: D58973730 Pull Request resolved: pytorch#129492 Approved by: https://github.com/aakhundov
1 parent c5c9dbe commit 31bb65d

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

test/inductor/test_control_flow.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,42 @@ def test_cond_simple_control_flow(self, device, dynamic):
221221
dynamic=dynamic,
222222
)
223223

224+
@requires_gpu
225+
def test_cond_control_flow_with_precomputed_size(self):
226+
class TestModel(torch.nn.Module):
227+
def __init__(
228+
self,
229+
):
230+
super().__init__()
231+
self.conv2d = torch.nn.Conv2d(
232+
512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
233+
)
234+
self.threshold = 20
235+
236+
def forward(self, x: torch.Tensor, index) -> torch.Tensor:
237+
def true_fn(x: torch.Tensor):
238+
return self.conv2d(x)
239+
240+
def false_fn(x: torch.Tensor):
241+
return self.conv2d(x)
242+
243+
return torch.cond(
244+
index < self.threshold and index >= 0, true_fn, false_fn, (x,)
245+
)
246+
247+
main_model = TestModel().cuda()
248+
x1 = torch.rand(2, 512, 128, 72).cuda()
249+
x2 = torch.rand(2, 512, 96, 96).cuda()
250+
251+
opt_model = torch.compile(main_model)
252+
out1 = main_model(x1, 1)
253+
opt_out1 = opt_model(x1, 1)
254+
self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5))
255+
256+
out2 = main_model(x2, 30)
257+
opt_out2 = opt_model(x2, 30)
258+
self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5))
259+
224260
@requires_gpu
225261
@parametrize("device", ["cpu", GPU_TYPE])
226262
@parametrize("dynamic", [False, True])

torch/_inductor/codegen/wrapper.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ class EnterSubgraphLine(WrapperLine):
269269
wrapper: WrapperCodeGen
270270
graph: GraphLowering
271271

272+
def __post_init__(self) -> None:
273+
self.wrapper.push_computed_sizes(self.wrapper.computed_sizes)
274+
272275
def codegen(self, code: IndentedBuffer) -> None:
273276
self.wrapper.push_codegened_graph(self.graph)
274277
code.do_indent()
@@ -278,6 +281,9 @@ def codegen(self, code: IndentedBuffer) -> None:
278281
class ExitSubgraphLine(WrapperLine):
279282
wrapper: WrapperCodeGen
280283

284+
def __post_init__(self) -> None:
285+
self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes()
286+
281287
def codegen(self, code: IndentedBuffer) -> None:
282288
self.wrapper.pop_codegened_graph()
283289
code.do_unindent()
@@ -488,6 +494,7 @@ def __init__(self):
488494
# including the graph instance into a cache key to avoid cross-graph
489495
# caching during lowering of nested subgraphs
490496
self.codegened_graph_stack = []
497+
self.computed_sizes_stack = []
491498

492499
self.write_header()
493500
self.write_prefix()
@@ -680,6 +687,14 @@ def push_codegened_graph(self, graph):
680687
def pop_codegened_graph(self):
681688
return self.codegened_graph_stack.pop()
682689

690+
def push_computed_sizes(self, computed_sizes):
691+
from copy import deepcopy
692+
693+
return self.computed_sizes_stack.append(deepcopy(computed_sizes))
694+
695+
def pop_computed_sizes(self):
696+
return self.computed_sizes_stack.pop()
697+
683698
def next_kernel_suffix(self) -> str:
684699
return f"{next(self._names_iter)}"
685700

0 commit comments

Comments
 (0)