Skip to content

Commit cf4de26

Browse files
authored
[a2av] Add autograd support for token combine op (#1511)
Added class `TokenCombiner` which combines tokens from different experts, with backward support. Usage: ``` combiner = TokenCombiner(group_name, align, max_inp_len, max_out_len, inp.shape[1:], world_size, ne, dtype) # inp, out, in_splits_offsets, out_splits_offsets must be symmetric tensors output = combiner(inp, out, in_splits_offsets, out_splits_offsets) ``` Supports: ``` torch.compile(combiner) ```
1 parent 9c42b9b commit cf4de26

File tree

1 file changed

+337
-0
lines changed

1 file changed

+337
-0
lines changed
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
import torch.distributed as dist
11+
import torch.distributed._symmetric_memory as symm_mem
12+
13+
14+
# Adding out-of-tree ops to the `symm_mem` library
15+
lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901
16+
17+
"""
18+
all_to_all_vdev_2d_copy:
19+
Copy data from `input` to `symm_in_buf` and call `all_to_all_vdev_2d` to shuffle data
20+
"""
21+
lib.define(
22+
"all_to_all_vdev_2d_copy("
23+
"Tensor input, Tensor symm_in_buf, Tensor(a!) out, "
24+
"Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()",
25+
tags=[torch._C.Tag.needs_exact_strides],
26+
)
27+
28+
29+
@torch.library.impl(lib, "all_to_all_vdev_2d_copy", "CUDA")
30+
def _all_to_all_vdev_2d_copy_cuda(
31+
input: torch.Tensor,
32+
symm_in_buf: torch.Tensor,
33+
out: torch.Tensor,
34+
in_splits: torch.Tensor,
35+
out_splits_offsets: torch.Tensor,
36+
group_name: str,
37+
major_align: Optional[int] = None,
38+
) -> None:
39+
if symm_in_buf.shape[0] < input.shape[0]:
40+
raise RuntimeError(
41+
f"symm_in_buf with dim-0 length {symm_in_buf.shape[0]} cannot fit input with dim-0 length {input.shape[0]}"
42+
)
43+
if symm_in_buf.shape[1:] != input.shape[1:]:
44+
raise RuntimeError(
45+
f"symm_in_buf non-0 dims do not match that of input: {symm_in_buf.shape[1:]} vs {input.shape[1:]}"
46+
)
47+
if symm_in_buf.dtype != input.dtype:
48+
raise RuntimeError(
49+
f"symm_in_buf dtype {symm_in_buf.dtype} does not match input dtype {input.dtype}"
50+
)
51+
52+
symm_in_buf.narrow(0, 0, input.shape[0]).copy_(input)
53+
torch.ops.symm_mem.all_to_all_vdev_2d(
54+
symm_in_buf,
55+
out,
56+
in_splits,
57+
out_splits_offsets,
58+
group_name,
59+
major_align,
60+
)
61+
62+
63+
class AllToAllVDev2dOffset(torch.autograd.Function):
64+
"""
65+
Autograd function for `all_to_all_vdev_2d_offset`
66+
"""
67+
68+
@staticmethod
69+
def forward( # type: ignore[no-untyped-def]
70+
ctx,
71+
input: torch.Tensor,
72+
out: torch.Tensor,
73+
in_splits_offsets: torch.Tensor,
74+
out_splits_offsets: torch.Tensor,
75+
group_name: str,
76+
# Buffers needed for backward pass
77+
major_align: int,
78+
grad_out_buf: torch.Tensor,
79+
grad_in_buf: torch.Tensor,
80+
grad_in_splits_offsets: torch.Tensor,
81+
) -> torch.Tensor:
82+
"""
83+
Functionality is the same as `all_to_all_vdev_2d_offset` but with functionalization.
84+
"""
85+
# Shuffle input to output
86+
torch.ops.symm_mem.all_to_all_vdev_2d_offset(
87+
input, out, in_splits_offsets, out_splits_offsets, group_name
88+
)
89+
90+
# Output splits in forward is the input splits in backward
91+
ctx.save_for_backward(
92+
out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets
93+
)
94+
ctx.group_name = group_name
95+
ctx.major_align = major_align
96+
return out
97+
98+
@staticmethod
99+
def backward( # type: ignore[no-untyped-def]
100+
ctx,
101+
grad_output: torch.Tensor,
102+
) -> tuple[torch.Tensor, None, None, None, None, None, None, None, None]:
103+
"""
104+
Backward pass of `all_to_all_vdev_2d_offset` is `all_to_all_vdev_2d`.
105+
106+
Args:
107+
`grad_output`: gradients of output passed back from the downstream.
108+
109+
Returns:
110+
`grad_input`: gradients of input.
111+
"""
112+
# Splits info
113+
# Splits/offsets of grad_out is the same as out splits/offsets in forward
114+
(
115+
grad_out_splits_offsets,
116+
grad_out_buf,
117+
grad_in_buf,
118+
grad_in_splits_offsets,
119+
) = ctx.saved_tensors
120+
grad_out_splits = grad_out_splits_offsets[0]
121+
122+
# Shuffle gradients back to the input
123+
# TODO: create an op that takes both in_splits_offsets and
124+
# out_splits_offsets, instead of taking alignment
125+
torch.ops.symm_mem.all_to_all_vdev_2d_copy(
126+
grad_output,
127+
grad_out_buf,
128+
grad_in_buf,
129+
grad_out_splits,
130+
grad_in_splits_offsets,
131+
ctx.group_name,
132+
ctx.major_align,
133+
)
134+
135+
return grad_in_buf, None, None, None, None, None, None, None, None
136+
137+
138+
class TokenCombiner(torch.nn.Module):
139+
"""
140+
Combine tokens from different experts, with backward pass to shuffle gradients back to the input.
141+
Args:
142+
`group_name`: name of the group to use for communication.
143+
`align`: alignment of the token offsets from each expert. If using
144+
Grouped Gemm next, this should be the same as Grouped Gemm's alignment.
145+
`in_len`: length of the input.
146+
`out_len`: length of the output.
147+
`token_shape`: shape of the tokens.
148+
`num_ranks`: number of ranks in the group.
149+
`num_local_experts`: number of local experts.
150+
`dtype`: data type of the input/output.
151+
"""
152+
153+
def __init__(
154+
self,
155+
group_name: str,
156+
align: int,
157+
in_len,
158+
out_len,
159+
token_shape,
160+
num_ranks,
161+
num_local_experts,
162+
dtype,
163+
device: torch.device,
164+
) -> None:
165+
super().__init__()
166+
self.group_name = group_name
167+
self.align = align
168+
self.grad_out_buf = symm_mem.empty(
169+
out_len, *token_shape, dtype=dtype, device=device
170+
)
171+
self.grad_in_buf = symm_mem.empty(
172+
in_len, *token_shape, dtype=dtype, device=device
173+
)
174+
self.nsplits = num_ranks * num_local_experts
175+
self.grad_in_splits_offsets = symm_mem.empty(
176+
(2, self.nsplits), dtype=torch.int64, device=device
177+
)
178+
179+
def forward(
180+
self,
181+
inp: torch.Tensor,
182+
out: torch.Tensor,
183+
in_splits_offsets: torch.Tensor,
184+
out_splits_offsets: torch.Tensor,
185+
) -> torch.Tensor:
186+
"""
187+
Args:
188+
`inp`: input tensor.
189+
`out`: buffer for output tensor.
190+
`in_splits_offsets`: splits and offsets of the input tensor.
191+
`out_splits_offsets`: splits and offsets of the output tensor.
192+
See `all_to_all_vdev_2d_offset` for more details.
193+
Note:
194+
All tensor arguments must be symmetrically allocated, i.e.
195+
>>> inp = symm_mem.empty(max_inp_len, dtype=dtype, device=device)
196+
>>> out = symm_mem.empty(max_out_len, dtype=dtype, device=device)
197+
>>> in_splits_offsets = symm_mem.empty(
198+
... (2, nsplits), dtype=torch.int64, device=device)
199+
>>> out_splits_offsets = symm_mem.empty(
200+
... (2, nsplits), dtype=torch.int64, device=device)
201+
"""
202+
203+
if in_splits_offsets.shape != (2, self.nsplits):
204+
raise ValueError(
205+
f"Expected shape (2, {self.nsplits}), got {in_splits_offsets.shape}"
206+
)
207+
if out_splits_offsets.shape != (2, self.nsplits):
208+
raise ValueError(
209+
f"Expected shape (2, {self.nsplits}), got {out_splits_offsets.shape}"
210+
)
211+
212+
return AllToAllVDev2dOffset.apply(
213+
inp,
214+
out,
215+
in_splits_offsets,
216+
out_splits_offsets,
217+
self.group_name,
218+
self.align,
219+
self.grad_out_buf,
220+
self.grad_in_buf,
221+
self.grad_in_splits_offsets,
222+
)
223+
224+
225+
def test_token_combine() -> None:
226+
# Init
227+
dist.init_process_group()
228+
rank = dist.get_rank()
229+
world_size = dist.get_world_size()
230+
device_count = torch.cuda.device_count()
231+
device = torch.device("cuda", rank % device_count)
232+
233+
# NVSHMEM backend specific
234+
torch.cuda.set_device(device)
235+
torch.empty(1, device=device)
236+
# Set NVSHMEM as SymmMem backend
237+
symm_mem.set_backend("NVSHMEM")
238+
239+
# Mimics Group GEMM alignment
240+
align = 8
241+
torch.manual_seed(42 + rank)
242+
243+
group_name = dist.group.WORLD.group_name
244+
symm_mem.enable_symm_mem_for_group(group_name)
245+
246+
dtype = torch.float
247+
# Number of experts per rank
248+
ne = 8
249+
nsplits = ne * world_size
250+
251+
# Number of elements for an expert is random between [0, k)
252+
k = 10
253+
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=device)
254+
255+
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
256+
max_inp_len = k * nsplits
257+
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
258+
overflow_factor = world_size # worst case: one rank receives all data
259+
max_out_len = max_inp_len * overflow_factor
260+
261+
# Use a dispatch to prepare the input for combine (this is just a
262+
# preparation, not the test itself)
263+
# Buffers for dispatch
264+
hid = 4096
265+
inp = symm_mem.empty(max_inp_len, hid, dtype=dtype, device=device)
266+
out = symm_mem.empty(max_out_len, hid, dtype=dtype, device=device)
267+
# 2 rows: input splits, input offsets
268+
in_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device)
269+
# 2 rows: output splits, output offsets
270+
out_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device)
271+
272+
# Dispatch the tokens first so that we have a nice input for combine
273+
in_splits_offsets[0].copy_(inp_splits)
274+
torch.ops.symm_mem.all_to_all_vdev_2d(
275+
inp,
276+
out,
277+
in_splits_offsets[0],
278+
out_splits_offsets,
279+
group_name,
280+
major_align=align,
281+
)
282+
283+
combiner = TokenCombiner(
284+
group_name,
285+
align,
286+
max_out_len,
287+
max_inp_len,
288+
out.shape[1:],
289+
world_size,
290+
ne,
291+
dtype,
292+
device,
293+
)
294+
295+
compiled_combiner = torch.compile(
296+
combiner,
297+
fullgraph=True,
298+
)
299+
300+
# Perform a Dot product with output, so that gradients passed back from
301+
# different ranks are different
302+
weight = torch.empty(max_inp_len, dtype=dtype, device=device).fill_(rank + 1)
303+
304+
# Now we start to test the autograd function
305+
306+
# Requires grad for input of combine
307+
out.requires_grad_(True)
308+
309+
combine_out = compiled_combiner(
310+
out,
311+
inp,
312+
out_splits_offsets,
313+
in_splits_offsets,
314+
)
315+
p = torch.matmul(weight, combine_out)
316+
p.sum().backward()
317+
318+
# Check gradients
319+
# We also need to skip the padding in the input data
320+
out_splits = out_splits_offsets[0].tolist()
321+
out_offsets = out_splits_offsets[1].tolist()
322+
for i, (split, offset) in enumerate(zip(out_splits, out_offsets)):
323+
grad_chunk = out.grad[offset : offset + split]
324+
dst_rank = i % world_size
325+
torch.testing.assert_close(
326+
grad_chunk,
327+
torch.empty(split, hid, device=device).fill_(dst_rank + 1),
328+
)
329+
330+
dist.destroy_process_group()
331+
print(f"Rank {rank} passed")
332+
333+
334+
if __name__ == "__main__":
335+
# To run this test, use the following command:
336+
# torchrun --nproc-per-node 4 --standalone combine.py
337+
test_token_combine()

0 commit comments

Comments
 (0)