Skip to content

Commit 9854452

Browse files
committed
[a2av] Add autograd support for token combine op
1 parent 93a236c commit 9854452

File tree

1 file changed

+323
-0
lines changed

1 file changed

+323
-0
lines changed
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
import torch.distributed as dist
9+
import torch.distributed._symmetric_memory as symm_mem
10+
11+
from typing import Optional
12+
13+
14+
# Adding out-of-tree ops to the `symm_mem` library
15+
lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901
16+
17+
"""
18+
all_to_all_vdev_2d_copy:
19+
Copy data from `input` to `symm_in_buf` and call `all_to_all_vdev_2d` to shuffle data
20+
"""
21+
lib.define(
22+
"all_to_all_vdev_2d_copy("
23+
"Tensor input, Tensor symm_in_buf, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()",
24+
tags=[torch._C.Tag.needs_exact_strides],
25+
)
26+
27+
28+
@torch.library.impl(lib, "all_to_all_vdev_2d_copy", "CUDA")
29+
def _all_to_all_vdev_2d_copy_cuda(
30+
input: torch.Tensor,
31+
symm_in_buf: torch.Tensor,
32+
out: torch.Tensor,
33+
in_splits: torch.Tensor,
34+
out_splits_offsets: torch.Tensor,
35+
group_name: str,
36+
major_align: Optional[int] = None,
37+
) -> None:
38+
if symm_in_buf.shape[0] < input.shape[0]:
39+
raise RuntimeError(
40+
f"symm_in_buf with dim-0 length {symm_in_buf.shape[0]} cannot fit input with dim-0 length {input.shape[0]}"
41+
)
42+
if symm_in_buf.shape[1:] != input.shape[1:]:
43+
raise RuntimeError(
44+
f"symm_in_buf non-0 dims do not match that of input: {symm_in_buf.shape[1:]} vs {input.shape[1:]}"
45+
)
46+
if symm_in_buf.dtype != input.dtype:
47+
raise RuntimeError(
48+
f"symm_in_buf dtype {symm_in_buf.dtype} does not match input dtype {input.dtype}"
49+
)
50+
51+
symm_in_buf.narrow(0, 0, input.shape[0]).copy_(input)
52+
torch.ops.symm_mem.all_to_all_vdev_2d(
53+
symm_in_buf,
54+
out,
55+
in_splits,
56+
out_splits_offsets,
57+
group_name,
58+
major_align,
59+
)
60+
61+
62+
class AllToAllVDev2dOffset(torch.autograd.Function):
63+
"""
64+
Autograd function for `all_to_all_vdev_2d_offset`
65+
"""
66+
67+
@staticmethod
68+
def forward( # type: ignore[no-untyped-def]
69+
ctx,
70+
input: torch.Tensor,
71+
out: torch.Tensor,
72+
in_splits_offsets: torch.Tensor,
73+
out_splits_offsets: torch.Tensor,
74+
group_name: str,
75+
# Buffers needed for backward pass
76+
major_align: int,
77+
grad_out_buf: torch.Tensor,
78+
grad_in_buf: torch.Tensor,
79+
grad_in_splits_offsets: torch.Tensor,
80+
) -> torch.Tensor:
81+
"""
82+
Functionality is the same as `all_to_all_vdev_2d_offset` but with functionalization.
83+
"""
84+
# Shuffle input to output
85+
torch.ops.symm_mem.all_to_all_vdev_2d_offset(
86+
input, out, in_splits_offsets, out_splits_offsets, group_name
87+
)
88+
89+
# Output splits in forward is the input splits in backward
90+
ctx.save_for_backward(out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets)
91+
ctx.group_name = group_name
92+
ctx.major_align = major_align
93+
return out
94+
95+
@staticmethod
96+
def backward( # type: ignore[no-untyped-def]
97+
ctx,
98+
grad_output: torch.Tensor,
99+
) -> tuple[torch.Tensor, None, None, None, None, None, None, None, None]:
100+
"""
101+
Backward pass of `all_to_all_vdev_2d_offset` is `all_to_all_vdev_2d`.
102+
103+
Args:
104+
`grad_output`: gradients of output passed back from the downstream.
105+
106+
Returns:
107+
`grad_input`: gradients of input.
108+
"""
109+
# Splits info
110+
# Splits/offsets of grad_out is the same as out splits/offsets in forward
111+
(grad_out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets) = ctx.saved_tensors
112+
grad_out_splits = grad_out_splits_offsets[0]
113+
114+
# Shuffle gradients back to the input
115+
# TODO: create an op that takes both in_splits_offsets and
116+
# out_splits_offsets, instead of taking alignment
117+
torch.ops.symm_mem.all_to_all_vdev_2d_copy(
118+
grad_output,
119+
grad_out_buf,
120+
grad_in_buf,
121+
grad_out_splits,
122+
grad_in_splits_offsets,
123+
ctx.group_name,
124+
ctx.major_align,
125+
)
126+
127+
return grad_in_buf, None, None, None, None, None, None, None, None
128+
129+
130+
class TokenCombiner(torch.nn.Module):
131+
"""
132+
Combine tokens from different experts, with backward pass to shuffle gradients back to the input.
133+
Args:
134+
`group_name`: name of the group to use for communication.
135+
`align`: alignment of the token offsets from each expert. If using
136+
Grouped Gemm next, this should be the same as Grouped Gemm's alignment.
137+
`in_len`: length of the input.
138+
`out_len`: length of the output.
139+
`token_shape`: shape of the tokens.
140+
`num_ranks`: number of ranks in the group.
141+
`num_local_experts`: number of local experts.
142+
`dtype`: data type of the input/output.
143+
"""
144+
145+
def __init__(
146+
self,
147+
group_name: str,
148+
align: int,
149+
in_len,
150+
out_len,
151+
token_shape,
152+
num_ranks,
153+
num_local_experts,
154+
dtype,
155+
) -> None:
156+
super().__init__()
157+
self.group_name = group_name
158+
self.align = align
159+
self.grad_out_buf = symm_mem.empty(out_len, *token_shape, dtype=dtype)
160+
self.grad_in_buf = symm_mem.empty(in_len, *token_shape, dtype=dtype)
161+
self.nsplits = num_ranks * num_local_experts
162+
self.grad_in_splits_offsets = symm_mem.empty(
163+
(2, self.nsplits), dtype=torch.int64
164+
)
165+
166+
def forward(
167+
self,
168+
inp: torch.Tensor,
169+
out: torch.Tensor,
170+
in_splits_offsets: torch.Tensor,
171+
out_splits_offsets: torch.Tensor,
172+
) -> torch.Tensor:
173+
"""
174+
Args:
175+
`inp`: input tensor.
176+
`out`: buffer for output tensor.
177+
`in_splits_offsets`: splits and offsets of the input tensor.
178+
`out_splits_offsets`: splits and offsets of the output tensor.
179+
See `all_to_all_vdev_2d_offset` for more details.
180+
Note:
181+
All tensor arguments must be symmetrically allocated, i.e.
182+
>>> inp = symm_mem.empty(max_inp_len, dtype=dtype, device=device)
183+
>>> out = symm_mem.empty(max_out_len, dtype=dtype, device=device)
184+
>>> in_splits_offsets = symm_mem.empty(
185+
... (2, nsplits), dtype=torch.int64, device=device)
186+
>>> out_splits_offsets = symm_mem.empty(
187+
... (2, nsplits), dtype=torch.int64, device=device)
188+
"""
189+
190+
if in_splits_offsets.shape != (2, self.nsplits):
191+
raise ValueError(
192+
f"Expected shape (2, {self.nsplits}), got {in_splits_offsets.shape}"
193+
)
194+
if out_splits_offsets.shape != (2, self.nsplits):
195+
raise ValueError(
196+
f"Expected shape (2, {self.nsplits}), got {out_splits_offsets.shape}"
197+
)
198+
199+
return AllToAllVDev2dOffset.apply(
200+
inp,
201+
out,
202+
in_splits_offsets,
203+
out_splits_offsets,
204+
self.group_name,
205+
self.align,
206+
self.grad_out_buf,
207+
self.grad_in_buf,
208+
self.grad_in_splits_offsets,
209+
)
210+
211+
212+
def test_token_combine() -> None:
213+
# Init
214+
dist.init_process_group()
215+
rank = dist.get_rank()
216+
world_size = dist.get_world_size()
217+
device_count = torch.cuda.device_count()
218+
device = torch.device("cuda", rank % device_count)
219+
220+
# NVSHMEM backend specific
221+
torch.cuda.set_device(device)
222+
torch.empty(1, device=device)
223+
# Set NVSHMEM as SymmMem backend
224+
symm_mem.set_backend("NVSHMEM")
225+
226+
# Mimics Group GEMM alignment
227+
align = 8
228+
torch.manual_seed(42 + rank)
229+
230+
group_name = dist.group.WORLD.group_name
231+
symm_mem.enable_symm_mem_for_group(group_name)
232+
233+
dtype = torch.float
234+
# Number of experts per rank
235+
ne = 8
236+
nsplits = ne * world_size
237+
238+
# Number of elements for an expert is random between [0, k)
239+
k = 10
240+
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=device)
241+
242+
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
243+
max_inp_len = k * nsplits
244+
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
245+
overflow_factor = world_size # worst case: one rank receives all data
246+
max_out_len = max_inp_len * overflow_factor
247+
248+
# Use a dispatch to prepare the input for combine (this is just a
249+
# preparation, not the test itself)
250+
# Buffers for dispatch
251+
hid = 4096
252+
inp = symm_mem.empty(max_inp_len, hid, dtype=dtype, device=device)
253+
out = symm_mem.empty(max_out_len, hid, dtype=dtype, device=device)
254+
# 2 rows: input splits, input offsets
255+
in_splits_offsets = symm_mem.empty(
256+
(2, nsplits), dtype=torch.int64, device=device
257+
)
258+
# 2 rows: output splits, output offsets
259+
out_splits_offsets = symm_mem.empty(
260+
(2, nsplits), dtype=torch.int64, device=device
261+
)
262+
263+
# Dispatch the tokens first so that we have a nice input for combine
264+
in_splits_offsets[0].copy_(inp_splits)
265+
torch.ops.symm_mem.all_to_all_vdev_2d(
266+
inp, out, in_splits_offsets[0], out_splits_offsets, group_name, major_align=align
267+
)
268+
269+
with device:
270+
combiner = TokenCombiner(
271+
group_name,
272+
align,
273+
max_out_len,
274+
max_inp_len,
275+
out.shape[1:],
276+
world_size,
277+
ne,
278+
dtype,
279+
)
280+
281+
compiled_combiner = torch.compile(
282+
combiner,
283+
fullgraph=True,
284+
)
285+
286+
# Perform a Dot product with output, so that gradients passed back from
287+
# different ranks are different
288+
weight = torch.empty(max_inp_len, dtype=dtype, device=device).fill_(rank + 1)
289+
290+
# Now we start to test the autograd function
291+
292+
# Requires grad for input of combine
293+
out.requires_grad_(True)
294+
295+
combine_out = compiled_combiner(
296+
out,
297+
inp,
298+
out_splits_offsets,
299+
in_splits_offsets,
300+
)
301+
p = torch.matmul(weight, combine_out)
302+
p.sum().backward()
303+
304+
# Check gradients
305+
# We also need to skip the padding in the input data
306+
out_splits = out_splits_offsets[0].tolist()
307+
out_offsets = out_splits_offsets[1].tolist()
308+
for i, (split, offset) in enumerate(zip(out_splits, out_offsets)):
309+
grad_chunk = out.grad[offset : offset + split]
310+
dst_rank = i % world_size
311+
torch.testing.assert_close(
312+
grad_chunk,
313+
torch.empty(split, hid, device=device).fill_(dst_rank + 1),
314+
)
315+
316+
dist.destroy_process_group()
317+
print(f"Rank {rank} passed")
318+
319+
320+
if __name__ == "__main__":
321+
# To run this test, use the following command:
322+
# torchrun --nproc-per-node 4 --standalone combine.py
323+
test_token_combine()

0 commit comments

Comments
 (0)