|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Optional |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.distributed as dist |
| 11 | +import torch.distributed._symmetric_memory as symm_mem |
| 12 | + |
| 13 | + |
| 14 | +# Adding out-of-tree ops to the `symm_mem` library |
| 15 | +lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901 |
| 16 | + |
| 17 | +""" |
| 18 | +all_to_all_vdev_2d_copy: |
| 19 | +Copy data from `input` to `symm_in_buf` and call `all_to_all_vdev_2d` to shuffle data |
| 20 | +""" |
| 21 | +lib.define( |
| 22 | + "all_to_all_vdev_2d_copy(" |
| 23 | + "Tensor input, Tensor symm_in_buf, Tensor(a!) out, " |
| 24 | + "Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()", |
| 25 | + tags=[torch._C.Tag.needs_exact_strides], |
| 26 | +) |
| 27 | + |
| 28 | + |
| 29 | +@torch.library.impl(lib, "all_to_all_vdev_2d_copy", "CUDA") |
| 30 | +def _all_to_all_vdev_2d_copy_cuda( |
| 31 | + input: torch.Tensor, |
| 32 | + symm_in_buf: torch.Tensor, |
| 33 | + out: torch.Tensor, |
| 34 | + in_splits: torch.Tensor, |
| 35 | + out_splits_offsets: torch.Tensor, |
| 36 | + group_name: str, |
| 37 | + major_align: Optional[int] = None, |
| 38 | +) -> None: |
| 39 | + if symm_in_buf.shape[0] < input.shape[0]: |
| 40 | + raise RuntimeError( |
| 41 | + f"symm_in_buf with dim-0 length {symm_in_buf.shape[0]} cannot fit input with dim-0 length {input.shape[0]}" |
| 42 | + ) |
| 43 | + if symm_in_buf.shape[1:] != input.shape[1:]: |
| 44 | + raise RuntimeError( |
| 45 | + f"symm_in_buf non-0 dims do not match that of input: {symm_in_buf.shape[1:]} vs {input.shape[1:]}" |
| 46 | + ) |
| 47 | + if symm_in_buf.dtype != input.dtype: |
| 48 | + raise RuntimeError( |
| 49 | + f"symm_in_buf dtype {symm_in_buf.dtype} does not match input dtype {input.dtype}" |
| 50 | + ) |
| 51 | + |
| 52 | + symm_in_buf.narrow(0, 0, input.shape[0]).copy_(input) |
| 53 | + torch.ops.symm_mem.all_to_all_vdev_2d( |
| 54 | + symm_in_buf, |
| 55 | + out, |
| 56 | + in_splits, |
| 57 | + out_splits_offsets, |
| 58 | + group_name, |
| 59 | + major_align, |
| 60 | + ) |
| 61 | + |
| 62 | + |
| 63 | +class AllToAllVDev2dOffset(torch.autograd.Function): |
| 64 | + """ |
| 65 | + Autograd function for `all_to_all_vdev_2d_offset` |
| 66 | + """ |
| 67 | + |
| 68 | + @staticmethod |
| 69 | + def forward( # type: ignore[no-untyped-def] |
| 70 | + ctx, |
| 71 | + input: torch.Tensor, |
| 72 | + out: torch.Tensor, |
| 73 | + in_splits_offsets: torch.Tensor, |
| 74 | + out_splits_offsets: torch.Tensor, |
| 75 | + group_name: str, |
| 76 | + # Buffers needed for backward pass |
| 77 | + major_align: int, |
| 78 | + grad_out_buf: torch.Tensor, |
| 79 | + grad_in_buf: torch.Tensor, |
| 80 | + grad_in_splits_offsets: torch.Tensor, |
| 81 | + ) -> torch.Tensor: |
| 82 | + """ |
| 83 | + Functionality is the same as `all_to_all_vdev_2d_offset` but with functionalization. |
| 84 | + """ |
| 85 | + # Shuffle input to output |
| 86 | + torch.ops.symm_mem.all_to_all_vdev_2d_offset( |
| 87 | + input, out, in_splits_offsets, out_splits_offsets, group_name |
| 88 | + ) |
| 89 | + |
| 90 | + # Output splits in forward is the input splits in backward |
| 91 | + ctx.save_for_backward( |
| 92 | + out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets |
| 93 | + ) |
| 94 | + ctx.group_name = group_name |
| 95 | + ctx.major_align = major_align |
| 96 | + return out |
| 97 | + |
| 98 | + @staticmethod |
| 99 | + def backward( # type: ignore[no-untyped-def] |
| 100 | + ctx, |
| 101 | + grad_output: torch.Tensor, |
| 102 | + ) -> tuple[torch.Tensor, None, None, None, None, None, None, None, None]: |
| 103 | + """ |
| 104 | + Backward pass of `all_to_all_vdev_2d_offset` is `all_to_all_vdev_2d`. |
| 105 | +
|
| 106 | + Args: |
| 107 | + `grad_output`: gradients of output passed back from the downstream. |
| 108 | +
|
| 109 | + Returns: |
| 110 | + `grad_input`: gradients of input. |
| 111 | + """ |
| 112 | + # Splits info |
| 113 | + # Splits/offsets of grad_out is the same as out splits/offsets in forward |
| 114 | + ( |
| 115 | + grad_out_splits_offsets, |
| 116 | + grad_out_buf, |
| 117 | + grad_in_buf, |
| 118 | + grad_in_splits_offsets, |
| 119 | + ) = ctx.saved_tensors |
| 120 | + grad_out_splits = grad_out_splits_offsets[0] |
| 121 | + |
| 122 | + # Shuffle gradients back to the input |
| 123 | + # TODO: create an op that takes both in_splits_offsets and |
| 124 | + # out_splits_offsets, instead of taking alignment |
| 125 | + torch.ops.symm_mem.all_to_all_vdev_2d_copy( |
| 126 | + grad_output, |
| 127 | + grad_out_buf, |
| 128 | + grad_in_buf, |
| 129 | + grad_out_splits, |
| 130 | + grad_in_splits_offsets, |
| 131 | + ctx.group_name, |
| 132 | + ctx.major_align, |
| 133 | + ) |
| 134 | + |
| 135 | + return grad_in_buf, None, None, None, None, None, None, None, None |
| 136 | + |
| 137 | + |
| 138 | +class TokenCombiner(torch.nn.Module): |
| 139 | + """ |
| 140 | + Combine tokens from different experts, with backward pass to shuffle gradients back to the input. |
| 141 | + Args: |
| 142 | + `group_name`: name of the group to use for communication. |
| 143 | + `align`: alignment of the token offsets from each expert. If using |
| 144 | + Grouped Gemm next, this should be the same as Grouped Gemm's alignment. |
| 145 | + `in_len`: length of the input. |
| 146 | + `out_len`: length of the output. |
| 147 | + `token_shape`: shape of the tokens. |
| 148 | + `num_ranks`: number of ranks in the group. |
| 149 | + `num_local_experts`: number of local experts. |
| 150 | + `dtype`: data type of the input/output. |
| 151 | + """ |
| 152 | + |
| 153 | + def __init__( |
| 154 | + self, |
| 155 | + group_name: str, |
| 156 | + align: int, |
| 157 | + in_len, |
| 158 | + out_len, |
| 159 | + token_shape, |
| 160 | + num_ranks, |
| 161 | + num_local_experts, |
| 162 | + dtype, |
| 163 | + ) -> None: |
| 164 | + super().__init__() |
| 165 | + self.group_name = group_name |
| 166 | + self.align = align |
| 167 | + self.grad_out_buf = symm_mem.empty(out_len, *token_shape, dtype=dtype) |
| 168 | + self.grad_in_buf = symm_mem.empty(in_len, *token_shape, dtype=dtype) |
| 169 | + self.nsplits = num_ranks * num_local_experts |
| 170 | + self.grad_in_splits_offsets = symm_mem.empty( |
| 171 | + (2, self.nsplits), dtype=torch.int64 |
| 172 | + ) |
| 173 | + |
| 174 | + def forward( |
| 175 | + self, |
| 176 | + inp: torch.Tensor, |
| 177 | + out: torch.Tensor, |
| 178 | + in_splits_offsets: torch.Tensor, |
| 179 | + out_splits_offsets: torch.Tensor, |
| 180 | + ) -> torch.Tensor: |
| 181 | + """ |
| 182 | + Args: |
| 183 | + `inp`: input tensor. |
| 184 | + `out`: buffer for output tensor. |
| 185 | + `in_splits_offsets`: splits and offsets of the input tensor. |
| 186 | + `out_splits_offsets`: splits and offsets of the output tensor. |
| 187 | + See `all_to_all_vdev_2d_offset` for more details. |
| 188 | + Note: |
| 189 | + All tensor arguments must be symmetrically allocated, i.e. |
| 190 | + >>> inp = symm_mem.empty(max_inp_len, dtype=dtype, device=device) |
| 191 | + >>> out = symm_mem.empty(max_out_len, dtype=dtype, device=device) |
| 192 | + >>> in_splits_offsets = symm_mem.empty( |
| 193 | + ... (2, nsplits), dtype=torch.int64, device=device) |
| 194 | + >>> out_splits_offsets = symm_mem.empty( |
| 195 | + ... (2, nsplits), dtype=torch.int64, device=device) |
| 196 | + """ |
| 197 | + |
| 198 | + if in_splits_offsets.shape != (2, self.nsplits): |
| 199 | + raise ValueError( |
| 200 | + f"Expected shape (2, {self.nsplits}), got {in_splits_offsets.shape}" |
| 201 | + ) |
| 202 | + if out_splits_offsets.shape != (2, self.nsplits): |
| 203 | + raise ValueError( |
| 204 | + f"Expected shape (2, {self.nsplits}), got {out_splits_offsets.shape}" |
| 205 | + ) |
| 206 | + |
| 207 | + return AllToAllVDev2dOffset.apply( |
| 208 | + inp, |
| 209 | + out, |
| 210 | + in_splits_offsets, |
| 211 | + out_splits_offsets, |
| 212 | + self.group_name, |
| 213 | + self.align, |
| 214 | + self.grad_out_buf, |
| 215 | + self.grad_in_buf, |
| 216 | + self.grad_in_splits_offsets, |
| 217 | + ) |
| 218 | + |
| 219 | + |
| 220 | +def test_token_combine() -> None: |
| 221 | + # Init |
| 222 | + dist.init_process_group() |
| 223 | + rank = dist.get_rank() |
| 224 | + world_size = dist.get_world_size() |
| 225 | + device_count = torch.cuda.device_count() |
| 226 | + device = torch.device("cuda", rank % device_count) |
| 227 | + |
| 228 | + # NVSHMEM backend specific |
| 229 | + torch.cuda.set_device(device) |
| 230 | + torch.empty(1, device=device) |
| 231 | + # Set NVSHMEM as SymmMem backend |
| 232 | + symm_mem.set_backend("NVSHMEM") |
| 233 | + |
| 234 | + # Mimics Group GEMM alignment |
| 235 | + align = 8 |
| 236 | + torch.manual_seed(42 + rank) |
| 237 | + |
| 238 | + group_name = dist.group.WORLD.group_name |
| 239 | + symm_mem.enable_symm_mem_for_group(group_name) |
| 240 | + |
| 241 | + dtype = torch.float |
| 242 | + # Number of experts per rank |
| 243 | + ne = 8 |
| 244 | + nsplits = ne * world_size |
| 245 | + |
| 246 | + # Number of elements for an expert is random between [0, k) |
| 247 | + k = 10 |
| 248 | + inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=device) |
| 249 | + |
| 250 | + # Max number of input elements (must be a constant across ranks for symmetric memory allocation) |
| 251 | + max_inp_len = k * nsplits |
| 252 | + # Max number of output elements (must be a constant across ranks for symmetric memory allocation) |
| 253 | + overflow_factor = world_size # worst case: one rank receives all data |
| 254 | + max_out_len = max_inp_len * overflow_factor |
| 255 | + |
| 256 | + # Use a dispatch to prepare the input for combine (this is just a |
| 257 | + # preparation, not the test itself) |
| 258 | + # Buffers for dispatch |
| 259 | + hid = 4096 |
| 260 | + inp = symm_mem.empty(max_inp_len, hid, dtype=dtype, device=device) |
| 261 | + out = symm_mem.empty(max_out_len, hid, dtype=dtype, device=device) |
| 262 | + # 2 rows: input splits, input offsets |
| 263 | + in_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device) |
| 264 | + # 2 rows: output splits, output offsets |
| 265 | + out_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device) |
| 266 | + |
| 267 | + # Dispatch the tokens first so that we have a nice input for combine |
| 268 | + in_splits_offsets[0].copy_(inp_splits) |
| 269 | + torch.ops.symm_mem.all_to_all_vdev_2d( |
| 270 | + inp, |
| 271 | + out, |
| 272 | + in_splits_offsets[0], |
| 273 | + out_splits_offsets, |
| 274 | + group_name, |
| 275 | + major_align=align, |
| 276 | + ) |
| 277 | + |
| 278 | + with device: |
| 279 | + combiner = TokenCombiner( |
| 280 | + group_name, |
| 281 | + align, |
| 282 | + max_out_len, |
| 283 | + max_inp_len, |
| 284 | + out.shape[1:], |
| 285 | + world_size, |
| 286 | + ne, |
| 287 | + dtype, |
| 288 | + ) |
| 289 | + |
| 290 | + compiled_combiner = torch.compile( |
| 291 | + combiner, |
| 292 | + fullgraph=True, |
| 293 | + ) |
| 294 | + |
| 295 | + # Perform a Dot product with output, so that gradients passed back from |
| 296 | + # different ranks are different |
| 297 | + weight = torch.empty(max_inp_len, dtype=dtype, device=device).fill_(rank + 1) |
| 298 | + |
| 299 | + # Now we start to test the autograd function |
| 300 | + |
| 301 | + # Requires grad for input of combine |
| 302 | + out.requires_grad_(True) |
| 303 | + |
| 304 | + combine_out = compiled_combiner( |
| 305 | + out, |
| 306 | + inp, |
| 307 | + out_splits_offsets, |
| 308 | + in_splits_offsets, |
| 309 | + ) |
| 310 | + p = torch.matmul(weight, combine_out) |
| 311 | + p.sum().backward() |
| 312 | + |
| 313 | + # Check gradients |
| 314 | + # We also need to skip the padding in the input data |
| 315 | + out_splits = out_splits_offsets[0].tolist() |
| 316 | + out_offsets = out_splits_offsets[1].tolist() |
| 317 | + for i, (split, offset) in enumerate(zip(out_splits, out_offsets)): |
| 318 | + grad_chunk = out.grad[offset : offset + split] |
| 319 | + dst_rank = i % world_size |
| 320 | + torch.testing.assert_close( |
| 321 | + grad_chunk, |
| 322 | + torch.empty(split, hid, device=device).fill_(dst_rank + 1), |
| 323 | + ) |
| 324 | + |
| 325 | + dist.destroy_process_group() |
| 326 | + print(f"Rank {rank} passed") |
| 327 | + |
| 328 | + |
| 329 | +if __name__ == "__main__": |
| 330 | + # To run this test, use the following command: |
| 331 | + # torchrun --nproc-per-node 4 --standalone combine.py |
| 332 | + test_token_combine() |
0 commit comments