Skip to content

Commit 9bb327b

Browse files
Revert "[AC] Backward Pass Aware AC - adding hooks to partitioner to pass callable (pytorch#137785)"
This reverts commit a8b912f. Reverted pytorch#137785 on behalf of https://github.com/ezyang due to breaks lint ([comment](pytorch#137785 (comment)))
1 parent 02dd3b8 commit 9bb327b

File tree

1 file changed

+7
-33
lines changed

1 file changed

+7
-33
lines changed

torch/_functorch/partitioners.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,12 +1506,9 @@ def dp_knapsack(
15061506

15071507

15081508
def _optimize_runtime_with_given_memory(
1509-
joint_graph: fx.Graph,
15101509
memory: List[float],
15111510
runtimes: List[float],
15121511
max_memory: float,
1513-
node_info: NodeInfo,
1514-
all_recomputable_banned_nodes: List[fx.Node],
15151512
) -> Tuple[float, List[int], List[int]]:
15161513
SOLVER = config.activation_memory_budget_solver
15171514
if SOLVER == "greedy":
@@ -1520,11 +1517,6 @@ def _optimize_runtime_with_given_memory(
15201517
return ilp_knapsack(memory, runtimes, max_memory)
15211518
elif SOLVER == "dp":
15221519
return dp_knapsack(memory, runtimes, max_memory)
1523-
elif callable(SOLVER):
1524-
saved_node_idx, recomp_node_idx = SOLVER(
1525-
memory, joint_graph, max_memory, node_info, all_recomputable_banned_nodes
1526-
)
1527-
return (0.0, saved_node_idx, recomp_node_idx)
15281520
else:
15291521
raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}")
15301522

@@ -1580,9 +1572,7 @@ def realize_symbol(d):
15801572

15811573

15821574
def choose_saved_values_set(
1583-
joint_graph: fx.Graph,
1584-
node_info: NodeInfo,
1585-
memory_budget=1,
1575+
joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1
15861576
) -> List[fx.Node]:
15871577
if memory_budget > 1 or memory_budget < 0:
15881578
raise RuntimeError(
@@ -1690,28 +1680,18 @@ def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]:
16901680
]
16911681
from torch.utils._mode_utils import no_dispatch
16921682

1693-
def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
1683+
def get_saved_values_knapsack(memory_budget):
16941684
with no_dispatch():
16951685
(
16961686
expected_runtime,
16971687
saved_node_idxs,
16981688
recomputable_node_idxs,
16991689
) = _optimize_runtime_with_given_memory(
1700-
joint_graph,
1701-
memories_banned_nodes,
1702-
runtimes_banned_nodes,
1703-
max(memory_budget, 0),
1704-
node_info,
1705-
all_recomputable_banned_nodes,
1690+
memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0)
17061691
)
17071692
dont_ban = set()
17081693
for idx in recomputable_node_idxs:
1709-
# if idx in all_recomputable_banned_nodes:
1710-
try:
1711-
dont_ban.add(all_recomputable_banned_nodes[idx])
1712-
except:
1713-
pass
1714-
1694+
dont_ban.add(all_recomputable_banned_nodes[idx])
17151695
assert dont_ban.issubset(all_recomputable_banned_nodes)
17161696

17171697
saved_values, _ = solve_min_cut(
@@ -1726,7 +1706,7 @@ def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
17261706
options = []
17271707
for sweep_memory_budget in range(100, -1, -5):
17281708
saved_values, expected_runtime = get_saved_values_knapsack(
1729-
sweep_memory_budget / 100, node_info=node_info, joint_graph=joint_graph
1709+
sweep_memory_budget / 100
17301710
)
17311711
options.append(
17321712
(
@@ -1771,7 +1751,7 @@ def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
17711751
# tensors we actually banned from recompute, but there may be other
17721752
# tensors that we choose to save.
17731753

1774-
return get_saved_values_knapsack(memory_budget=memory_budget, node_info=node_info, joint_graph=joint_graph)[0]
1754+
return get_saved_values_knapsack(memory_budget=memory_budget)[0]
17751755

17761756

17771757
def min_cut_rematerialization_partition(
@@ -1897,9 +1877,7 @@ def classify_nodes(joint_module):
18971877
break
18981878
# print("Memory Budget: ", memory_budget)
18991879
saved_values = choose_saved_values_set(
1900-
joint_graph,
1901-
node_info,
1902-
memory_budget=memory_budget,
1880+
joint_graph, node_info, memory_budget=memory_budget
19031881
)
19041882
# save_for_backward on tensors and stashes symints in autograd .ctx
19051883
saved_sym_nodes = list(filter(is_sym_node, saved_values))
@@ -1921,14 +1899,10 @@ def classify_nodes(joint_module):
19211899
bw_module = reordering_to_mimic_autograd_engine(bw_module)
19221900

19231901
if AOT_PARTITIONER_DEBUG:
1924-
from torch._inductor.fx_utils import get_node_storage
1925-
1926-
storages = {get_node_storage(node) for node in saved_values}
19271902
print(
19281903
"Theoretical Activations Stored: ",
19291904
sum(_size_of(i) for i in saved_values) / 1e9,
19301905
)
1931-
sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values])
19321906
fw_module_nodes = {
19331907
node.name for node in fw_module.graph.nodes if node.op == "call_function"
19341908
}

0 commit comments

Comments
 (0)