Skip to content

Commit 371fc53

Browse files
Damian Reevesfacebook-github-bot
authored andcommitted
Reclaim prefetch promotion budget and reapply during iterative scaleup (#2590)
Summary: Pull Request resolved: #2590 In D56505315 we promote tables that consume less HBM when not using UVM_CACHING. This can happen when the input is large and the overheads of calculating the uniques/populating the cache (around 7x input size, or (1+6/12)x when using multi-pass prefetch, see shard_estimators.py:calculate_pipeline_io_cost) dominate the saving from having CLF < 1.0. It runs the promotion both on the starting proposal (using min-working-set), and after the proposed scaleup. In the second run, because it runs after scaleup has completed, the saved memory is "wasted". In this diff, we integrate the promotion logic directly into the interactive scaleup, so any memory saved via promotion is available to further scale hard to cache tables. This can improve plan quality. We still keep the original implementation to run on the initial starting proposal. It's possible the starting proposal is not partitionable with the configured storage reservation without this initial promotion step. Removing this would cause the planner to fail and never reach the scaleup. The net result is that we'll try to use all of the probe budget, rather undershooting, when tables have large I/O prefetch costs. Reviewed By: keyan Differential Revision: D66435139 fbshipit-source-id: 27faf36542266d323d5747280f4c1053b610cdc6
1 parent 49288c3 commit 371fc53

File tree

2 files changed

+89
-24
lines changed

2 files changed

+89
-24
lines changed

torchrec/distributed/planner/proposers.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import logging
1313
from collections import OrderedDict
1414
from decimal import Decimal
15-
from typing import cast, Dict, List, Optional, Set, Tuple, TypeVar, Union
15+
from typing import Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar, Union
1616

1717
import torch
1818

@@ -687,10 +687,6 @@ def feedback(
687687
self.proposal = EmbeddingOffloadScaleupProposer.next_plan(
688688
self.starting_proposal, budget, self.enumerator
689689
)
690-
if self.proposal is not None:
691-
self.promote_high_prefetch_overheaad_table_to_hbm(
692-
self.enumerator, self.proposal
693-
)
694690

695691
@staticmethod
696692
def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) -> int:
@@ -748,8 +744,10 @@ def none_to_zero(x: Optional[float]) -> float:
748744
if len(cache_tables) == 0:
749745
return None
750746

751-
size_model = EmbeddingOffloadScaleupProposer.build_affine_storage_model(
752-
cache_tables, enumerator
747+
size_model, fused_hbm_ceiling = (
748+
EmbeddingOffloadScaleupProposer.build_affine_storage_model(
749+
cache_tables, enumerator
750+
)
753751
)
754752
clfs = torch.tensor(
755753
[sharding_option.cache_load_factor for sharding_option in cache_tables]
@@ -772,6 +770,7 @@ def none_to_zero(x: Optional[float]) -> float:
772770
)
773771
new_clfs = EmbeddingOffloadScaleupProposer.allocate_budget(
774772
model=size_model,
773+
fused_hbm_ceiling=fused_hbm_ceiling,
775774
clfs=clfs,
776775
budget=budget,
777776
allocation_priority=cooked_cacheability,
@@ -788,9 +787,10 @@ def none_to_zero(x: Optional[float]) -> float:
788787
sharding_option.cache_params.load_factor = None
789788
sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value
790789
num_promoted += 1
791-
logger.info(
792-
f"EmbeddingOffloadScaleupProposer - Promoted {num_promoted} tables to HBM because cache size is similar to table size."
793-
)
790+
if num_promoted > 0:
791+
logger.info(
792+
f"EmbeddingOffloadScaleupProposer - Promoted {num_promoted} tables to HBM because cache size is similar to table size."
793+
)
794794
# recalculate cost estimates of modified tables
795795
enumerator.populate_estimates(cache_tables)
796796
return proposal
@@ -822,32 +822,42 @@ def get_expected_lookups(sharding_option: ShardingOption) -> Optional[float]:
822822
@staticmethod
823823
def build_affine_storage_model(
824824
uvm_caching_sharding_options: List[ShardingOption], enumerator: Enumerator
825-
) -> torch.Tensor:
825+
) -> Tuple[torch.Tensor, torch.Tensor]:
826826
plan: List[ShardingOption] = copy.deepcopy(uvm_caching_sharding_options)
827827

828-
def compute_hbm_sizes(clf: float) -> torch.Tensor:
828+
def set_clf(sharding_option: ShardingOption, clf: float) -> None:
829+
assert sharding_option.cache_params # appease pyre
830+
sharding_option.cache_params.load_factor = clf
831+
832+
def set_fused(sharding_option: ShardingOption) -> None:
833+
assert sharding_option.cache_params # appease pyre
834+
sharding_option.cache_params.load_factor = None
835+
sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value
836+
837+
def compute_hbm_sizes(f: Callable[[ShardingOption], None]) -> torch.Tensor:
829838
for sharding_option in plan:
830-
assert sharding_option.cache_params # appease pyre
831-
sharding_option.cache_params.load_factor = clf
839+
f(sharding_option)
832840
enumerator.populate_estimates(plan)
833841
return torch.tensor(
834842
[sharding_option.total_storage.hbm for sharding_option in plan]
835843
)
836844

837845
low_clf, high_clf = 0.1, 0.9
838-
low_hbms = compute_hbm_sizes(low_clf)
839-
high_hbms = compute_hbm_sizes(high_clf)
846+
low_hbms = compute_hbm_sizes(lambda so: set_clf(so, low_clf))
847+
high_hbms = compute_hbm_sizes(lambda so: set_clf(so, high_clf))
848+
fused_hbms = compute_hbm_sizes(set_fused)
840849

841850
A = (high_hbms - low_hbms) / (high_clf - low_clf)
842851
B = low_hbms - A * low_clf
843-
return torch.stack((A, B), dim=1) # Nx2 (a,b)
852+
caching_model = torch.stack((A, B), dim=1) # Nx2 (a,b)
853+
return caching_model, fused_hbms
844854

845855
@staticmethod
846856
def clf_to_bytes(
847857
model: torch.Tensor, clfs: Union[float, torch.Tensor]
848858
) -> torch.Tensor:
849859
# evaluate affine model AX + B
850-
return (model[:, 0] * clfs + model[:, 1]).to(torch.int64)
860+
return (model[:, 0] * clfs + model[:, 1]).to(torch.float64)
851861

852862
# Given a model of an affine system, an existing configuration (clfs), available
853863
# budget, and an allocation policy, return new configuration that best uses the
@@ -856,6 +866,7 @@ def clf_to_bytes(
856866
@staticmethod
857867
def allocate_budget(
858868
model: torch.Tensor,
869+
fused_hbm_ceiling: torch.Tensor,
859870
clfs: torch.Tensor,
860871
budget: int,
861872
allocation_priority: torch.Tensor,
@@ -882,7 +893,7 @@ def allocate_budget(
882893
if mask.sum() == 0:
883894
break
884895

885-
logging.debug(
896+
logger.debug(
886897
f"[allocate_budget] pass={num_pass}, budget={budget}, #cache_tables={mask.sum()}"
887898
)
888899

@@ -902,6 +913,21 @@ def allocate_budget(
902913
# to HBM vs spending that budget on improving hit rate on other tables in
903914
# next pass.
904915

916+
# Is any table over the size we'd get if we promoted to HBM? (promotion can
917+
# be smaller if input size is large when using prefetch). If so, mark for
918+
# promotion and reclaim budget to use on remaining tables.
919+
promotes = mask & (min_size_bytes + cache_size_bytes > fused_hbm_ceiling)
920+
if promotes.sum() > 0:
921+
budget_reclaimed = torch.sum(
922+
((min_size_bytes + cache_size_bytes) - fused_hbm_ceiling)[promotes]
923+
).item()
924+
logger.debug(
925+
f"[allocate_budget] {promotes.sum()} tables exceeded ceiling, promoting to save {budget_reclaimed} bytes"
926+
)
927+
budget += budget_reclaimed
928+
# force these tables to 1.0 to ensure promotion
929+
cache_size_bytes[promotes] = max_cache_size_bytes[promotes]
930+
905931
# cache_size_bytes are the new cache sizes we want to use. We convert them back
906932
# to clfs by dividing by max_cache_size_bytes, which has isolated the clf
907933
# portion of the table size from the fixed overheads.

torchrec/distributed/planner/tests/test_proposers.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -505,17 +505,19 @@ def test_allocate_budget(self) -> None:
505505
got = EmbeddingOffloadScaleupProposer.clf_to_bytes(
506506
model, torch.tensor([0, 0.5, 1])
507507
)
508-
torch.testing.assert_close(got, torch.tensor([0, 4, 9]))
508+
torch.testing.assert_close(got, torch.tensor([0, 4, 9], dtype=torch.float64))
509509

510510
# Scenario 1, enough budget to scale everything to 1.0
511511
model = torch.tensor(
512512
[[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]]
513513
)
514+
fused_hbm_ceiling = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0)
514515
mins = torch.tensor([0.1, 0.1, 1])
515516
budget = 100_000_000
516517
got = EmbeddingOffloadScaleupProposer.allocate_budget(
517518
model,
518-
clfs=torch.tensor(mins),
519+
fused_hbm_ceiling=fused_hbm_ceiling,
520+
clfs=mins,
519521
budget=budget,
520522
allocation_priority=torch.tensor([2, 2, 2]),
521523
)
@@ -530,10 +532,15 @@ def test_allocate_budget(self) -> None:
530532
model = torch.tensor(
531533
[[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]]
532534
)
535+
fused_hbm_ceiling = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0)
533536
mins = torch.tensor([0.1, 0.1, 1])
534537
budget = 10_000_000
535538
got = EmbeddingOffloadScaleupProposer.allocate_budget(
536-
model, clfs=mins, budget=budget, allocation_priority=torch.tensor([2, 2, 2])
539+
model,
540+
fused_hbm_ceiling=fused_hbm_ceiling,
541+
clfs=mins,
542+
budget=budget,
543+
allocation_priority=torch.tensor([2, 2, 2]),
537544
)
538545
torch.testing.assert_close(got, torch.tensor([0.26667, 0.26667, 1.0]))
539546
increase = (
@@ -546,10 +553,15 @@ def test_allocate_budget(self) -> None:
546553
model = torch.tensor(
547554
[[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]]
548555
)
556+
fused_hbm_ceiling = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0)
549557
mins = torch.tensor([0.1, 0.1, 1])
550558
budget = 10_000_000
551559
got = EmbeddingOffloadScaleupProposer.allocate_budget(
552-
model, clfs=mins, budget=budget, allocation_priority=torch.tensor([2, 4, 2])
560+
model,
561+
fused_hbm_ceiling=fused_hbm_ceiling,
562+
clfs=mins,
563+
budget=budget,
564+
allocation_priority=torch.tensor([2, 4, 2]),
553565
)
554566
# increase is twice as much for table 2 (started at 0.1)
555567
torch.testing.assert_close(
@@ -559,16 +571,18 @@ def test_allocate_budget(self) -> None:
559571
EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum()
560572
- EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum()
561573
)
562-
self.assertEqual(increase, budget)
574+
self.assertEqual(int(increase), budget)
563575

564576
# Scenario 4, multi-pass scale up
565577
model = torch.tensor(
566578
[[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]]
567579
)
580+
fused_hbm_ceiling = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0)
568581
mins = torch.tensor([0.1, 0.3, 0.5])
569582
budget = 50_000_000
570583
got = EmbeddingOffloadScaleupProposer.allocate_budget(
571584
model,
585+
fused_hbm_ceiling=fused_hbm_ceiling,
572586
clfs=mins,
573587
budget=budget,
574588
allocation_priority=torch.tensor([1, 2, 100]),
@@ -580,6 +594,31 @@ def test_allocate_budget(self) -> None:
580594
)
581595
self.assertEqual(increase, budget)
582596

597+
# Scenario 5, prefetch overhead causing early promotion
598+
# like scenario 4, but we set fused size to 80%, which saves enough memory
599+
# to promote all 3 to HBM inside the same budget.
600+
model = torch.tensor(
601+
[[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]]
602+
)
603+
fused_hbm_ceiling = (
604+
EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0) * 0.8
605+
)
606+
mins = torch.tensor([0.1, 0.3, 0.5])
607+
budget = 50_000_000
608+
got = EmbeddingOffloadScaleupProposer.allocate_budget(
609+
model,
610+
fused_hbm_ceiling=fused_hbm_ceiling,
611+
clfs=mins,
612+
budget=budget,
613+
allocation_priority=torch.tensor([1, 2, 100]),
614+
)
615+
torch.testing.assert_close(got, torch.tensor([1.0, 1.0, 1.0]))
616+
self.assertLessEqual(
617+
fused_hbm_ceiling.sum().item(),
618+
EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum().item()
619+
+ budget,
620+
)
621+
583622
@unittest.mock.patch(
584623
"torchrec.distributed.planner.shard_estimators._calculate_storage_specific_sizes",
585624
side_effect=mock_calculate_storage_specific_sizes,

0 commit comments

Comments
 (0)