Skip to content

[moe training] refactor to share benchmarking and profiling utils #2767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import argparse
import copy
import os
import statistics
from time import perf_counter_ns

import pytest
import torch
Expand All @@ -24,6 +22,11 @@
from torch.distributed._composable.fsdp import fully_shard
from torch.nn import functional as F

from benchmarks.prototype.moe_training.utils import (
bench_fwd_bwd_microseconds,
profile_fn,
)

# this feature requires CUDA and SM89+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
pytest.skip(
Expand All @@ -48,8 +51,12 @@
)


def bench_moe_float8_training_fsdp(enable_profile=False):
def bench_moe_float8_training_fsdp(
recipe_name: str, enable_profile: bool, use_compile: bool
):
assert torch.cuda.is_available()
assert recipe_name in ["fp8_rowwise", "mxfp8"]
recipe = MoEScalingType[recipe_name.upper()]

# setup distributed for fsdp
setup_distributed()
Expand All @@ -62,15 +69,19 @@ def bench_moe_float8_training_fsdp(enable_profile=False):
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
dim, hidden_dim = 5120, 4 * 5120
# reference bf16 MoE using llama4 shapes
dim, hidden_dim = 5120, 8192
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

# target MoE for testing conversion
model = copy.deepcopy(ref_model)

# Token group alignment size must be 16 for fp8 rowwise training
alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16
set_token_group_alignment_size_m(alignment_size)

# assert starting params are identical for both models
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
assert torch.equal(param1, param2)
Expand All @@ -83,15 +94,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
config = MoETrainingConfig(scaling_type=recipe)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# FSDP2
fully_shard(model)
fully_shard(ref_model)

# inputs (llama4 shapes)
batch, seq = 1, 8192
batch, seq = 1, 16640
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
Expand All @@ -104,70 +115,34 @@ def warmup(model, input):
loss.backward()
torch.cuda.synchronize()

def bench_fn_microseconds(model, input):
labels = torch.ones_like(input)
times = []
for _ in range(10):
start_ns = perf_counter_ns()
out = model(input)
loss = F.mse_loss(out, labels)
loss.backward()
torch.cuda.synchronize()
end_ns = perf_counter_ns()
duration_us = (end_ns - start_ns) / 1000
times.append(duration_us)
return statistics.median(times)

def profile_fn(model, input, profile_name="profile"):
# Only profile on rank 0
if torch.distributed.get_rank() == 0:
labels = torch.ones_like(input)
wait, warmup, active = 1, 3, 1
total_steps = wait + warmup + active
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=0
),
record_shapes=True,
with_stack=True,
) as prof:
for _ in range(total_steps):
out = model(input)
loss = F.mse_loss(out, labels)
loss.backward()
prof.step()

# Save profiler results
prof.export_chrome_trace(f"{profile_name}.json")
print(f"Saved: {profile_name}.json")

# Compile models
ref_model = torch.compile(ref_model, fullgraph=False)
model = torch.compile(model, fullgraph=False)

print("Benchmarking MoE with FSDP2 using bf16 training")
warmup(ref_model, ref_x)
bf16_us = bench_fn_microseconds(ref_model, ref_x)
print(f"bf16 time: {bf16_us} us")
if enable_profile:
print("Profiling bf16 model")
profile_fn(ref_model, ref_x, profile_name="bf16_profile")
labels = torch.ones_like(x)

# Token group alignment size must be 16 for fp8 rowwise training
set_token_group_alignment_size_m(16)

print("Benchmarking MoE with FSDP2 using fp8 rowwise training")
warmup(model, x)
fp8_us = bench_fn_microseconds(model, x)
print(f"fp8 time: {fp8_us} us")
# TODO: bench with fullgraph=True if/when it is supported
bf16_us = bench_fwd_bwd_microseconds(
ref_model,
ref_x,
labels=labels,
use_compile=use_compile,
fullgraph=False,
)
print(f"BF16 time: {bf16_us} us")
if enable_profile:
print("Profiling bf16 training")
profile_fn(ref_model, ref_x, labels=labels, profile_name="bf16_profile")

