|
13 | 13 |
|
14 | 14 | import torch
|
15 | 15 | import torch.distributed as dist
|
| 16 | +from torchrec.distributed.types import ShardingEnv2D |
16 | 17 |
|
17 | 18 | logger: logging.Logger = logging.getLogger(__name__)
|
18 | 19 |
|
19 | 20 | # Global, only should be accessed via intra_and_cross_node_pg()
|
20 | 21 | _INTRA_PG: Optional[dist.ProcessGroup] = None
|
21 | 22 | _CROSS_PG: Optional[dist.ProcessGroup] = None
|
22 | 23 |
|
| 24 | +# For 2D parallel |
| 25 | +_INTRA_PG_2D: Optional[dist.ProcessGroup] = None |
| 26 | +_CROSS_PG_2D: Optional[dist.ProcessGroup] = None |
| 27 | +_NODE_GROUP_SIZE_2D: Optional[int] = None |
| 28 | + |
23 | 29 |
|
24 | 30 | def _env2int(env_list: List[str], default: int = -1) -> int:
|
25 | 31 | for e in env_list:
|
@@ -54,6 +60,15 @@ def get_local_size(world_size: Optional[int] = None) -> int:
|
54 | 60 | return local_size
|
55 | 61 |
|
56 | 62 |
|
| 63 | +def get_node_group_size(world_size: Optional[int] = None) -> int: |
| 64 | + """ |
| 65 | + Get the local world size accounting for 2D environment, if not set, we fallback to global environment |
| 66 | + """ |
| 67 | + if _NODE_GROUP_SIZE_2D is None: |
| 68 | + return get_local_size(world_size) |
| 69 | + return _NODE_GROUP_SIZE_2D |
| 70 | + |
| 71 | + |
57 | 72 | def get_local_rank(world_size: Optional[int] = None, rank: Optional[int] = None) -> int:
|
58 | 73 | """
|
59 | 74 | Gets the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html)
|
@@ -151,3 +166,98 @@ def intra_and_cross_node_pg(
|
151 | 166 | dist.barrier()
|
152 | 167 |
|
153 | 168 | return _INTRA_PG, _CROSS_PG
|
| 169 | + |
| 170 | + |
| 171 | +def intra_and_cross_node_pg_2D( |
| 172 | + env: ShardingEnv2D, |
| 173 | + device: Optional[torch.device] = None, |
| 174 | +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: |
| 175 | + """ |
| 176 | + Creates sub process groups (intra and cross node) under 2D parallelism scheme |
| 177 | + The concept of "intra" and "cross" node is lost under a 2D parallelism scheme |
| 178 | + due to the ranks that exist under a sharding group do not have gurantee of the typical |
| 179 | + node topology. And as such there are no guarantees of "intra" group exploiting intra node bandwidth. |
| 180 | +
|
| 181 | + NOTE: |
| 182 | + These process groups are created for sharding schemes (ie: GRID) that were designed to exploit |
| 183 | + intra node bandwidth for optimized comms. There will be future work to redesign the comms for GRID |
| 184 | + sharding to be optimized under a 2D setup. |
| 185 | +
|
| 186 | + Example:: |
| 187 | + Here is what "intra" and "cross" groups look like in a 2D environment, |
| 188 | + Sharding Groups: |
| 189 | + Group 0: [0, 2, 4, 6] |
| 190 | + Group 1: [1, 3, 5, 7] |
| 191 | + devices_per_node = 2: |
| 192 | + "intra" groups for each sharding group, |
| 193 | + Group 0: [0, 2], [4, 6] |
| 194 | + Group 1: [1, 3], [5, 7] |
| 195 | + "cross" groups for each sharding group, |
| 196 | + Group 0: [0, 4], [2, 6] |
| 197 | + Group 1: [1, 5], [3, 7] |
| 198 | +
|
| 199 | + We can see as this scales to real world topologies how the "intra" and "cross" node ideas in a traditional |
| 200 | + sense are not applicable here. |
| 201 | + """ |
| 202 | + if device is not None and device.type == "meta": |
| 203 | + return None, None |
| 204 | + |
| 205 | + global _INTRA_PG_2D |
| 206 | + global _CROSS_PG_2D |
| 207 | + global _NODE_GROUP_SIZE_2D |
| 208 | + |
| 209 | + backend = dist.get_backend(env.sharding_pg) |
| 210 | + my_rank = dist.get_rank() |
| 211 | + |
| 212 | + sharding_group_size = dist.get_world_size( |
| 213 | + env.sharding_pg |
| 214 | + ) # Local replica group world size |
| 215 | + world_size = dist.get_world_size() # Global world size |
| 216 | + step = world_size // sharding_group_size |
| 217 | + devices_per_node = ( |
| 218 | + env.node_group_size if env.node_group_size else get_local_size(world_size) |
| 219 | + ) |
| 220 | + _NODE_GROUP_SIZE_2D = devices_per_node |
| 221 | + |
| 222 | + assert ( |
| 223 | + sharding_group_size % devices_per_node == 0 |
| 224 | + ), f"node group size is not divisible by sharding group size, {devices_per_node=}, {sharding_group_size=}" |
| 225 | + intra_pg_groups: List[List[List[int]]] = [[] for _ in range(step)] |
| 226 | + |
| 227 | + if _INTRA_PG_2D is None: |
| 228 | + for group_rank in range(step): |
| 229 | + sharding_pg_peers = [ |
| 230 | + step * r + group_rank for r in range(sharding_group_size) |
| 231 | + ] |
| 232 | + for group in range(len(sharding_pg_peers) // devices_per_node): |
| 233 | + intra_pg_peers = sharding_pg_peers[ |
| 234 | + group * devices_per_node : (group + 1) * devices_per_node |
| 235 | + ] |
| 236 | + intra_pg_groups[group_rank].append(intra_pg_peers) |
| 237 | + curr_intra_pg = dist.new_group(backend=backend, ranks=intra_pg_peers) |
| 238 | + if my_rank in intra_pg_peers: |
| 239 | + logger.warning( |
| 240 | + f"[Connection] 2D rank {my_rank} -> intra_pg_peers {intra_pg_peers}" |
| 241 | + ) |
| 242 | + _INTRA_PG_2D = curr_intra_pg |
| 243 | + assert _INTRA_PG_2D is not None, "INTRA_PG_2D is not initialized!" |
| 244 | + dist.barrier() |
| 245 | + |
| 246 | + if _CROSS_PG_2D is None: |
| 247 | + for group_rank in range(step): |
| 248 | + intra_pg_group = intra_pg_groups[group_rank] |
| 249 | + for cross_group_rank in range(devices_per_node): |
| 250 | + cross_pg_peers = [ |
| 251 | + intra_pg_group[j][cross_group_rank] |
| 252 | + for j in range(len(intra_pg_group)) |
| 253 | + ] |
| 254 | + curr_cross_pg = dist.new_group(backend=backend, ranks=cross_pg_peers) |
| 255 | + if my_rank in cross_pg_peers: |
| 256 | + logger.warning( |
| 257 | + f"[Connection] 2D rank {my_rank} -> cross_pg_peers {cross_pg_peers}" |
| 258 | + ) |
| 259 | + _CROSS_PG_2D = curr_cross_pg |
| 260 | + assert _CROSS_PG_2D is not None, "CROSS_PG_2D is not initialized!" |
| 261 | + dist.barrier() |
| 262 | + |
| 263 | + return _INTRA_PG_2D, _CROSS_PG_2D |
0 commit comments