@@ -21,26 +21,36 @@ class ShardMetadata:
21
21
original tensor.
22
22
placement(:class:`torch.distributed._remote_device`):
23
23
Specifies the placement of this shard.
24
+ bucket_id_offset: Optional[int] = None: This represents the bucket
25
+ offset from which the bucket ids are stored in this shard
26
+ num_buckets: Optional[int] = None: This represents the number of
27
+ buckets stored in this shard for bucket-wise sharding
24
28
"""
25
29
26
- __slots__ = ["shard_offsets" , "shard_sizes" , "placement" ]
30
+ __slots__ = ["shard_offsets" , "shard_sizes" , "placement" , "bucket_id_offset" , "num_buckets" ]
27
31
28
32
shard_offsets : list [int ]
29
33
shard_sizes : list [int ]
30
34
placement : Optional [_remote_device ]
35
+ bucket_id_offset : Optional [int ]
36
+ num_buckets : Optional [int ]
31
37
32
38
def __init__ (
33
39
self ,
34
40
shard_offsets : list [int ],
35
41
shard_sizes : list [int ],
36
42
placement : Optional [Union [str , _remote_device ]] = None ,
43
+ bucket_id_offset : Optional [int ] = None ,
44
+ num_buckets : Optional [int ] = None ,
37
45
):
38
46
self .shard_offsets = shard_offsets
39
47
self .shard_sizes = shard_sizes
40
48
if isinstance (placement , str ):
41
49
self .placement = _remote_device (placement )
42
50
else :
43
51
self .placement = placement
52
+ self .bucket_id_offset = bucket_id_offset
53
+ self .num_buckets = num_buckets
44
54
if len (self .shard_offsets ) != len (self .shard_sizes ):
45
55
raise ValueError (
46
56
f"shard_offsets and shard_sizes should have "
@@ -53,12 +63,23 @@ def __init__(
53
63
raise ValueError ("shard_offsets should be >=0" )
54
64
if self .shard_sizes [i ] < 0 :
55
65
raise ValueError ("shard_sizes should be >= 0" )
66
+
67
+ if self .bucket_id_offset :
68
+ if self .bucket_id_offset < 0 :
69
+ raise ValueError ("bucket_id_offset should be >=0 for all the shards" )
70
+ if not self .num_buckets :
71
+ raise ValueError ("num_buckets should be provided for bucket-wise sharding when bucket_offset is set" )
72
+ if self .num_buckets < 0 :
73
+ raise ValueError ("Numebr of bucket should be > 0 within each shard when bucket-wise sharding is enabled" )
56
74
57
75
def __hash__ (self ):
58
76
def _hash_reduce (a , b ):
59
77
return (a << 8 ) + hash (b )
60
78
61
79
res = reduce (_hash_reduce , self .shard_offsets , 37 )
62
80
res = reduce (_hash_reduce , self .shard_sizes , res )
81
+ if self .bucket_id_offset :
82
+ res = reduce (_hash_reduce , self .bucket_id_offset , res )
83
+ res = reduce (_hash_reduce , self .num_buckets , res )
63
84
res = _hash_reduce (res , self .placement )
64
85
return res
0 commit comments