scaled_us = bench_fwd_bwd_microseconds(
model,
x,
labels=labels,
use_compile=use_compile,
fullgraph=False,
)
print(f"Scaled time: {scaled_us} us")
if enable_profile:
print("Profiling fp8 model")
profile_fn(model, x, profile_name="fp8_profile")
print("Profiling quantized training")
profile_fn(model, x, labels=labels, profile_name=f"{recipe_name}_profile")

print(f"Speedup: {bf16_us / scaled_us:.3f}x")
dist.destroy_process_group()


Expand All @@ -185,5 +160,15 @@ def setup_distributed():
action="store_true",
help="Enable PyTorch profiling and save results to file",
)
parser.add_argument("--recipe", type=str, help="[fp8_rowwise, mxfp8]")
parser.add_argument(
"--compile",
action="store_true",
help="use torch.compile",
)
args = parser.parse_args()
bench_moe_float8_training_fsdp(enable_profile=args.profile)
bench_moe_float8_training_fsdp(
recipe_name=args.recipe,
enable_profile=args.profile,
use_compile=args.compile,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from tabulate import tabulate
from tqdm import tqdm
from utils import bench_fwd_bwd_microseconds
from utils import bench_fwd_bwd_microseconds, profile_fn

from torchao.prototype.moe_training import _scaled_grouped_mm
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
Expand Down Expand Up @@ -47,7 +47,7 @@ class Experiment:

def get_configs() -> List[ExperimentConfig]:
A_shapes = [(16640, 5120)]
B_shapes = [(16, 8192, 5120), (128, 8192, 5120)]
B_shapes = [(16, 8192, 5120)]
recipes = [MoEScalingType.FP8_ROWWISE]
high_precision_dtypes = [torch.bfloat16]
configs = []
Expand Down Expand Up @@ -106,6 +106,16 @@ def run_experiment(
labels=labels,
use_compile=args.compile,
)
if args.profile:
profile_fn(
torch._grouped_mm,
A,
B_t,
offs,
labels=labels,
use_compile=args.compile,
profile_name="bf16_profile",
)

# benchmark scaled grouped mm with dynamic fp8 rowwise quant
fp8_us = bench_fwd_bwd_microseconds(
Expand All @@ -117,6 +127,17 @@ def run_experiment(
labels=labels,
use_compile=args.compile,
)
if args.profile:
profile_fn(
_scaled_grouped_mm,
A,
B_t,
offs,
scaling_type=config.recipe,
labels=labels,
use_compile=args.compile,
profile_name="scaled_profile",
)

return ExperimentResult(
bf16_us=round(bf16_us, 3),
Expand Down Expand Up @@ -164,5 +185,6 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--compile", action="store_true")
arg_parser.add_argument("--profile", action="store_true")
args = arg_parser.parse_args()
main(args)
41 changes: 39 additions & 2 deletions benchmarks/prototype/moe_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from torch.nn import functional as F


def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwargs):
def bench_fwd_bwd_microseconds(
fn, *args, labels=None, use_compile=False, fullgraph=True, **kwargs
):
assert labels is not None
fn = torch.compile(fn, fullgraph=False) if use_compile else fn
fn = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn
times = []
for _ in range(10):
start_ns = perf_counter_ns()
Expand All @@ -19,3 +21,38 @@ def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwar
duration_us = (end_ns - start_ns) / 1000
times.append(duration_us)
return statistics.median(times)


def profile_fn(
fn,
*args,
labels=None,
use_compile=False,
fullgraph=True,
profile_name="profile",
**kwargs,
):
assert labels is not None
fn = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn
wait, warmup, active = 1, 3, 1
total_steps = wait + warmup + active
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=0
),
record_shapes=True,
with_stack=True,
) as prof:
for _ in range(total_steps):
out = fn(*args, **kwargs)
loss = F.mse_loss(out, labels)
loss.backward()
prof.step()

# Save profiler results
prof.export_chrome_trace(f"{profile_name}.json")
print(f"Saved: {profile_name}.json")