Skip to content

Commit 49288c3

Browse files
Damian Reevesfacebook-github-bot
authored andcommitted
Narrow scaleup probes to max cache sharding span (#2588)
Summary: Pull Request resolved: #2588 When a lot of cache scaleup budget is available, significantly larger than the total amount of memory needed to promote every table to HBM, it's likely that many of the budget probes will attempt to cost a plan using more budget than the proposal can utilize. In these scenarios, we tend to see only two distinct plan costs, 1) the min-working-set plan which is costed first, 2) every other proposal "clips" at the max scaleup limit (i.e. everything promoted to HBM). It's also plausible that the fully-promoted plan is more expensive than the min-working-set plan, due to the increased bin-packing difficulty of fitting the larger shards. In these cases, the job only runs on the min-working-set proposal, even though (lots of) additional memory is available for larger caches (up to the point of diminishing returns due to bin-packing overhead). This diff narrows the search region, when more memory is available than we can use, to focus our search effort on productive portions of the search space. This increases the likelihood we discover a plan that is both cheaper than min-working-set or fully-promoted. Reviewed By: keyan Differential Revision: D66419942 fbshipit-source-id: 8d5ad8b70179517193fa88e9acc041ffb171b822
1 parent caa6773 commit 49288c3

File tree

2 files changed

+96
-17
lines changed

2 files changed

+96
-17
lines changed

torchrec/distributed/planner/proposers.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import logging
1313
from collections import OrderedDict
1414
from decimal import Decimal
15-
from typing import cast, Dict, List, Optional, Set, Tuple, Union
15+
from typing import cast, Dict, List, Optional, Set, Tuple, TypeVar, Union
1616

1717
import torch
1818

@@ -460,6 +460,15 @@ def feedback(
460460
self._current_proposal = -1
461461

462462

463+
_T = TypeVar("_T")
464+
465+
466+
def _none_throws(x: Optional[_T]) -> _T:
467+
if x is None:
468+
raise AssertionError("unexpected None")
469+
return x
470+
471+
463472
class EmbeddingOffloadScaleupProposer(Proposer):
464473
def __init__(self, use_depth: bool = True) -> None:
465474
self.use_depth: bool = use_depth
@@ -535,6 +544,26 @@ def load(
535544
)
536545
self.proposal = copy.deepcopy(self.starting_proposal)
537546

547+
@staticmethod
548+
def get_hbm_ceiling(
549+
starting_proposal: List[ShardingOption], enumerator: Enumerator
550+
) -> int:
551+
"""returns total amount of memory scaleup could use."""
552+
proposal = copy.deepcopy(starting_proposal)
553+
cache_tables = EmbeddingOffloadScaleupProposer.get_scalable_sharding_options(
554+
proposal
555+
)
556+
for sharding_option in cache_tables:
557+
if (
558+
sharding_option.compute_kernel
559+
== EmbeddingComputeKernel.FUSED_UVM_CACHING.value
560+
):
561+
assert sharding_option.cache_params # appease pyre
562+
sharding_option.cache_params.load_factor = None
563+
sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value
564+
enumerator.populate_estimates(cache_tables)
565+
return sum(sharding_option.total_storage.hbm for sharding_option in proposal)
566+
538567
@staticmethod
539568
def promote_high_prefetch_overheaad_table_to_hbm(
540569
enumerator: Optional[Enumerator], proposal: List[ShardingOption]
@@ -621,11 +650,20 @@ def feedback(
621650
hbm_available = EmbeddingOffloadScaleupProposer.get_budget(
622651
plan, storage_constraint
623652
)
653+
# max scale up
654+
peak_budget_need = (
655+
EmbeddingOffloadScaleupProposer.get_hbm_ceiling(
656+
plan, _none_throws(self.enumerator)
657+
)
658+
- hbm_used_previously
659+
)
660+
search_budget = min(hbm_available, peak_budget_need)
661+
624662
logger.info(
625-
f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB"
663+
f"EmbeddingOffloadScaleupProposer - unscaled plan={round(bytes_to_gb(hbm_used_previously),2)} GB, cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, peak scale up budget need={round(bytes_to_gb(peak_budget_need),2)} GB, exploring plans of size [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + search_budget), 2)}] GB"
626664
)
627665
self.search = LuusJaakolaSearch(
628-
0, hbm_available, max_iterations=16, left_cost=perf_rating
666+
0, search_budget, max_iterations=16, left_cost=perf_rating
629667
)
630668

631669
logger.info(
@@ -663,23 +701,16 @@ def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) ->
663701
)
664702
return available_hbm - used_hbm
665703

666-
# Given an available budget of additional memory, and a provisional sharding plan,
667-
# attempt to use the budget wisely to scale up caches that would most benefit from it.
668704
@staticmethod
669-
def next_plan(
670-
starting_proposal: List[ShardingOption],
671-
budget: Optional[int],
672-
enumerator: Optional[Enumerator],
673-
) -> Optional[List[ShardingOption]]:
674-
if budget is None or enumerator is None:
675-
return None
705+
def get_scalable_sharding_options(
706+
proposal: List[ShardingOption],
707+
) -> List[ShardingOption]:
708+
"""Return the subset of tables that we can scale."""
676709

677710
def none_to_zero(x: Optional[float]) -> float:
678711
return x if x is not None else 0.0
679712

680-
proposal = copy.deepcopy(starting_proposal)
681-
# This is the subset of tables that we can scale
682-
cache_tables = [
713+
return [
683714
sharding_option
684715
for sharding_option in proposal
685716
if sharding_option.compute_kernel
@@ -693,6 +724,26 @@ def none_to_zero(x: Optional[float]) -> float:
693724
* none_to_zero(sharding_option.cache_load_factor)
694725
> 0
695726
]
727+
728+
# Given an available budget of additional memory, and a provisional sharding plan,
729+
# attempt to use the budget wisely to scale up caches that would most benefit from it.
730+
@staticmethod
731+
def next_plan(
732+
starting_proposal: List[ShardingOption],
733+
budget: Optional[int],
734+
enumerator: Optional[Enumerator],
735+
) -> Optional[List[ShardingOption]]:
736+
if budget is None or enumerator is None:
737+
return None
738+
739+
def none_to_zero(x: Optional[float]) -> float:
740+
return x if x is not None else 0.0
741+
742+
proposal = copy.deepcopy(starting_proposal)
743+
# This is the subset of tables that we can scale
744+
cache_tables = EmbeddingOffloadScaleupProposer.get_scalable_sharding_options(
745+
proposal
746+
)
696747
# Nothing to scale
697748
if len(cache_tables) == 0:
698749
return None

torchrec/distributed/planner/tests/test_proposers.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,34 @@ def test_dynamic_programming_three_table(self) -> None:
472472
num_proposals += 1
473473
self.assertEqual(2, num_proposals)
474474

475+
def test_get_scalable_sharding_options(self) -> None:
476+
def make_so(
477+
name: str, clf: Optional[float], stats: Optional[CacheStatistics]
478+
) -> ShardingOption:
479+
so = make_sharding_option(name, 1, clf)
480+
if clf:
481+
assert so.cache_params
482+
so.cache_params.stats = stats
483+
return so
484+
485+
proposal = [
486+
make_so("fused", None, None),
487+
make_so("caching-no-stats", 0.5, None),
488+
make_so(
489+
"caching-stats",
490+
0.5,
491+
MockCacheStatistics(expected_lookups=1, cacheability=0.42),
492+
),
493+
make_so(
494+
"caching-stats-no-data",
495+
0,
496+
MockCacheStatistics(expected_lookups=0, cacheability=0),
497+
),
498+
]
499+
got = EmbeddingOffloadScaleupProposer.get_scalable_sharding_options(proposal)
500+
want = [proposal[-2]]
501+
self.assertEqual(got, want)
502+
475503
def test_allocate_budget(self) -> None:
476504
model = torch.tensor([[1.0, 0.0], [2.0, 3.0], [4.0, 5.0]])
477505
got = EmbeddingOffloadScaleupProposer.clf_to_bytes(
@@ -823,7 +851,7 @@ def test_budget_shrink(self, _) -> None:
823851
if initial_mem is None:
824852
initial_mem = mem
825853
# Budget given constraints:
826-
# cache scale up budget=92.53 GB, exploring [7.47, 100.0] GB
854+
# unscaled plan=7.47 GB, cache scale up budget=92.53 GB, peak scale up budget need=67.06 GB, exploring plans of size [7.47, 74.53] GB
827855
#
828856
# Simple perf model, assume partitioner gives a lowest score at 7.9GB, and
829857
# anything larger than 8GB fails to partition. This is very hard to hit when
@@ -845,7 +873,7 @@ def test_budget_shrink(self, _) -> None:
845873
self.assertEqual(proposals, 16)
846874
self.assertNotEqual(initial_mem, best_plan, "couldn't find a better plan")
847875
# goal is 7.9, we get very close
848-
self.assertEqual(best_plan, 7.960684550926089 * GB)
876+
self.assertEqual(best_plan, 7.9028974287211895 * GB)
849877

850878
def test_proposers_to_proposals_list(self) -> None:
851879
def make_mock_proposal(name: str) -> List[ShardingOption]:

0 commit comments

Comments
 (0)