Skip to content

Commit 9c42b9b

Browse files
authored
[a2av] Add autograd support for token dispatch op (#1491)
Added class `TokenDispatcher` which dispatches tokens to different experts, with backward support. Usage: ``` dispatcher = TokenDispatcher(group_name, align, max_inp_len, max_out_len, inp.shape[1:], world_size, ne, dtype) # inp, out, in_splits, out_splits_offsets must be symmetric tensors output = dispatcher(inp, out, in_splits, out_splits_offsets) ``` Supports: ``` torch.compile(dispatcher) ```
1 parent d14f1e3 commit 9c42b9b

File tree

1 file changed

+312
-0
lines changed

1 file changed

+312
-0
lines changed
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
import torch.distributed as dist
9+
import torch.distributed._symmetric_memory as symm_mem
10+
11+
12+
# Adding out-of-tree ops to the `symm_mem` library
13+
lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901
14+
15+
"""
16+
all_to_all_vdev_2d_offset_copy:
17+
Copy data from `input` to `symm_in_buf` and call `all_to_all_vdev_2d_offset` to shuffle data
18+
"""
19+
lib.define(
20+
"all_to_all_vdev_2d_offset_copy("
21+
"Tensor input, Tensor symm_in_buf, Tensor(a!) out, "
22+
"Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()",
23+
tags=[torch._C.Tag.needs_exact_strides],
24+
)
25+
26+
27+
@torch.library.impl(lib, "all_to_all_vdev_2d_offset_copy", "CUDA")
28+
def _all_to_all_vdev_2d_offset_copy_cuda(
29+
input: torch.Tensor,
30+
symm_in_buf: torch.Tensor,
31+
out: torch.Tensor,
32+
in_splits_offsets: torch.Tensor,
33+
out_splits_offsets: torch.Tensor,
34+
group_name: str,
35+
) -> None:
36+
if symm_in_buf.shape[0] < input.shape[0]:
37+
raise RuntimeError(
38+
f"symm_in_buf with dim-0 length {symm_in_buf.shape[0]} cannot fit input with dim-0 length {input.shape[0]}"
39+
)
40+
if symm_in_buf.shape[1:] != input.shape[1:]:
41+
raise RuntimeError(
42+
f"symm_in_buf non-0 dims do not match that of input: {symm_in_buf.shape[1:]} vs {input.shape[1:]}"
43+
)
44+
if symm_in_buf.dtype != input.dtype:
45+
raise RuntimeError(
46+
f"symm_in_buf dtype {symm_in_buf.dtype} does not match input dtype {input.dtype}"
47+
)
48+
49+
symm_in_buf.narrow(0, 0, input.shape[0]).copy_(input)
50+
torch.ops.symm_mem.all_to_all_vdev_2d_offset(
51+
symm_in_buf,
52+
out,
53+
in_splits_offsets,
54+
out_splits_offsets,
55+
group_name,
56+
)
57+
58+
59+
class AllToAllVDev2d(torch.autograd.Function):
60+
"""
61+
Autograd function for `all_to_all_vdev_2d`
62+
"""
63+
64+
@staticmethod
65+
def forward( # type: ignore[no-untyped-def]
66+
ctx,
67+
input: torch.Tensor,
68+
out: torch.Tensor,
69+
in_splits: torch.Tensor,
70+
out_splits_offsets: torch.Tensor,
71+
group_name: str,
72+
major_align: int,
73+
# Buffers needed for backward pass
74+
grad_out_buf: torch.Tensor,
75+
grad_in_buf: torch.Tensor,
76+
grad_in_splits_offsets: torch.Tensor,
77+
) -> torch.Tensor:
78+
"""
79+
Functionality is the same as `all_to_all_vdev_2d` but with functionalization.
80+
"""
81+
# Shuffle input to output
82+
torch.ops.symm_mem.all_to_all_vdev_2d(
83+
input, out, in_splits, out_splits_offsets, group_name, major_align
84+
)
85+
86+
# Output splits in forward is the input splits in backward
87+
ctx.save_for_backward(
88+
out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets
89+
)
90+
ctx.group_name = group_name
91+
return out
92+
93+
@staticmethod
94+
def backward( # type: ignore[no-untyped-def]
95+
ctx,
96+
grad_output: torch.Tensor,
97+
) -> tuple[torch.Tensor, None, None, None, None, None, None, None, None]:
98+
"""
99+
Backward pass of `all_to_all_vdev_2d` is `all_to_all_vdev_2d_offset`.
100+
101+
Args:
102+
`grad_output`: gradients of output passed back from the downstream.
103+
104+
Returns:
105+
`grad_input`: gradients of input.
106+
"""
107+
# Splits info
108+
# Splits/offsets of grad_out is the same as out splits/offsets in forward
109+
(
110+
grad_out_splits_offsets,
111+
grad_out_buf,
112+
grad_in_buf,
113+
grad_in_splits_offsets,
114+
) = ctx.saved_tensors
115+
116+
# Shuffle gradients back to the input
117+
torch.ops.symm_mem.all_to_all_vdev_2d_offset_copy(
118+
grad_output,
119+
grad_out_buf,
120+
grad_in_buf,
121+
grad_out_splits_offsets,
122+
grad_in_splits_offsets,
123+
group_name=ctx.group_name,
124+
)
125+
return grad_in_buf, None, None, None, None, None, None, None, None
126+
127+
128+
class TokenDispatcher(torch.nn.Module):
129+
"""
130+
Dispatch tokens to different experts, with backward pass to shuffle gradients back to the input.
131+
Args:
132+
`group_name`: name of the group to use for communication.
133+
`align`: alignment of the token offsets for each receiving expert. If
134+
using Grouped Gemm next, this should be the same as Grouped Gemm's
135+
alignment.
136+
`in_len`: length of the input.
137+
`out_len`: length of the output.
138+
`token_shape`: shape of the tokens.
139+
`num_ranks`: number of ranks in the group.
140+
`num_local_experts`: number of local experts.
141+
`dtype`: data type of the input/output.
142+
`device`: device to use for communication.
143+
"""
144+
145+
def __init__(
146+
self,
147+
group_name: str,
148+
align: int,
149+
in_len,
150+
out_len,
151+
token_shape,
152+
num_ranks,
153+
num_local_experts,
154+
dtype,
155+
device: torch.device,
156+
) -> None:
157+
super().__init__()
158+
self.group_name = group_name
159+
self.align = align
160+
self.grad_out_buf = symm_mem.empty(
161+
out_len, *token_shape, dtype=dtype, device=device
162+
)
163+
self.grad_in_buf = symm_mem.empty(
164+
in_len, *token_shape, dtype=dtype, device=device
165+
)
166+
self.nsplits = num_ranks * num_local_experts
167+
self.grad_in_splits_offsets = symm_mem.empty(
168+
(2, self.nsplits), dtype=torch.int64, device=device
169+
)
170+
171+
def forward(
172+
self,
173+
inp: torch.Tensor,
174+
out: torch.Tensor,
175+
in_splits: torch.Tensor,
176+
out_splits_offsets: torch.Tensor,
177+
) -> torch.Tensor:
178+
"""
179+
Args:
180+
`inp`: input tensor.
181+
`out`: buffer for output tensor.
182+
`in_splits`: splits of the input tensor.
183+
`out_splits_offsets`: splits and offsets of the output tensor.
184+
See `all_to_all_vdev_2d` for more details.
185+
Note:
186+
All tensor arguments must be symmetrically allocated, i.e.
187+
>>> inp = symm_mem.empty(max_inp_len, dtype=dtype, device=device)
188+
>>> out = symm_mem.empty(max_out_len, dtype=dtype, device=device)
189+
>>> in_splits = symm_mem.empty(
190+
... nsplits, dtype=torch.int64, device=device)
191+
>>> out_splits_offsets = symm_mem.empty(
192+
... (2, nsplits), dtype=torch.int64, device=device)
193+
"""
194+
195+
if in_splits.numel() != self.nsplits:
196+
raise ValueError(f"Expected {self.nsplits} splits, got {in_splits.numel()}")
197+
if out_splits_offsets.shape != (2, self.nsplits):
198+
raise ValueError(
199+
f"Expected shape (2, {self.nsplits}), got {out_splits_offsets.shape}"
200+
)
201+
202+
return AllToAllVDev2d.apply(
203+
inp,
204+
out,
205+
in_splits,
206+
out_splits_offsets,
207+
self.group_name,
208+
self.align,
209+
self.grad_out_buf,
210+
self.grad_in_buf,
211+
self.grad_in_splits_offsets,
212+
)
213+
214+
215+
def test_token_dispatch() -> None:
216+
# Init
217+
dist.init_process_group()
218+
rank = dist.get_rank()
219+
world_size = dist.get_world_size()
220+
device_count = torch.cuda.device_count()
221+
device = torch.device("cuda", rank % device_count)
222+
223+
# NVSHMEM backend specific
224+
torch.cuda.set_device(device)
225+
torch.empty(1, device=device)
226+
# Set NVSHMEM as SymmMem backend
227+
symm_mem.set_backend("NVSHMEM")
228+
229+
# Mimics Group GEMM alignment
230+
align = 8
231+
torch.manual_seed(42 + rank)
232+
233+
group_name = dist.group.WORLD.group_name
234+
symm_mem.enable_symm_mem_for_group(group_name)
235+
236+
dtype = torch.float
237+
# Number of experts per rank
238+
ne = 8
239+
nsplits = ne * world_size
240+
241+
# Number of elements for an expert is random between [0, k)
242+
k = 10
243+
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=device)
244+
245+
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
246+
max_inp_len = k * nsplits
247+
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
248+
overflow_factor = world_size # worst case: one rank receives all data
249+
max_out_len = max_inp_len * overflow_factor
250+
251+
hid = 4096
252+
inp = symm_mem.empty(max_inp_len, hid, dtype=dtype, device=device)
253+
out = symm_mem.empty(max_out_len, hid, dtype=dtype, device=device)
254+
in_splits = symm_mem.empty(nsplits, dtype=torch.int64, device=device).copy_(
255+
inp_splits
256+
)
257+
# 2 rows: output splits, output offsets
258+
out_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device)
259+
260+
dispatcher = TokenDispatcher(
261+
group_name,
262+
align,
263+
max_inp_len,
264+
max_out_len,
265+
inp.shape[1:],
266+
world_size,
267+
ne,
268+
dtype,
269+
device,
270+
)
271+
272+
compiled_dispatcher = torch.compile(
273+
dispatcher,
274+
fullgraph=True,
275+
)
276+
277+
# Perform a Dot product with output, so that gradients passed back from
278+
# different ranks are different
279+
weight = torch.empty(max_out_len, dtype=dtype, device=device).fill_(rank + 1)
280+
281+
# Run a few iterations
282+
iters = 2
283+
for i in range(iters):
284+
# Test if gradients would be passed back from inp to tokens
285+
tokens = torch.randn(
286+
max_inp_len, hid, dtype=dtype, device=device
287+
).requires_grad_(True)
288+
tokens.grad = None
289+
inp.copy_(tokens)
290+
output = compiled_dispatcher(inp, out, in_splits, out_splits_offsets)
291+
p = torch.matmul(weight, output)
292+
p.sum().backward()
293+
294+
# Check gradients
295+
start = 0
296+
for i, split in enumerate(in_splits.tolist()):
297+
grad_chunk = tokens.grad[start : start + split]
298+
dst_rank = i // ne
299+
torch.testing.assert_close(
300+
grad_chunk,
301+
torch.empty(split, hid, device=device).fill_(dst_rank + 1),
302+
)
303+
start += split
304+
305+
dist.destroy_process_group()
306+
print(f"Rank {rank} passed")
307+
308+
309+
if __name__ == "__main__":
310+
# To run this test, use the following command:
311+
# torchrun --nproc-per-node 4 --standalone dispatch.py
312+
test_token_dispatch()

0 commit comments

Comments
 (0)