Skip to content

Commit 818769f

Browse files
faran928facebook-github-bot
authored andcommitted
[ZCH vNext] Bucket offsets and sizes in torchrec shard metadata for bucket wise sharding (pytorch#151192)
Summary: X-link: pytorch/torchrec#2885 X-link: pytorch/torchrec#2884 Bucket offsets and sizes in torchrec shard metadata for bucket wise sharding for ZCH v.Next Test Plan: buck test torchrec/distributed/tests:test_sharding_plan Differential Revision: D72921209
1 parent ddfc14b commit 818769f

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

torch/distributed/_shard/metadata.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,36 @@ class ShardMetadata:
2121
original tensor.
2222
placement(:class:`torch.distributed._remote_device`):
2323
Specifies the placement of this shard.
24+
bucket_id_offset: Optional[int] = None: This represents the bucket
25+
offset from which the bucket ids are stored in this shard
26+
num_buckets: Optional[int] = None: This represents the number of
27+
buckets stored in this shard for bucket-wise sharding
2428
"""
2529

26-
__slots__ = ["shard_offsets", "shard_sizes", "placement"]
30+
__slots__ = ["shard_offsets", "shard_sizes", "placement", "bucket_id_offset", "num_buckets"]
2731

2832
shard_offsets: list[int]
2933
shard_sizes: list[int]
3034
placement: Optional[_remote_device]
35+
bucket_id_offset: Optional[int]
36+
num_buckets: Optional[int]
3137

3238
def __init__(
3339
self,
3440
shard_offsets: list[int],
3541
shard_sizes: list[int],
3642
placement: Optional[Union[str, _remote_device]] = None,
43+
bucket_id_offset: Optional[int] = None,
44+
num_buckets: Optional[int] = None,
3745
):
3846
self.shard_offsets = shard_offsets
3947
self.shard_sizes = shard_sizes
4048
if isinstance(placement, str):
4149
self.placement = _remote_device(placement)
4250
else:
4351
self.placement = placement
52+
self.bucket_id_offset = bucket_id_offset
53+
self.num_buckets = num_buckets
4454
if len(self.shard_offsets) != len(self.shard_sizes):
4555
raise ValueError(
4656
f"shard_offsets and shard_sizes should have "
@@ -53,12 +63,23 @@ def __init__(
5363
raise ValueError("shard_offsets should be >=0")
5464
if self.shard_sizes[i] < 0:
5565
raise ValueError("shard_sizes should be >= 0")
66+
67+
if self.bucket_id_offset:
68+
if self.bucket_id_offset < 0:
69+
raise ValueError("bucket_id_offset should be >=0 for all the shards")
70+
if not self.num_buckets:
71+
raise ValueError("num_buckets should be provided for bucket-wise sharding when bucket_offset is set")
72+
if self.num_buckets < 0:
73+
raise ValueError("Numebr of bucket should be > 0 within each shard when bucket-wise sharding is enabled")
5674

5775
def __hash__(self):
5876
def _hash_reduce(a, b):
5977
return (a << 8) + hash(b)
6078

6179
res = reduce(_hash_reduce, self.shard_offsets, 37)
6280
res = reduce(_hash_reduce, self.shard_sizes, res)
81+
if self.bucket_id_offset:
82+
res = reduce(_hash_reduce, self.bucket_id_offset, res)
83+
res = reduce(_hash_reduce, self.num_buckets, res)
6384
res = _hash_reduce(res, self.placement)
6485
return res

0 commit comments

Comments
 (0)