Skip to content

Commit a31e14c

Browse files
kausvfacebook-github-bot
authored andcommitted
Proportional Uneven RW Inference Sharding (#2734)
Summary: Pull Request resolved: #2734 Support bucketization aware inference sharding in TGIF for ZCH bucket boundaries from training. A "best effort" sharding is performed across bucket boundaries proportional to memory list. * Added bucketization awareness to RW sharding, * TGIF sharding now ensures at most 1 bucket difference across equal memory uneven shards as opposed to previous logic of remainder rows to last shard * InferRWSparseDist checks for customized embedding_shard_metadata for uneven shards before dividing evenly Reviewed By: dstaay-fb, emlin Differential Revision: D69057627 fbshipit-source-id: 4f813a6a621bed9df31f26d28ea3c2379f6d3ea6
1 parent 23cf189 commit a31e14c

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

torchrec/distributed/quant_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ def sharded_tbes_weights_spec(
441441
shard_sizes: List[int] = [table.local_rows, table.local_cols]
442442
shard_offsets: List[int] = table_metadata.shard_offsets
443443
s: str = "embedding_bags" if is_sqebc else "embeddings"
444+
s = ("_embedding_module." if is_sqmcec else "") + s
444445
unsharded_fqn_weight: str = f"{module_fqn}.{s}.{table_name}.weight"
445446

446447
sharded_fqn_weight: str = (

torchrec/distributed/sharding/rw_sharding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,11 +659,13 @@ def __init__(
659659
self._world_size: int = world_size
660660
self._num_features = num_features
661661
self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets
662-
663662
self.feature_block_sizes: List[int] = []
664663
for i, hash_size in enumerate(feature_hash_sizes):
665664
block_divisor = self._world_size
666-
if feature_total_num_buckets is not None:
665+
if (
666+
feature_total_num_buckets is not None
667+
and embedding_shard_metadata is None
668+
):
667669
assert feature_total_num_buckets[i] % self._world_size == 0
668670
block_divisor = feature_total_num_buckets[i]
669671
self.feature_block_sizes.append(

0 commit comments

Comments
 (0)