Skip to content

Commit 019a92d

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
add 2D parallel and DTensor support to TWRW (#2629)
Summary: Pull Request resolved: #2629 Add 2D and DTensor support for TWRW Differential Revision: D67145321 fbshipit-source-id: eaa264131628c89fa5900fd2df9dc10b1d3da893
1 parent 3928a1b commit 019a92d

File tree

2 files changed

+143
-8
lines changed

2 files changed

+143
-8
lines changed

torchrec/distributed/sharding/twrw_sharding.py

+56-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313

1414
import torch
1515
import torch.distributed as dist
16-
from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg
16+
from torch.distributed._tensor import Shard
17+
from torch.distributed.distributed_c10d import get_process_group_ranks
18+
from torchrec.distributed.comm import (
19+
get_local_size,
20+
intra_and_cross_node_pg,
21+
intra_and_cross_node_pg_2D,
22+
)
1723
from torchrec.distributed.dist_data import (
1824
KJTAllToAll,
1925
PooledEmbeddingsAllToAll,
@@ -34,6 +40,7 @@
3440
)
3541
from torchrec.distributed.embedding_types import (
3642
BaseGroupedFeatureProcessor,
43+
DTensorMetadata,
3744
EmbeddingComputeKernel,
3845
GroupedEmbeddingConfig,
3946
ShardedEmbeddingTable,
@@ -44,6 +51,7 @@
4451
QuantizedCommCodecs,
4552
ShardedTensorMetadata,
4653
ShardingEnv,
54+
ShardingEnv2D,
4755
ShardingType,
4856
ShardMetadata,
4957
)
@@ -71,14 +79,26 @@ def __init__(
7179
) -> None:
7280
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
7381
self._env = env
74-
self._pg: Optional[dist.ProcessGroup] = self._env.process_group
82+
self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D)
83+
self._pg: Optional[dist.ProcessGroup] = (
84+
self._env.sharding_pg # pyre-ignore[16]
85+
if self._is_2D_parallel
86+
else self._env.process_group
87+
)
7588
self._world_size: int = self._env.world_size
7689
self._rank: int = self._env.rank
7790
self._device = device
7891
self._need_pos = need_pos
79-
intra_pg, cross_pg = intra_and_cross_node_pg(
80-
device, backend=dist.get_backend(self._pg)
81-
)
92+
if self._is_2D_parallel:
93+
intra_pg, cross_pg = intra_and_cross_node_pg_2D(
94+
# pyre-fixme[6]
95+
self._env,
96+
device=device,
97+
)
98+
else:
99+
intra_pg, cross_pg = intra_and_cross_node_pg(
100+
device, backend=dist.get_backend(self._pg)
101+
)
82102
self._intra_pg: Optional[dist.ProcessGroup] = intra_pg
83103
self._cross_pg: Optional[dist.ProcessGroup] = cross_pg
84104
self._local_size: int = (
@@ -112,11 +132,23 @@ def _shard(
112132
world_size = self._world_size
113133
local_size = self._local_size
114134
tables_per_rank: List[List[ShardedEmbeddingTable]] = [
115-
[] for i in range(world_size)
135+
[] for _ in range(world_size)
116136
]
137+
peer_group = (
138+
# pyre-ignore [6]
139+
get_process_group_ranks(self._pg)
140+
if self._is_2D_parallel
141+
else None
142+
)
117143
for info in sharding_infos:
118-
# pyre-ignore [16]
119-
table_node = info.param_sharding.ranks[0] // local_size
144+
# Under 2D parallelism we transform rank to the logical ordering in a regular parallelism scheme
145+
rank = (
146+
# pyre-ignore [16]
147+
peer_group.index(info.param_sharding.ranks[0])
148+
if peer_group is not None
149+
else info.param_sharding.ranks[0]
150+
)
151+
table_node = rank // local_size
120152
# pyre-fixme [16]
121153
shards = info.param_sharding.sharding_spec.shards
122154

@@ -131,6 +163,21 @@ def _shard(
131163
),
132164
)
133165

166+
dtensor_metadata = None
167+
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
168+
placements = (Shard(0),)
169+
dtensor_metadata = DTensorMetadata(
170+
mesh=self._env.device_mesh,
171+
placements=placements,
172+
size=(
173+
info.embedding_config.num_embeddings,
174+
info.embedding_config.embedding_dim,
175+
),
176+
stride=info.param.stride(),
177+
)
178+
# to not pass onto TBE
179+
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]
180+
134181
for rank in range(
135182
table_node * local_size,
136183
(table_node + 1) * local_size,
@@ -154,6 +201,7 @@ def _shard(
154201
),
155202
local_metadata=shards[rank_idx],
156203
global_metadata=global_metadata,
204+
dtensor_metadata=dtensor_metadata,
157205
weight_init_max=info.embedding_config.weight_init_max,
158206
weight_init_min=info.embedding_config.weight_init_min,
159207
fused_params=info.fused_params,

torchrec/distributed/tests/test_2d_sharding.py

+87
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,90 @@ def test_sharding_rw_2D(
402402
variable_batch_size=variable_batch_size,
403403
pooling=pooling,
404404
)
405+
406+
@unittest.skipIf(
407+
torch.cuda.device_count() <= 7,
408+
"Not enough GPUs, this test requires at least four GPUs",
409+
)
410+
# pyre-fixme[56]
411+
@given(
412+
sharder_type=st.sampled_from(
413+
[
414+
SharderType.EMBEDDING_BAG_COLLECTION.value,
415+
]
416+
),
417+
kernel_type=st.sampled_from(
418+
[
419+
EmbeddingComputeKernel.FUSED.value,
420+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
421+
EmbeddingComputeKernel.FUSED_UVM.value,
422+
],
423+
),
424+
qcomms_config=st.sampled_from(
425+
[
426+
# None,
427+
QCommsConfig(
428+
forward_precision=CommType.FP16, backward_precision=CommType.BF16
429+
),
430+
]
431+
),
432+
apply_optimizer_in_backward_config=st.sampled_from(
433+
[
434+
None,
435+
{
436+
"embedding_bags": (
437+
torch.optim.SGD,
438+
{
439+
"lr": 0.01,
440+
},
441+
),
442+
},
443+
]
444+
),
445+
pooling=st.sampled_from([PoolingType.SUM]),
446+
)
447+
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
448+
def test_sharding_twrw_2D(
449+
self,
450+
sharder_type: str,
451+
kernel_type: str,
452+
qcomms_config: Optional[QCommsConfig],
453+
apply_optimizer_in_backward_config: Optional[
454+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
455+
],
456+
pooling: PoolingType,
457+
) -> None:
458+
if (
459+
self.device == torch.device("cpu")
460+
and kernel_type != EmbeddingComputeKernel.FUSED.value
461+
):
462+
self.skipTest("CPU does not support uvm.")
463+
464+
sharding_type = ShardingType.TABLE_ROW_WISE.value
465+
assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value)
466+
467+
self._test_sharding(
468+
world_size=self.WORLD_SIZE,
469+
local_size=self.WORLD_SIZE_2D // 2,
470+
world_size_2D=self.WORLD_SIZE_2D,
471+
sharders=[
472+
cast(
473+
ModuleSharder[nn.Module],
474+
create_test_sharder(
475+
sharder_type,
476+
sharding_type,
477+
kernel_type,
478+
qcomms_config=qcomms_config,
479+
device=self.device,
480+
),
481+
),
482+
],
483+
qcomms_config=qcomms_config,
484+
constraints={
485+
table.name: ParameterConstraints(min_partition=2)
486+
for table in self.tables
487+
},
488+
backend=self.backend,
489+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
490+
pooling=pooling,
491+
)

0 commit comments

Comments
 (0)