13
13
14
14
import torch
15
15
import torch .distributed as dist
16
- from torchrec .distributed .comm import get_local_size , intra_and_cross_node_pg
16
+ from torch .distributed ._tensor import Shard
17
+ from torch .distributed .distributed_c10d import get_process_group_ranks
18
+ from torchrec .distributed .comm import (
19
+ get_local_size ,
20
+ intra_and_cross_node_pg ,
21
+ intra_and_cross_node_pg_2D ,
22
+ )
17
23
from torchrec .distributed .dist_data import (
18
24
KJTAllToAll ,
19
25
PooledEmbeddingsAllToAll ,
34
40
)
35
41
from torchrec .distributed .embedding_types import (
36
42
BaseGroupedFeatureProcessor ,
43
+ DTensorMetadata ,
37
44
EmbeddingComputeKernel ,
38
45
GroupedEmbeddingConfig ,
39
46
ShardedEmbeddingTable ,
44
51
QuantizedCommCodecs ,
45
52
ShardedTensorMetadata ,
46
53
ShardingEnv ,
54
+ ShardingEnv2D ,
47
55
ShardingType ,
48
56
ShardMetadata ,
49
57
)
@@ -71,14 +79,26 @@ def __init__(
71
79
) -> None :
72
80
super ().__init__ (qcomm_codecs_registry = qcomm_codecs_registry )
73
81
self ._env = env
74
- self ._pg : Optional [dist .ProcessGroup ] = self ._env .process_group
82
+ self ._is_2D_parallel : bool = isinstance (env , ShardingEnv2D )
83
+ self ._pg : Optional [dist .ProcessGroup ] = (
84
+ self ._env .sharding_pg # pyre-ignore[16]
85
+ if self ._is_2D_parallel
86
+ else self ._env .process_group
87
+ )
75
88
self ._world_size : int = self ._env .world_size
76
89
self ._rank : int = self ._env .rank
77
90
self ._device = device
78
91
self ._need_pos = need_pos
79
- intra_pg , cross_pg = intra_and_cross_node_pg (
80
- device , backend = dist .get_backend (self ._pg )
81
- )
92
+ if self ._is_2D_parallel :
93
+ intra_pg , cross_pg = intra_and_cross_node_pg_2D (
94
+ # pyre-fixme[6]
95
+ self ._env ,
96
+ device = device ,
97
+ )
98
+ else :
99
+ intra_pg , cross_pg = intra_and_cross_node_pg (
100
+ device , backend = dist .get_backend (self ._pg )
101
+ )
82
102
self ._intra_pg : Optional [dist .ProcessGroup ] = intra_pg
83
103
self ._cross_pg : Optional [dist .ProcessGroup ] = cross_pg
84
104
self ._local_size : int = (
@@ -112,11 +132,23 @@ def _shard(
112
132
world_size = self ._world_size
113
133
local_size = self ._local_size
114
134
tables_per_rank : List [List [ShardedEmbeddingTable ]] = [
115
- [] for i in range (world_size )
135
+ [] for _ in range (world_size )
116
136
]
137
+ peer_group = (
138
+ # pyre-ignore [6]
139
+ get_process_group_ranks (self ._pg )
140
+ if self ._is_2D_parallel
141
+ else None
142
+ )
117
143
for info in sharding_infos :
118
- # pyre-ignore [16]
119
- table_node = info .param_sharding .ranks [0 ] // local_size
144
+ # Under 2D parallelism we transform rank to the logical ordering in a regular parallelism scheme
145
+ rank = (
146
+ # pyre-ignore [16]
147
+ peer_group .index (info .param_sharding .ranks [0 ])
148
+ if peer_group is not None
149
+ else info .param_sharding .ranks [0 ]
150
+ )
151
+ table_node = rank // local_size
120
152
# pyre-fixme [16]
121
153
shards = info .param_sharding .sharding_spec .shards
122
154
@@ -131,6 +163,21 @@ def _shard(
131
163
),
132
164
)
133
165
166
+ dtensor_metadata = None
167
+ if info .fused_params .get ("output_dtensor" , False ): # pyre-ignore[16]
168
+ placements = (Shard (0 ),)
169
+ dtensor_metadata = DTensorMetadata (
170
+ mesh = self ._env .device_mesh ,
171
+ placements = placements ,
172
+ size = (
173
+ info .embedding_config .num_embeddings ,
174
+ info .embedding_config .embedding_dim ,
175
+ ),
176
+ stride = info .param .stride (),
177
+ )
178
+ # to not pass onto TBE
179
+ info .fused_params .pop ("output_dtensor" , None ) # pyre-ignore[16]
180
+
134
181
for rank in range (
135
182
table_node * local_size ,
136
183
(table_node + 1 ) * local_size ,
@@ -154,6 +201,7 @@ def _shard(
154
201
),
155
202
local_metadata = shards [rank_idx ],
156
203
global_metadata = global_metadata ,
204
+ dtensor_metadata = dtensor_metadata ,
157
205
weight_init_max = info .embedding_config .weight_init_max ,
158
206
weight_init_min = info .embedding_config .weight_init_min ,
159
207
fused_params = info .fused_params ,
0 commit comments