Skip to content

Commit 464a0e9

Browse files
kausvfacebook-github-bot
authored andcommitted
Add ShardedQuantManagedCollisionEmbeddingCollection (#2649)
Summary: Pull Request resolved: #2649 Sharded MCEC is extended from Sharded EC to reuse the lookups of sharded embeddings. Reviewed By: emlin Differential Revision: D61388755 fbshipit-source-id: d222a9db8842ab3c5adc568d0083c53e768683ce
1 parent 5f607ff commit 464a0e9

8 files changed

+1064
-89
lines changed

torchrec/distributed/embedding_sharding.py

+18
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
4848
from torchrec.streamable import Multistreamable
4949

50+
5051
torch.fx.wrap("len")
5152

5253
CACHE_LOAD_FACTOR_STR: str = "cache_load_factor"
@@ -61,6 +62,15 @@ def _fx_wrap_tensor_to_device_dtype(
6162
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)
6263

6364

65+
@torch.fx.wrap
66+
def _fx_wrap_optional_tensor_to_device_dtype(
67+
t: Optional[torch.Tensor], tensor_device_dtype: torch.Tensor
68+
) -> Optional[torch.Tensor]:
69+
if t is None:
70+
return None
71+
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)
72+
73+
6474
@torch.fx.wrap
6575
def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]:
6676
return (
@@ -121,6 +131,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
121131
block_sizes: torch.Tensor,
122132
bucketize_pos: bool = False,
123133
block_bucketize_pos: Optional[List[torch.Tensor]] = None,
134+
total_num_blocks: Optional[torch.Tensor] = None,
124135
) -> Tuple[
125136
torch.Tensor,
126137
torch.Tensor,
@@ -142,6 +153,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
142153
bucketize_pos=bucketize_pos,
143154
sequence=True,
144155
block_sizes=block_sizes,
156+
total_num_blocks=total_num_blocks,
145157
my_size=num_buckets,
146158
weights=kjt.weights_or_none(),
147159
max_B=_fx_wrap_max_B(kjt),
@@ -289,6 +301,7 @@ def bucketize_kjt_inference(
289301
kjt: KeyedJaggedTensor,
290302
num_buckets: int,
291303
block_sizes: torch.Tensor,
304+
total_num_buckets: Optional[torch.Tensor] = None,
292305
bucketize_pos: bool = False,
293306
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
294307
is_sequence: bool = False,
@@ -303,6 +316,7 @@ def bucketize_kjt_inference(
303316
Args:
304317
num_buckets (int): number of buckets to bucketize the values into.
305318
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
319+
total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
306320
bucketize_pos (bool): output the changed position of the bucketized values or
307321
not.
308322
block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature.
@@ -318,6 +332,9 @@ def bucketize_kjt_inference(
318332
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
319333
)
320334
block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
335+
total_num_buckets_new_type = _fx_wrap_optional_tensor_to_device_dtype(
336+
total_num_buckets, kjt.values()
337+
)
321338
unbucketize_permute = None
322339
bucket_mapping = None
323340
if is_sequence:
@@ -332,6 +349,7 @@ def bucketize_kjt_inference(
332349
kjt,
333350
num_buckets=num_buckets,
334351
block_sizes=block_sizes_new_type,
352+
total_num_blocks=total_num_buckets_new_type,
335353
bucketize_pos=bucketize_pos,
336354
block_bucketize_pos=block_bucketize_row_pos,
337355
)

0 commit comments

Comments
 (0)