Skip to content

Commit 401702a

Browse files
isururanawakameta-codesync[bot]
authored andcommitted
Fix hardcoded local_world_size in dynamic resharding (meta-pytorch#4000)
Summary: Pull Request resolved: meta-pytorch#4000 `_prepare_shard_distribution_comm_ops()` hardcoded `_is_intra_comm(src, dst, 8)` to classify P2P communication as intra-node vs inter-node. This is wrong on hardware with `local_world_size != 8` (e.g. GB200_HP with `local_world_size=2`), causing inter-host P2P ops to be misclassified as intra-node. Added `local_world_size` parameter (default 8 for backward compatibility). Callers can now pass the actual value from their topology or process group. Reviewed By: kausv Differential Revision: D98986450 fbshipit-source-id: 66e04aea46e643cfbfe764f160a063364408ad10
1 parent 21e7aaf commit 401702a

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import torch.distributed as dist
17+
from torchrec.distributed.comm import get_local_size
1718
from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo
1819
from torchrec.distributed.types import (
1920
EmbeddingModuleShardingPlan,
@@ -242,6 +243,7 @@ def _prepare_shard_distribution_comm_ops(
242243
new_opt_state: Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]] = None,
243244
has_optimizer: bool = False,
244245
process_group: Optional[dist.ProcessGroup] = None,
246+
local_world_size: int = 8,
245247
) -> None:
246248
"""
247249
Prepares the point-to-point (P2P) communication operations required to redistribute sharded tensors
@@ -294,7 +296,7 @@ def _prepare_shard_distribution_comm_ops(
294296
# Update the shard size with new size
295297
shard_size = [shard_size[0], split_offsets[1] - split_offsets[0], shard_id]
296298

297-
intra = _is_intra_comm(src_rank, dst_rank, 8)
299+
intra = _is_intra_comm(src_rank, dst_rank, local_world_size)
298300
# Create unique tag for P2P communication (32-bit limit for PyTorch)
299301

300302
tag = _generate_tag(
@@ -381,7 +383,7 @@ def _prepare_shard_distribution_comm_ops(
381383

382384
tensor_col_offset = 0
383385
for _, src_rank, tag, tag_opt, shard_size in receiving_shards_metadata:
384-
intra = _is_intra_comm(src_rank, rank, 8)
386+
intra = _is_intra_comm(src_rank, rank, local_world_size)
385387
end_col_offset = tensor_col_offset + shard_size[1]
386388
receving_tensor_view = local_tensor_dst[:, tensor_col_offset:end_col_offset]
387389
copy = False
@@ -796,6 +798,7 @@ def prepare_comm_ops(
796798
extend_shard_name=extend_shard_name,
797799
has_optimizer=has_optimizer,
798800
process_group=pg,
801+
local_world_size=get_local_size(world_size),
799802
)
800803

801804
return comm_dict

0 commit comments

Comments
 (0)