Integrate DeepGEMM FP8 blockwise MoE backend#4473
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4473
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New FailureAs of commit 4cce2ad with merge base 5165bfb ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@claude review |
|
Claude finished @iamzainhuda's task in 6m 32s —— View job Code Review: DeepGEMM FP8 blockwise MoE backend
Nice work — the backend abstraction is clean, the layout/scale contracts between TorchAO and DeepGEMM are well-documented inline, and the test matrix (optional-dep handling, layout parity, compile, K-grouped quant, wgrad parity, end-to-end autograd parity) is genuinely comprehensive. A few things worth addressing below. Potential issues1. PEP 604 unions (
torchao's CI still targets Python 3.9 ( Cheapest fix: add 2. Per-step host↔device sync in 3. Double FP32 allocation in K-grouped wgrad (memory) 4. Minor / nits
CI failuresThe 5 reported failures look unrelated to this PR: |
Build the DeepGEMM M-grouped layout entirely on-device with bucketize and defer host-side group sizes to a lazy cached_property, so the forward grouped GEMM no longer triggers a per-group device-to-host sync (the K-grouped wgrad keeps the single host ks read DeepGEMM requires). Skip the alignment check on the padded path, where groups are block-aligned by construction. Also add a per-kernel MoE grouped-GEMM benchmark mirroring the linear blockwise FP8 benchmarks: each quant cast and DeepGEMM grouped GEMM is timed in isolation and reported against the memory-bandwidth or FP8 compute roofline.
|
@claude review |
|
Claude finished @andrewor14's task in 3m 28s —— View job Code Review: DeepGEMM FP8 blockwise MoE backend
This is a clean, well-engineered integration. Re-reviewing the current state of the branch (commits through A few items remain worth a look. Remaining / new findings1.
but the real op ( 2. Compact K-grouped kernel still stores without an 3. Redundant mask term in ragged store (nit) Notes (not blockers)
CIThe reported failures ( |
|
@claude can you verify if the refactor is done correctly such that the original emulated path is unchanged? Trace through the code paths and print them here. Then do the same for the new deepgemm code path |
|
Claude finished @andrewor14's task in 2m 41s —— View job Refactor verification: emulated path unchanged + DeepGEMM path trace
Short answerYes — the refactor preserves the emulated path exactly. The pre-refactor inline kernel calls were moved verbatim into Emulated path trace (before vs after)Forward —
|
andrewor14
left a comment
There was a problem hiding this comment.
Looks great, just mostly comments about docs and arg signatures. Looks like claude pointed out there's potentially a wrong shape. Can you look into that?
Also just wondering if you've tried this on a real model yet to see if we get some reasonable loss, not blocking for this PR but would be great to see some real examples
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class DeepGemmGroupedOffsetPlan: |
There was a problem hiding this comment.
add some docstrings here (and all other new classes)? Explain what each of the fields refer to?
| group_sizes: list[int], | ||
| block_size: int, | ||
| dim: int, | ||
| ) -> DeepGemmKGroupedQuantMetadata: |
There was a problem hiding this comment.
need some docstrings on these helper functions too. What are they doing on a high level, what the args are etc
| padded_group_end_offsets, | ||
| out_dtype, | ||
| block_size, | ||
| deepgemm_offset_plan=deepgemm_offset_plan, |
There was a problem hiding this comment.
seems a bit weird to pass this deepgemm specific plan to a general backend plan. Given a specific backend the plan is the same for all grouped_mm calls right? Can we just embed this offset plan in the backend_plan itself so you don't need to pass it to every grouped_mm call?
There was a problem hiding this comment.
yeah this makes sense, changed it so its part of that backends plan now
|
|
||
|
|
||
| class _EmulatedGroupedMMBackendPlan(_GroupedMMBackendPlan): | ||
| kind = _GroupedMMBackend.EMULATED |
There was a problem hiding this comment.
nit: maybe just backend to be more specific?
There was a problem hiding this comment.
changed to make it clearer:
- _GroupedMMBackend enum -> _GroupedMMBackendKind
- _GroupedMMBackendPlan -> _GroupedMMBackend
- _EmulatedGroupedMMBackendPlan -> _EmulatedGroupedMMBackend
- _DeepGemmGroupedMMBackendPlan -> _DeepGemmGroupedMMBackend
| original_group_end_offsets: Optional[torch.Tensor] = None, | ||
| padded_group_start_offsets: Optional[torch.Tensor] = None, | ||
| num_rows: Optional[int] = None, | ||
| ) -> _GroupedMMBackendSelection: |
There was a problem hiding this comment.
the docstring here should list all the scenarios in which emulated is selected vs deepgemm
|
@andrewor14 addressed the comments |
Summary
Adds an optional DeepGEMM backend for the FP8 blockwise MoE grouped GEMM, alongside the existing emulated path. When DeepGEMM is installed,
KernelPreference.AUTOdispatches the forward, dgrad, and wgrad grouped GEMMs to DeepGEMM; otherwise it falls back to the emulated kernels.Two main parts are added here: backend-selection layer and DeepGEMM-layout FP8 quantization kernels that write DeepGEMM's exact data/scale contracts directly (no dispatch-time transpose/copy).
DeepGEMM layout/scale contract vs TorchAO
DeepGEMM's grouped FP8 kernels are expect a different operand memory layout and scale orientation and differ from the layouts of TorchAO's existing blockwise quantizers (and
torch._grouped_mm) produce.(E, K, N)transposed in per-expert column-major layout, its 1x128 activation scales are stored as(M_blocks, K)(column-major-friendly fortorch._scaled_mm), and its grouped RHS scales follow thetorch._grouped_mmconvention.(E, N, K)row-major with contiguousKand(E, N//128, K//128)scales; the dgrad RHS as(E, K, N)row-major with contiguousNand(E, K//128, N//128)scales; and for the K-grouped wgrad, each operand as a flat per-expert(D, expert_tokens)(D, total_valid_token_blocks)— i.e. the token dimension becomes the GEMM's contraction axis and the scale axes are transposed relative to TorchAO's.Rather than quantize in the TorchAO layout and then transpose/copy at dispatch time, we add kernels that write DeepGEMM's exact contract directly.
(E*N, K)view of the column-major weight and emit DeepGEMM's(E, N, K)/(E, K, N)data and block-transposed scales straight from the cast — no separate layout-copy kernel._DEEPGEMM_DIRECT_K_GROUPED_QUANT_MIN_DIM: wide operands take the direct K-grouped quant that writes the flat(D, expert_tokens)K-major buffer and(D, blocks)scales in one pass; narrower operands reuse TorchAO's existing transposed-LHS/RHS quantizers and are then flattened per-expert and have their scales transposed ((M_blocks, K) → (K, M_blocks)) at launch time to match DeepGEMM's K-major contract. The on-devicegrouped_layout(one int32 expert id per row,-1for padding) is what tells DeepGEMM's contiguous M-grouped kernel which expert each row belongs to.H100 per-kernel roofline (balanced tokens, M=32768, E=8, N=2048, K=7168)
Roofline refs: mem BW peak 3352 GB/s / achievable 3084 GB/s (92%); FP8 compute peak 1979 TFLOP/s / achievable 1544 TFLOP/s (78%).
Quant/cast kernels are reported as % of achievable memory bandwidth; DeepGEMM grouped GEMMs as % of achievable FP8 compute.