Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Oct 17, 2025

Summary

  • Make EP a2a dispatch and a2a combine each be separately configurable to use either "default" or "mxfp8" impl
  • "mxfp8" impl uses torchao's new to_mxfp8_a2a_dequant, which has the exact same API as functional collective all_to_all_single_autograd and is differentiable, so it can be used as a drop-in replacement for the default a2a impl.
  • torchao to_mxfp8_a2a_dequant works as follows:
    • quantizes the inputs to mxfp8
    • all_to_all_single on e4m3 data
    • all_to_all_single on e8m0 scales
    • dequantize outputs back to original precision

Performance

  • Single node benchmarks with 4xB200

  • Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8

  • Reduced num layers from 48 -> 2 to avoid OOM in single node setting

  • Debug model config:

llama4_configs = {
    "debugmodel": TransformerModelArgs(
        dim=5120,
        n_layers=2,
        n_heads=40,
        n_kv_heads=8,
        ffn_dim_multiplier=1.2,
        multiple_of=2048,
        rope_theta=500000,
        max_seq_len=10485760,
        moe_args=MoEArgs(num_experts=16),
        interleave_moe_layer_step=1,
    ),
Configuration Throughput (Median Tokens/s) Max Memory (GiB)
bf16 baseline 49381.0 145.55
MXFP8 for Linears only 52038.0 146.62
MXFP8 for Grouped GEMMs only 69350.0 144.71
MXFP8 for Linears + Grouped GEMMs 70747.0 145.32
MXFP8 for Linears + Grouped GEMMs + A2A Dispatch 72602.5 145.45
MXFP8 for Linears + Grouped GEMMs + A2A Dispatch + A2A Combine 73152.0 146.08

Additional context on design/implementation choices

  • Note: both default and mxfp8 impls require the d2h sync to get input_splits/output_splits on the host for the a2a call.
    • I also explored a no-sync/on-device implementation using Triton + Symmetric memory, and got it working e2e in a torchtitan PoC: [mxfp8 moe training] mxfp8 a2a working e2e in torchtitan llama4 training; improve tests + bench scripts ao#3088
    • I found that this design of preallocating over-allocated symmetric memory buffers for exchange of variable token numbers (to avoid syncs required for exact allocation, while risking either crash or token dropping if overflow factor heuristic is wrong), is fundamentally in conflict with the torchtitan MoE design of doing a d2h sync to safely do exact allocation. Extracting out the variable size outputs from the padded buffers causes d2h sync (causing perf to regress below baseline), and we can't avoid this since otherwise downstream ops will break due to shape mismatches - the whole model basically would need to be designed assuming the static padded shapes.
    • Therefore, we choose to integrate this more straight-forward impl that is natively compatible with non-experimental titan MoE design

Additional background on motivation

  • MoE performance literature has shown ~47% average runtime for flagship OSS MoE models (Qwen2, Phi3.5, Mixtra8x7b) is due to exposed MoE comms.
  • Torchtitan Llama4 debug model with EP=4, ~30% of MoE training with EP is a2a comms, most of that exposed (see trace screenshot), which directionally corroborates this.
  • We can optimize this via (1) quantizing the comms to minimize data sent over NVLink/IB, (2) avoid d2h sync that can occur in implementations which move a2a output splits from device->host to compute exact preallocation necessary for incoming tokens, and (3) finer grained overlapping techniques.

30% of llama4 model profiled runtime is all2all comms

  • FSDP=4, EP=4, dim=5120, num_experts=16, seq_len=8192, local_batch_size=8
Screenshot 2025-09-29 at 3 08 47 PM

47% avg runtime devoted to MoE comms in profiled OSS models

Screenshot 2025-09-29 at 3 11 00 PM

@danielvegamyhre danielvegamyhre marked this pull request as draft October 17, 2025 18:47
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 17, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review October 17, 2025 21:19

def apply_moe_ep_tp(
model: nn.Module,
job_config: JobConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's only send in job_config.quantize instead of the whole job_config
Also let's make it a non-positional arg to avoid unnecessary BC breaking

grouped_mm: QuantizedGroupedMM = field(default_factory=QuantizedGroupedMM)
"""Quantized training config for grouped GEMMs"""

expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not intuitive to let Quantize to dictate what a2a impl should be. Instead we should let Quantize to override the default setting.

I'd suggest we rename this to override_ep_a2a_dispatch. Since we only have mxfp8 for now, you can go with bool, or you could make it Literal | None.

Comment on lines +21 to +22
self._a2a_dispatch_impl = to_mxfp8_a2a_dequant
self._a2a_combine_impl = to_mxfp8_a2a_dequant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait you are changing both together, not independently

Comment on lines +460 to +465
logger.info(
f"Using all-to-all dispatch implementation: {job_config.quantize.expert_parallel_a2a_dispatch_impl}"
)
logger.info(
f"Using all-to-all combine implementation: {job_config.quantize.expert_parallel_a2a_combine_impl}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to the comment above, we should print these override info only when they are overridden during quantization.

Comment on lines +453 to +458
assert (
job_config.quantize.expert_parallel_a2a_dispatch_impl in EP_IMPLS
), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_dispatch_impl}, must be one of {EP_IMPLS.keys()}"
assert (
job_config.quantize.expert_parallel_a2a_combine_impl in EP_IMPLS
), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_combine_impl}, must be one of {EP_IMPLS.keys()}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do this in quantized EP class, if you adopt the Literal version of config

Comment on lines +449 to +452
EP_IMPLS = {
"default": ExpertParallel,
"mxfp8": MXExpertParallel,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some if-else on quantize_config is enough, no need to have this EP_IMPLS variable here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants