Skip to content

Commit bc3833f

Browse files
committed
[a2av] Add autograd support for token combine op
1 parent 93a236c commit bc3833f

File tree

1 file changed

+332
-0
lines changed

1 file changed

+332
-0
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
import torch.distributed as dist
11+
import torch.distributed._symmetric_memory as symm_mem
12+
13+
14+
# Adding out-of-tree ops to the `symm_mem` library
15+
lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901
16+
17+
"""
18+
all_to_all_vdev_2d_copy:
19+
Copy data from `input` to `symm_in_buf` and call `all_to_all_vdev_2d` to shuffle data
20+
"""
21+
lib.define(
22+
"all_to_all_vdev_2d_copy("
23+
"Tensor input, Tensor symm_in_buf, Tensor(a!) out, "
24+
"Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()",
25+
tags=[torch._C.Tag.needs_exact_strides],
26+
)
27+
28+
29+
@torch.library.impl(lib, "all_to_all_vdev_2d_copy", "CUDA")
30+
def _all_to_all_vdev_2d_copy_cuda(
31+
input: torch.Tensor,
32+
symm_in_buf: torch.Tensor,
33+
out: torch.Tensor,
34+
in_splits: torch.Tensor,
35+
out_splits_offsets: torch.Tensor,
36+
group_name: str,
37+
major_align: Optional[int] = None,
38+
) -> None:
39+
if symm_in_buf.shape[0] < input.shape[0]:
40+
raise RuntimeError(
41+
f"symm_in_buf with dim-0 length {symm_in_buf.shape[0]} cannot fit input with dim-0 length {input.shape[0]}"
42+
)
43+
if symm_in_buf.shape[1:] != input.shape[1:]:
44+
raise RuntimeError(
45+
f"symm_in_buf non-0 dims do not match that of input: {symm_in_buf.shape[1:]} vs {input.shape[1:]}"
46+
)
47+
if symm_in_buf.dtype != input.dtype:
48+
raise RuntimeError(
49+
f"symm_in_buf dtype {symm_in_buf.dtype} does not match input dtype {input.dtype}"
50+
)
51+
52+
symm_in_buf.narrow(0, 0, input.shape[0]).copy_(input)
53+
torch.ops.symm_mem.all_to_all_vdev_2d(
54+
symm_in_buf,
55+
out,
56+
in_splits,
57+
out_splits_offsets,
58+
group_name,
59+
major_align,
60+
)
61+
62+
63+
class AllToAllVDev2dOffset(torch.autograd.Function):
64+
"""
65+
Autograd function for `all_to_all_vdev_2d_offset`
66+
"""
67+
68+
@staticmethod
69+
def forward( # type: ignore[no-untyped-def]
70+
ctx,
71+
input: torch.Tensor,
72+
out: torch.Tensor,
73+
in_splits_offsets: torch.Tensor,
74+
out_splits_offsets: torch.Tensor,
75+
group_name: str,
76+
# Buffers needed for backward pass
77+
major_align: int,
78+
grad_out_buf: torch.Tensor,
79+
grad_in_buf: torch.Tensor,
80+
grad_in_splits_offsets: torch.Tensor,
81+
) -> torch.Tensor:
82+
"""
83+
Functionality is the same as `all_to_all_vdev_2d_offset` but with functionalization.
84+
"""
85+
# Shuffle input to output
86+
torch.ops.symm_mem.all_to_all_vdev_2d_offset(
87+
input, out, in_splits_offsets, out_splits_offsets, group_name
88+
)
89+
90+
# Output splits in forward is the input splits in backward
91+
ctx.save_for_backward(
92+
out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets
93+
)
94+
ctx.group_name = group_name
95+
ctx.major_align = major_align
96+
return out
97+
98+
@staticmethod
99+
def backward( # type: ignore[no-untyped-def]
100+
ctx,
101+
grad_output: torch.Tensor,
102+
) -> tuple[torch.Tensor, None, None, None, None, None, None, None, None]:
103+
"""
104+
Backward pass of `all_to_all_vdev_2d_offset` is `all_to_all_vdev_2d`.
105+
106+
Args:
107+
`grad_output`: gradients of output passed back from the downstream.
108+
109+
Returns:
110+
`grad_input`: gradients of input.
111+
"""
112+
# Splits info
113+
# Splits/offsets of grad_out is the same as out splits/offsets in forward
114+
(
115+
grad_out_splits_offsets,
116+
grad_out_buf,
117+
grad_in_buf,
118+
grad_in_splits_offsets,
119+
) = ctx.saved_tensors
120+
grad_out_splits = grad_out_splits_offsets[0]
121+
122+
# Shuffle gradients back to the input
123+
# TODO: create an op that takes both in_splits_offsets and
124+
# out_splits_offsets, instead of taking alignment
125+
torch.ops.symm_mem.all_to_all_vdev_2d_copy(
126+
grad_output,
127+
grad_out_buf,
128+
grad_in_buf,
129+
grad_out_splits,
130+
grad_in_splits_offsets,
131+
ctx.group_name,
132+
ctx.major_align,
133+
)
134+
135+
return grad_in_buf, None, None, None, None, None, None, None, None
136+
137+
138+
class TokenCombiner(torch.nn.Module):
139+
"""
140+
Combine tokens from different experts, with backward pass to shuffle gradients back to the input.
141+
Args:
142+
`group_name`: name of the group to use for communication.
143+
`align`: alignment of the token offsets from each expert. If using
144+
Grouped Gemm next, this should be the same as Grouped Gemm's alignment.
145+
`in_len`: length of the input.
146+
`out_len`: length of the output.
147+
`token_shape`: shape of the tokens.
148+
`num_ranks`: number of ranks in the group.
149+
`num_local_experts`: number of local experts.
150+
`dtype`: data type of the input/output.
151+
"""
152+
153+
def __init__(
154+
self,
155+
group_name: str,
156+
align: int,
157+
in_len,
158+
out_len,
159+
token_shape,
160+
num_ranks,
161+
num_local_experts,
162+
dtype,
163+
) -> None:
164+
super().__init__()
165+
self.group_name = group_name
166+
self.align = align
167+
self.grad_out_buf = symm_mem.empty(out_len, *token_shape, dtype=dtype)
168+
self.grad_in_buf = symm_mem.empty(in_len, *token_shape, dtype=dtype)
169+
self.nsplits = num_ranks * num_local_experts
170+
self.grad_in_splits_offsets = symm_mem.empty(
171+
(2, self.nsplits), dtype=torch.int64
172+
)
173+
174+
def forward(
175+
self,
176+
inp: torch.Tensor,
177+
out: torch.Tensor,
178+
in_splits_offsets: torch.Tensor,
179+
out_splits_offsets: torch.Tensor,
180+
) -> torch.Tensor:
181+
"""
182+
Args:
183+
`inp`: input tensor.
184+
`out`: buffer for output tensor.
185+
`in_splits_offsets`: splits and offsets of the input tensor.
186+
`out_splits_offsets`: splits and offsets of the output tensor.
187+
See `all_to_all_vdev_2d_offset` for more details.
188+
Note:
189+
All tensor arguments must be symmetrically allocated, i.e.
190+
>>> inp = symm_mem.empty(max_inp_len, dtype=dtype, device=device)
191+
>>> out = symm_mem.empty(max_out_len, dtype=dtype, device=device)
192+
>>> in_splits_offsets = symm_mem.empty(
193+
... (2, nsplits), dtype=torch.int64, device=device)
194+
>>> out_splits_offsets = symm_mem.empty(
195+
... (2, nsplits), dtype=torch.int64, device=device)
196+
"""
197+
198+
if in_splits_offsets.shape != (2, self.nsplits):
199+
raise ValueError(
200+
f"Expected shape (2, {self.nsplits}), got {in_splits_offsets.shape}"
201+
)
202+
if out_splits_offsets.shape != (2, self.nsplits):
203+
raise ValueError(
204+
f"Expected shape (2, {self.nsplits}), got {out_splits_offsets.shape}"
205+
)
206+
207+
return AllToAllVDev2dOffset.apply(
208+
inp,
209+
out,
210+
in_splits_offsets,
211+
out_splits_offsets,
212+
self.group_name,
213+
self.align,
214+
self.grad_out_buf,
215+
self.grad_in_buf,
216+
self.grad_in_splits_offsets,
217+
)
218+
219+
220+
def test_token_combine() -> None:
221+
# Init
222+
dist.init_process_group()
223+
rank = dist.get_rank()
224+
world_size = dist.get_world_size()
225+
device_count = torch.cuda.device_count()
226+
device = torch.device("cuda", rank % device_count)
227+
228+
# NVSHMEM backend specific
229+
torch.cuda.set_device(device)
230+
torch.empty(1, device=device)
231+
# Set NVSHMEM as SymmMem backend
232+
symm_mem.set_backend("NVSHMEM")
233+
234+
# Mimics Group GEMM alignment
235+
align = 8
236+
torch.manual_seed(42 + rank)
237+
238+
group_name = dist.group.WORLD.group_name
239+
symm_mem.enable_symm_mem_for_group(group_name)
240+
241+
dtype = torch.float
242+
# Number of experts per rank
243+
ne = 8
244+
nsplits = ne * world_size
245+
246+
# Number of elements for an expert is random between [0, k)
247+
k = 10
248+
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=device)
249+
250+
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
251+
max_inp_len = k * nsplits
252+
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
253+
overflow_factor = world_size # worst case: one rank receives all data
254+
max_out_len = max_inp_len * overflow_factor
255+
256+
# Use a dispatch to prepare the input for combine (this is just a
257+
# preparation, not the test itself)
258+
# Buffers for dispatch
259+
hid = 4096
260+
inp = symm_mem.empty(max_inp_len, hid, dtype=dtype, device=device)
261+
out = symm_mem.empty(max_out_len, hid, dtype=dtype, device=device)
262+
# 2 rows: input splits, input offsets
263+
in_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device)
264+
# 2 rows: output splits, output offsets
265+
out_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device)
266+
267+
# Dispatch the tokens first so that we have a nice input for combine
268+
in_splits_offsets[0].copy_(inp_splits)
269+
torch.ops.symm_mem.all_to_all_vdev_2d(
270+
inp,
271+
out,
272+
in_splits_offsets[0],
273+
out_splits_offsets,
274+
group_name,
275+
major_align=align,
276+
)
277+
278+
with device:
279+
combiner = TokenCombiner(
280+
group_name,
281+
align,
282+
max_out_len,
283+
max_inp_len,
284+
out.shape[1:],
285+
world_size,
286+
ne,
287+
dtype,
288+
)
289+
290+
compiled_combiner = torch.compile(
291+
combiner,
292+
fullgraph=True,
293+
)
294+
295+
# Perform a Dot product with output, so that gradients passed back from
296+
# different ranks are different
297+
weight = torch.empty(max_inp_len, dtype=dtype, device=device).fill_(rank + 1)
298+
299+
# Now we start to test the autograd function
300+
301+
# Requires grad for input of combine
302+
out.requires_grad_(True)
303+
304+
combine_out = compiled_combiner(
305+
out,
306+
inp,
307+
out_splits_offsets,
308+
in_splits_offsets,
309+
)
310+
p = torch.matmul(weight, combine_out)
311+
p.sum().backward()
312+
313+
# Check gradients
314+
# We also need to skip the padding in the input data
315+
out_splits = out_splits_offsets[0].tolist()
316+
out_offsets = out_splits_offsets[1].tolist()
317+
for i, (split, offset) in enumerate(zip(out_splits, out_offsets)):
318+
grad_chunk = out.grad[offset : offset + split]
319+
dst_rank = i % world_size
320+
torch.testing.assert_close(
321+
grad_chunk,
322+
torch.empty(split, hid, device=device).fill_(dst_rank + 1),
323+
)
324+
325+
dist.destroy_process_group()
326+
print(f"Rank {rank} passed")
327+
328+
329+
if __name__ == "__main__":
330+
# To run this test, use the following command:
331+
# torchrun --nproc-per-node 4 --standalone combine.py
332+
test_token_combine()

0 commit comments

Comments
 (0)