@@ -1506,12 +1506,9 @@ def dp_knapsack(
1506
1506
1507
1507
1508
1508
def _optimize_runtime_with_given_memory (
1509
- joint_graph : fx .Graph ,
1510
1509
memory : List [float ],
1511
1510
runtimes : List [float ],
1512
1511
max_memory : float ,
1513
- node_info : NodeInfo ,
1514
- all_recomputable_banned_nodes : List [fx .Node ],
1515
1512
) -> Tuple [float , List [int ], List [int ]]:
1516
1513
SOLVER = config .activation_memory_budget_solver
1517
1514
if SOLVER == "greedy" :
@@ -1520,11 +1517,6 @@ def _optimize_runtime_with_given_memory(
1520
1517
return ilp_knapsack (memory , runtimes , max_memory )
1521
1518
elif SOLVER == "dp" :
1522
1519
return dp_knapsack (memory , runtimes , max_memory )
1523
- elif callable (SOLVER ):
1524
- saved_node_idx , recomp_node_idx = SOLVER (
1525
- memory , joint_graph , max_memory , node_info , all_recomputable_banned_nodes
1526
- )
1527
- return (0.0 , saved_node_idx , recomp_node_idx )
1528
1520
else :
1529
1521
raise RuntimeError (f"Not aware of memory budget knapsack solver: { SOLVER } " )
1530
1522
@@ -1580,9 +1572,7 @@ def realize_symbol(d):
1580
1572
1581
1573
1582
1574
def choose_saved_values_set (
1583
- joint_graph : fx .Graph ,
1584
- node_info : NodeInfo ,
1585
- memory_budget = 1 ,
1575
+ joint_graph : fx .Graph , node_info : NodeInfo , memory_budget = 1
1586
1576
) -> List [fx .Node ]:
1587
1577
if memory_budget > 1 or memory_budget < 0 :
1588
1578
raise RuntimeError (
@@ -1690,28 +1680,18 @@ def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]:
1690
1680
]
1691
1681
from torch .utils ._mode_utils import no_dispatch
1692
1682
1693
- def get_saved_values_knapsack (memory_budget , node_info , joint_graph ):
1683
+ def get_saved_values_knapsack (memory_budget ):
1694
1684
with no_dispatch ():
1695
1685
(
1696
1686
expected_runtime ,
1697
1687
saved_node_idxs ,
1698
1688
recomputable_node_idxs ,
1699
1689
) = _optimize_runtime_with_given_memory (
1700
- joint_graph ,
1701
- memories_banned_nodes ,
1702
- runtimes_banned_nodes ,
1703
- max (memory_budget , 0 ),
1704
- node_info ,
1705
- all_recomputable_banned_nodes ,
1690
+ memories_banned_nodes , runtimes_banned_nodes , max (memory_budget , 0 )
1706
1691
)
1707
1692
dont_ban = set ()
1708
1693
for idx in recomputable_node_idxs :
1709
- # if idx in all_recomputable_banned_nodes:
1710
- try :
1711
- dont_ban .add (all_recomputable_banned_nodes [idx ])
1712
- except :
1713
- pass
1714
-
1694
+ dont_ban .add (all_recomputable_banned_nodes [idx ])
1715
1695
assert dont_ban .issubset (all_recomputable_banned_nodes )
1716
1696
1717
1697
saved_values , _ = solve_min_cut (
@@ -1726,7 +1706,7 @@ def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
1726
1706
options = []
1727
1707
for sweep_memory_budget in range (100 , - 1 , - 5 ):
1728
1708
saved_values , expected_runtime = get_saved_values_knapsack (
1729
- sweep_memory_budget / 100 , node_info = node_info , joint_graph = joint_graph
1709
+ sweep_memory_budget / 100
1730
1710
)
1731
1711
options .append (
1732
1712
(
@@ -1771,7 +1751,7 @@ def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
1771
1751
# tensors we actually banned from recompute, but there may be other
1772
1752
# tensors that we choose to save.
1773
1753
1774
- return get_saved_values_knapsack (memory_budget = memory_budget , node_info = node_info , joint_graph = joint_graph )[0 ]
1754
+ return get_saved_values_knapsack (memory_budget = memory_budget )[0 ]
1775
1755
1776
1756
1777
1757
def min_cut_rematerialization_partition (
@@ -1897,9 +1877,7 @@ def classify_nodes(joint_module):
1897
1877
break
1898
1878
# print("Memory Budget: ", memory_budget)
1899
1879
saved_values = choose_saved_values_set (
1900
- joint_graph ,
1901
- node_info ,
1902
- memory_budget = memory_budget ,
1880
+ joint_graph , node_info , memory_budget = memory_budget
1903
1881
)
1904
1882
# save_for_backward on tensors and stashes symints in autograd .ctx
1905
1883
saved_sym_nodes = list (filter (is_sym_node , saved_values ))
@@ -1921,14 +1899,10 @@ def classify_nodes(joint_module):
1921
1899
bw_module = reordering_to_mimic_autograd_engine (bw_module )
1922
1900
1923
1901
if AOT_PARTITIONER_DEBUG :
1924
- from torch ._inductor .fx_utils import get_node_storage
1925
-
1926
- storages = {get_node_storage (node ) for node in saved_values }
1927
1902
print (
1928
1903
"Theoretical Activations Stored: " ,
1929
1904
sum (_size_of (i ) for i in saved_values ) / 1e9 ,
1930
1905
)
1931
- sorted_sizes = sorted ([(_size_of (i ), str (i )) for i in saved_values ])
1932
1906
fw_module_nodes = {
1933
1907
node .name for node in fw_module .graph .nodes if node .op == "call_function"
1934
1908
}
0 commit comments