Skip to content

[a2av] Add autograd support for token combine op #1511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 12, 2025
Merged

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Aug 1, 2025

Added class TokenCombiner which combines tokens from different experts, with backward support.

Usage:

combiner = TokenCombiner(group_name, align, max_inp_len, max_out_len, inp.shape[1:], world_size, ne, dtype)
# inp, out, in_splits_offsets, out_splits_offsets must be symmetric tensors
output = combiner(inp, out, in_splits_offsets, out_splits_offsets)

Supports:

torch.compile(combiner)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 1, 2025
@kwen2501 kwen2501 requested a review from tianyu-l August 1, 2025 00:08
Comment on lines 167 to 168
self.grad_out_buf = symm_mem.empty(out_len, *token_shape, dtype=dtype)
self.grad_in_buf = symm_mem.empty(in_len, *token_shape, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar request: please allow device as an input to constructor

@kwen2501 kwen2501 merged commit cf4de26 into main Aug 12, 2025
4 checks passed
@tianyu-l tianyu-l deleted the combine_autograd branch August 12, 2025 20:42
@vwxyzjn
Copy link

vwxyzjn commented Aug 12, 2025

Hi @kwen2501 is this kernels for all to all communications? Any ways to test it out with deepseek v3?

@kwen2501
Copy link
Contributor Author

kwen2501 commented Aug 12, 2025

@vwxyzjn Yes, the inner part of the PR is an op named torch.ops.symm_mem.all_to_all_vdev_2d_offset, which accepts token splits + offsets on GPU. The PR just wraps that op with Autograd support.

Similarly, for token dispatch, we have an op named all_to_all_vdev_2d. There is more documentation about these two ops here:
https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu#L557
https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu#L704

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants