Skip to content

Commit 6b37bc0

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
TorchRec 2D Parallel (#2554)
Summary: Pull Request resolved: #2554 In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Reviewed By: dstaay-fb Differential Revision: D61643328 fbshipit-source-id: 7e3e447210cabe9a28b72c1acc48ae06b153d95d
1 parent 6f1a45d commit 6b37bc0

12 files changed

+991
-47
lines changed

torchrec/distributed/batched_embedding_kernel.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
PartiallyMaterializedTensor,
4747
)
4848
from torch import nn
49-
from torchrec.distributed.comm import get_local_rank, get_local_size
49+
from torchrec.distributed.comm import get_local_rank, get_node_group_size
5050
from torchrec.distributed.composable.table_batched_embedding_slice import (
5151
TableBatchedEmbeddingSlice,
5252
)
@@ -303,7 +303,7 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata(
303303
)
304304
# for grid sharding, the row dimension is replicated CW shard times
305305
grid_shard_nodes = (
306-
len(table_global_shards_metadata) // get_local_size()
306+
len(table_global_shards_metadata) // get_node_group_size()
307307
if is_grid_sharded
308308
else 1
309309
)
@@ -1445,7 +1445,6 @@ def __init__(
14451445
fused_params = config.fused_params or {}
14461446
if "cache_precision" not in fused_params:
14471447
fused_params["cache_precision"] = weights_precision
1448-
14491448
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
14501449
SplitTableBatchedEmbeddingBagsCodegen(
14511450
embedding_specs=list(

torchrec/distributed/comm.py

+110
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@
1313

1414
import torch
1515
import torch.distributed as dist
16+
from torchrec.distributed.types import ShardingEnv2D
1617

1718
logger: logging.Logger = logging.getLogger(__name__)
1819

1920
# Global, only should be accessed via intra_and_cross_node_pg()
2021
_INTRA_PG: Optional[dist.ProcessGroup] = None
2122
_CROSS_PG: Optional[dist.ProcessGroup] = None
2223

24+
# For 2D parallel
25+
_INTRA_PG_2D: Optional[dist.ProcessGroup] = None
26+
_CROSS_PG_2D: Optional[dist.ProcessGroup] = None
27+
_NODE_GROUP_SIZE_2D: Optional[int] = None
28+
2329

2430
def _env2int(env_list: List[str], default: int = -1) -> int:
2531
for e in env_list:
@@ -54,6 +60,15 @@ def get_local_size(world_size: Optional[int] = None) -> int:
5460
return local_size
5561

5662

63+
def get_node_group_size(world_size: Optional[int] = None) -> int:
64+
"""
65+
Get the local world size accounting for 2D environment, if not set, we fallback to global environment
66+
"""
67+
if _NODE_GROUP_SIZE_2D is None:
68+
return get_local_size(world_size)
69+
return _NODE_GROUP_SIZE_2D
70+
71+
5772
def get_local_rank(world_size: Optional[int] = None, rank: Optional[int] = None) -> int:
5873
"""
5974
Gets the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html)
@@ -151,3 +166,98 @@ def intra_and_cross_node_pg(
151166
dist.barrier()
152167

153168
return _INTRA_PG, _CROSS_PG
169+
170+
171+
def intra_and_cross_node_pg_2D(
172+
env: ShardingEnv2D,
173+
device: Optional[torch.device] = None,
174+
) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]:
175+
"""
176+
Creates sub process groups (intra and cross node) under 2D parallelism scheme
177+
The concept of "intra" and "cross" node is lost under a 2D parallelism scheme
178+
due to the ranks that exist under a sharding group do not have gurantee of the typical
179+
node topology. And as such there are no guarantees of "intra" group exploiting intra node bandwidth.
180+
181+
NOTE:
182+
These process groups are created for sharding schemes (ie: GRID) that were designed to exploit
183+
intra node bandwidth for optimized comms. There will be future work to redesign the comms for GRID
184+
sharding to be optimized under a 2D setup.
185+
186+
Example::
187+
Here is what "intra" and "cross" groups look like in a 2D environment,
188+
Sharding Groups:
189+
Group 0: [0, 2, 4, 6]
190+
Group 1: [1, 3, 5, 7]
191+
devices_per_node = 2:
192+
"intra" groups for each sharding group,
193+
Group 0: [0, 2], [4, 6]
194+
Group 1: [1, 3], [5, 7]
195+
"cross" groups for each sharding group,
196+
Group 0: [0, 4], [2, 6]
197+
Group 1: [1, 5], [3, 7]
198+
199+
We can see as this scales to real world topologies how the "intra" and "cross" node ideas in a traditional
200+
sense are not applicable here.
201+
"""
202+
if device is not None and device.type == "meta":
203+
return None, None
204+
205+
global _INTRA_PG_2D
206+
global _CROSS_PG_2D
207+
global _NODE_GROUP_SIZE_2D
208+
209+
backend = dist.get_backend(env.sharding_pg)
210+
my_rank = dist.get_rank()
211+
212+
sharding_group_size = dist.get_world_size(
213+
env.sharding_pg
214+
) # Local replica group world size
215+
world_size = dist.get_world_size() # Global world size
216+
step = world_size // sharding_group_size
217+
devices_per_node = (
218+
env.node_group_size if env.node_group_size else get_local_size(world_size)
219+
)
220+
_NODE_GROUP_SIZE_2D = devices_per_node
221+
222+
assert (
223+
sharding_group_size % devices_per_node == 0
224+
), f"node group size is not divisible by sharding group size, {devices_per_node=}, {sharding_group_size=}"
225+
intra_pg_groups: List[List[List[int]]] = [[] for _ in range(step)]
226+
227+
if _INTRA_PG_2D is None:
228+
for group_rank in range(step):
229+
sharding_pg_peers = [
230+
step * r + group_rank for r in range(sharding_group_size)
231+
]
232+
for group in range(len(sharding_pg_peers) // devices_per_node):
233+
intra_pg_peers = sharding_pg_peers[
234+
group * devices_per_node : (group + 1) * devices_per_node
235+
]
236+
intra_pg_groups[group_rank].append(intra_pg_peers)
237+
curr_intra_pg = dist.new_group(backend=backend, ranks=intra_pg_peers)
238+
if my_rank in intra_pg_peers:
239+
logger.warning(
240+
f"[Connection] 2D rank {my_rank} -> intra_pg_peers {intra_pg_peers}"
241+
)
242+
_INTRA_PG_2D = curr_intra_pg
243+
assert _INTRA_PG_2D is not None, "INTRA_PG_2D is not initialized!"
244+
dist.barrier()
245+
246+
if _CROSS_PG_2D is None:
247+
for group_rank in range(step):
248+
intra_pg_group = intra_pg_groups[group_rank]
249+
for cross_group_rank in range(devices_per_node):
250+
cross_pg_peers = [
251+
intra_pg_group[j][cross_group_rank]
252+
for j in range(len(intra_pg_group))
253+
]
254+
curr_cross_pg = dist.new_group(backend=backend, ranks=cross_pg_peers)
255+
if my_rank in cross_pg_peers:
256+
logger.warning(
257+
f"[Connection] 2D rank {my_rank} -> cross_pg_peers {cross_pg_peers}"
258+
)
259+
_CROSS_PG_2D = curr_cross_pg
260+
assert _CROSS_PG_2D is not None, "CROSS_PG_2D is not initialized!"
261+
dist.barrier()
262+
263+
return _INTRA_PG_2D, _CROSS_PG_2D

torchrec/distributed/embeddingbag.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
QuantizedCommCodecs,
6666
ShardedTensor,
6767
ShardingEnv,
68+
ShardingEnv2D,
6869
ShardingType,
6970
ShardMetadata,
7071
TensorProperties,
@@ -149,6 +150,7 @@ def create_embedding_bag_sharding(
149150
EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
150151
]:
151152
sharding_type = sharding_infos[0].param_sharding.sharding_type
153+
152154
if device is not None and device.type == "meta":
153155
replace_placement_with_meta_device(sharding_infos)
154156
if sharding_type == ShardingType.TABLE_WISE.value:
@@ -949,10 +951,14 @@ def _initialize_torch_state(self) -> None: # noqa
949951
)
950952

951953
self._model_parallel_name_to_sharded_tensor[table_name] = (
952-
ShardedTensor._init_from_local_shards_and_global_metadata(
953-
local_shards=local_shards,
954-
sharded_tensor_metadata=metadata,
955-
process_group=none_throws(self._env.process_group),
954+
ShardedTensor._init_from_local_shards(
955+
local_shards,
956+
self._name_to_table_size[table_name],
957+
process_group=(
958+
self._env.sharding_pg
959+
if isinstance(self._env, ShardingEnv2D)
960+
else self._env.process_group
961+
),
956962
)
957963
)
958964

0 commit comments

Comments
 (0)