Skip to content

Integrate DeepGEMM FP8 blockwise MoE backend#4473

Open
iamzainhuda wants to merge 5 commits into
mainfrom
deepgemm-fp8-blockwise-moe
Open

Integrate DeepGEMM FP8 blockwise MoE backend#4473
iamzainhuda wants to merge 5 commits into
mainfrom
deepgemm-fp8-blockwise-moe

Conversation

@iamzainhuda

@iamzainhuda iamzainhuda commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds an optional DeepGEMM backend for the FP8 blockwise MoE grouped GEMM, alongside the existing emulated path. When DeepGEMM is installed, KernelPreference.AUTO dispatches the forward, dgrad, and wgrad grouped GEMMs to DeepGEMM; otherwise it falls back to the emulated kernels.

Two main parts are added here: backend-selection layer and DeepGEMM-layout FP8 quantization kernels that write DeepGEMM's exact data/scale contracts directly (no dispatch-time transpose/copy).

DeepGEMM layout/scale contract vs TorchAO

DeepGEMM's grouped FP8 kernels are expect a different operand memory layout and scale orientation and differ from the layouts of TorchAO's existing blockwise quantizers (and torch._grouped_mm) produce.

  • TorchAO's public expert weight is (E, K, N) transposed in per-expert column-major layout, its 1x128 activation scales are stored as (M_blocks, K) (column-major-friendly for torch._scaled_mm), and its grouped RHS scales follow the torch._grouped_mm convention.
  • DeepGEMM instead wants: the forward RHS as (E, N, K) row-major with contiguous K and (E, N//128, K//128) scales; the dgrad RHS as (E, K, N) row-major with contiguous N and (E, K//128, N//128) scales; and for the K-grouped wgrad, each operand as a flat per-expert (D, expert_tokens)
  • K-major buffer (experts concatenated) with scales as (D, total_valid_token_blocks) — i.e. the token dimension becomes the GEMM's contraction axis and the scale axes are transposed relative to TorchAO's.

Rather than quantize in the TorchAO layout and then transpose/copy at dispatch time, we add kernels that write DeepGEMM's exact contract directly.

  • The weight quantizers read the contiguous (E*N, K) view of the column-major weight and emit DeepGEMM's (E, N, K)/(E, K, N) data and block-transposed scales straight from the cast — no separate layout-copy kernel.
  • For the wgrad operands there are two routes selected per operand by _DEEPGEMM_DIRECT_K_GROUPED_QUANT_MIN_DIM: wide operands take the direct K-grouped quant that writes the flat (D, expert_tokens) K-major buffer and (D, blocks) scales in one pass; narrower operands reuse TorchAO's existing transposed-LHS/RHS quantizers and are then flattened per-expert and have their scales transposed ((M_blocks, K) → (K, M_blocks)) at launch time to match DeepGEMM's K-major contract. The on-device grouped_layout (one int32 expert id per row, -1 for padding) is what tells DeepGEMM's contiguous M-grouped kernel which expert each row belongs to.

H100 per-kernel roofline (balanced tokens, M=32768, E=8, N=2048, K=7168)

Roofline refs: mem BW peak 3352 GB/s / achievable 3084 GB/s (92%); FP8 compute peak 1979 TFLOP/s / achievable 1544 TFLOP/s (78%).

Quant/cast kernels are reported as % of achievable memory bandwidth; DeepGEMM grouped GEMMs as % of achievable FP8 compute.

kernel µs % of achievable roofline
fwd: act_quant_lhs 253.5 91.1 (BW)
fwd: weight_quant_forward_rhs 128.1 89.2 (BW)
fwd: deepgemm_grouped_mm 779.3 80.0 (compute)
bwd: act_quant_lhs(grad_out) 77.6 85.0 (BW)
bwd: weight_quant_dgrad_rhs 126.3 90.5 (BW)
bwd: deepgemm_grouped_mm_dgrad 778.6 80.1 (compute)
bwd: wgrad_quant_lhs(grad_out) [transposed] 85.1 77.5 (BW)
bwd: wgrad_quant_rhs(A) [direct] 275.3 83.8 (BW)
bwd: deepgemm_grouped_mm_wgrad 2105.4 29.6 (compute)

@pytorch-bot

pytorch-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4473

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit 4cce2ad with merge base 5165bfb (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 9, 2026
@iamzainhuda iamzainhuda added the module: training quantize_ api training flow label Jun 9, 2026
@iamzainhuda

Copy link
Copy Markdown
Contributor Author

@claude review

@claude

claude Bot commented Jun 15, 2026

Copy link
Copy Markdown

Claude finished @iamzainhuda's task in 6m 32s —— View job


Code Review: DeepGEMM FP8 blockwise MoE backend

  • Read changed files
  • Analyze correctness, performance, and design
  • Post review feedback

Nice work — the backend abstraction is clean, the layout/scale contracts between TorchAO and DeepGEMM are well-documented inline, and the test matrix (optional-dep handling, layout parity, compile, K-grouped quant, wgrad parity, end-to-end autograd parity) is genuinely comprehensive. A few things worth addressing below.


Potential issues

1. PEP 604 unions (X | None) at runtime — may break Python 3.9 (likely blocking)
The three new files use | unions in positions that are evaluated at import time, not just type-checked:

  • deepgemm_metadata.py:34-35 dataclass fields original_group_sizes: list[int] | None, padded_group_start_offsets: torch.Tensor | None
  • deepgemm_grouped_kernels.py:61-64 module: object | None, import_error: ImportError | None, etc.
  • function-signature defaults like num_rows: int | None = None (deepgemm_metadata.py:87)

torchao's CI still targets Python 3.9 (.github/workflows/ruff_linter.yml, dashboard_perf_test.yml both pin 3.9). On 3.9, object | None / list[int] | None raise TypeError at class/def evaluation. Because grouped_mm.py imports grouped_mm_backenddeepgemm_grouped_kernelsdeepgemm_metadata at module load, this would break import on 3.9 even for the emulated path. Existing code (e.g. moe_training/tensor.py:301) only uses | on local variable annotations, which are never evaluated, so it slips through.

Cheapest fix: add from __future__ import annotations to the top of all three new files (defers all annotation evaluation to strings). Fix this → Please confirm whether torchao's runtime min is actually 3.9 — if it has moved to 3.10+ this is moot.

2. Per-step host↔device sync in build_deepgemm_grouped_offset_plan (perf)
group_sizes_from_offsets (deepgemm_metadata.py:69-79) calls int(group_end_offsets[i].item()) in a Python loop, and build_deepgemm_grouped_offset_plan additionally does padded_group_start_offsets.tolist(). This runs on every forward (inside _Float8BlockwiseGroupedMM.forward), forcing D2H syncs and a guaranteed graph break under torch.compile/CUDA graphs. The plan is correctly cached in ctx so backward doesn't repeat it, but the forward cost is real. Worth a comment acknowledging the sync, or computing offsets once on host if the caller already has them.

3. Double FP32 allocation in K-grouped wgrad (memory)
deepgemm_grouped_kernels.py:548-553 allocates both out_fp32 (E, N, K) and accum = torch.zeros_like(out_fp32) — two full FP32 buffers for the largest operand in the backward. If DeepGEMM's k_grouped_fp8_gemm_nt_contiguous computes out = accum + A@B with accum only needed as the accumulator seed, consider whether out_fp32 itself can be zero-initialized and passed as both, or whether accum is even required when not accumulating across calls. A short comment on why both are needed would help.

4. register_fake shape mismatch with the real op (correctness under compile)
deepgemm_quant.py:458-472 — the fake returns scale shape (x.shape[1], ceil_div(x.shape[0], block_size)), but the real op returns (D, valid_tokens // block_size) where valid_tokens = sum(group_sizes) ≤ x.shape[0] (padding rows are excluded). When groups are padded, ceil_div(x.shape[0], 128) over-counts blocks vs. the real output, so a traced/compiled call through this custom op would see a wrong meta shape. The autograd path uses _..._with_group_sizes directly rather than the registered op, so this only bites standalone/compiled use of triton_fp8_blockwise_act_quant_k_grouped_deepgemm, but it's still latent. Since the fake can't know group_sizes, consider documenting that the registered op is only valid for unpadded inputs, or deriving the count from offsets.


Minor / nits

  • deepgemm_quant.py:31-36: _should_use_traceable_triton_launch ORs torch.compiler.is_compiling() and torch._dynamo.is_compiling() — the former already subsumes the latter; harmless but redundant.
  • deepgemm_grouped_kernels.py:36: _DEEPGEMM_DIRECT_K_GROUPED_QUANT_MIN_DIM = 4096 is an H100/DeepSeek-V3-tuned threshold applied uniformly to all SM90+ devices. Correctness is preserved either way, and the comment is excellent — just flagging that it may be suboptimal on non-H100 SM90+ parts.
  • deepgemm_grouped_kernels.py:289 compact kernel stores q/s without an offs_d < D mask, unlike the ragged kernel. This is safe only because the compact path is gated on D % 128 == 0 and NUM_GROUPS ∈ {2,16,32,64,128} all divide D — so offs_d never exceeds D. A one-line comment stating that invariant would prevent a future config change from silently introducing OOB stores.

CI failures

The 5 reported failures look unrelated to this PR: test_save_load_int8woqtensors, test_x86inductor_fusion (×3), and a generic H100 exit code 1. The first four touch areas this PR doesn't modify and are almost certainly pre-existing/flaky. The H100 job is the one worth a closer look since this PR adds the H100 DeepGEMM tests — if the failure is in test_fp8_blockwise_deepgemm_backend.py it'd be in scope; if it's the int8woq test on H100, it's unrelated.
· branch deepgemm-fp8-blockwise-moe

Build the DeepGEMM M-grouped layout entirely on-device with bucketize and
defer host-side group sizes to a lazy cached_property, so the forward grouped
GEMM no longer triggers a per-group device-to-host sync (the K-grouped wgrad
keeps the single host ks read DeepGEMM requires). Skip the alignment check on
the padded path, where groups are block-aligned by construction.

Also add a per-kernel MoE grouped-GEMM benchmark mirroring the linear blockwise
FP8 benchmarks: each quant cast and DeepGEMM grouped GEMM is timed in isolation
and reported against the memory-bandwidth or FP8 compute roofline.
@iamzainhuda iamzainhuda requested a review from andrewor14 June 15, 2026 21:24
@iamzainhuda iamzainhuda marked this pull request as ready for review June 15, 2026 21:24
@iamzainhuda iamzainhuda requested a review from drisspg June 17, 2026 17:48
@andrewor14

Copy link
Copy Markdown
Contributor

@claude review

@claude

claude Bot commented Jun 23, 2026

Copy link
Copy Markdown

Claude finished @andrewor14's task in 3m 28s —— View job


Code Review: DeepGEMM FP8 blockwise MoE backend

  • Read changed files
  • Analyze correctness, performance, and design
  • Post review feedback

This is a clean, well-engineered integration. Re-reviewing the current state of the branch (commits through c51414b), most of the issues raised in the earlier review have already been addressed — from __future__ import annotations is now on all three new files, the forward-path host syncs were removed, the double-FP32 allocation in wgrad now carries a precise comment explaining DeepGEMM's d = c + A@B accumulator contract, and the redundant is_compiling() OR was simplified. Nice follow-through.

A few items remain worth a look.


Remaining / new findings

1. register_fake still reports the wrong shape for padded inputs (latent, correctness under compile)
deepgemm_quant.py:460-474 — the fake returns

  • q: (x.shape[0] * x.shape[1],) = (M * D,)
  • s: (D, ceil_div(M, block_size))

but the real op (deepgemm_quant.py:388-391) returns (valid_tokens * D,) and (D, valid_tokens // block_size), where valid_tokens = sum(group_sizes) ≤ M. When groups are padded, the fake over-counts both the flat buffer length and the scale-block count. The autograd path calls _..._with_group_sizes directly rather than the registered op, so this only bites a standalone/compiled call through triton_fp8_blockwise_act_quant_k_grouped_deepgemm — but it's genuinely latent: a traced consumer would later trip the a.numel() == total_tokens * lhs.dim assert in deepgemm_grouped_kernels.py:538. Since the fake can't see group_sizes, the cleanest fix is a one-line docstring note that the registered op's meta is only exact for unpadded inputs (where M == valid_tokens). Fix this →

2. Compact K-grouped kernel still stores without an offs_d < D mask (robustness)
deepgemm_quant.py:289-330 — the compact kernel's q/s stores (lines 327, 330) have no offs_d < D guard, unlike the ragged kernel (deepgemm_quant.py:278, 285). This is safe today only because the compact path is gated on D % 128 == 0 (deepgemm_quant.py:397) and every NUM_GROUPS ∈ {2,16,32,64,128} (kernels.py:473) divides 128, so cdiv(D, NUM_GROUPS) * NUM_GROUPS == D exactly and offs_d never exceeds D. That's a non-obvious invariant spanning two files — a one-line comment in the compact kernel stating it would stop a future NUM_GROUPS config change from silently introducing OOB stores. Fix this →

3. Redundant mask term in ragged store (nit)
deepgemm_quant.py:279mask=q_mask & valid_group_block, but q_mask is already valid_group_block & (offs_d[:, None] < D) (line 278), so the extra & valid_group_block is a no-op. Harmless; just mask=q_mask would read cleaner.


Notes (not blockers)

  • PEP 604 unions are fine now. setup.py:24 pins min_supported_cpython_hexcode = 0x030A0000 (Python 3.10), so the X | None annotations evaluated at import time in deepgemm_metadata.py/deepgemm_grouped_kernels.py are valid — the earlier "may break 3.9" concern is moot.
  • Backward-path host syncs remain by necessity: offset_plan.group_sizes (deepgemm_metadata.py:45-48) does a .tolist() and build_deepgemm_k_grouped_quant_metadata (deepgemm_metadata.py:170-203) runs a Python loop over groups/blocks. These run in wgrad during backward because DeepGEMM's K-grouped API requires a host-side ks sequence — inherent to the kernel contract, not a regression. The forward path is correctly sync-free now.
  • _DEEPGEMM_DIRECT_K_GROUPED_QUANT_MIN_DIM = 4096 (deepgemm_grouped_kernels.py:37) is an H100/DeepSeek-V3-tuned threshold applied to all SM90+ parts. Correctness holds either way; the comment already explains the heuristic well — just flagging it may be suboptimal on non-H100 SM90+.
  • Test coverage is strong: optional-dep handling, broken-install messaging, AUTO fallback, layout parity vs TorchAO transposes, fullgraph compile, K-grouped quant parity, wgrad parity, and end-to-end autograd parity. Note the GEMM-executing tests (test_deepgemm_*_matches_emulated) importorskip("deep_gemm"), so they're skipped in CI without DeepGEMM + H100 — the layout/quant parity tests that do run on any SM90+ are what guard the contract in CI.

CI

The reported failures (test_x86inductor_fusion ×2, test_save_load_int8woqtensors) are in PT2E/int8 areas this PR doesn't touch and look pre-existing/flaky, not caused by this change.
· branch deepgemm-fp8-blockwise-moe

@andrewor14

Copy link
Copy Markdown
Contributor

@claude can you verify if the refactor is done correctly such that the original emulated path is unchanged? Trace through the code paths and print them here. Then do the same for the new deepgemm code path

@claude

claude Bot commented Jun 23, 2026

Copy link
Copy Markdown

Claude finished @andrewor14's task in 2m 41s —— View job


Refactor verification: emulated path unchanged + DeepGEMM path trace

  • Read grouped_mm.py and diff against the pre-refactor commit (5165bfb)
  • Trace the emulated forward/backward path end-to-end
  • Trace the DeepGEMM forward/backward path end-to-end
  • Confirm behavioral equivalence of the emulated path before vs after

Short answer

Yes — the refactor preserves the emulated path exactly. The pre-refactor inline kernel calls were moved verbatim into _EmulatedGroupedMMBackendPlan (same functions, same arguments, same order, same scale recipes). The autograd function was renamed _Float8BlockwiseEmulatedGroupedMM → _Float8BlockwiseGroupedMM and now dispatches through a backend_plan object, but for the emulated backend every call resolves to the identical kernel it called before. The one observable addition on the emulated path is the _select_fp8_blockwise_grouped_mm_backend(...) call; when it returns the emulated plan it does no extra device work (deepgemm_offset_plan=None).


Emulated path trace (before vs after)

Forward — _Float8BlockwiseGroupedMM.forward (grouped_mm.py:91-187)

All structural steps are byte-for-byte identical to the old _Float8BlockwiseEmulatedGroupedMM.forward:

  1. Same asserts (grouped_mm.py:103-115).
  2. Same pad_token_groups(...) (:120-123).
  3. New: _select_fp8_blockwise_grouped_mm_backend(...) (:127-138). For EMULATED it short-circuits at grouped_mm_backend.py:306-310_EMULATED_GROUPED_MM_BACKEND_PLAN, deepgemm_offset_plan=None. No quant, no GEMM, no sync.
  4. LHS quant triton_fp8_blockwise_act_quant_lhs(...) (:142-146) — still inline, unchanged, identical for both backends.
  5. RHS quant: was inline triton_fp8_blockwise_weight_quant_grouped_transposed_rhs(B_t, block_size=, dtype=); now backend_plan.quantize_forward_rhs(...)grouped_mm_backend.py:104-116 calls the same function with the same args.
  6. GEMM: was emulated_blockwise_scaled_grouped_mm(A_fp8, B_t_fp8, A_scale, 1x128, B_t_scale, 128x128, offs, out_dtype, block_size); now backend_plan.grouped_mm(...)grouped_mm_backend.py:147-157 forwards the identical 9 args (the deepgemm_offset_plan kwarg is ignored by the emulated plan).
  7. Same unpad_token_groups(...) (:164-171).
  8. ctx saves are identical plus two new fields (backend_plan, deepgemm_offset_plan) consumed only in backward.

Backward — _Float8BlockwiseGroupedMM.backward (grouped_mm.py:189-264)

  • grad_A (dgrad): RHS quant was triton_fp8_blockwise_weight_quant_grouped_rhs(...); now quantize_dgrad_rhsgrouped_mm_backend.py:127-131 calls the same function. GEMM unchanged (same recipes/args).
  • grad_B (wgrad): the entire old block — triton_fp8_blockwise_act_quant_transposed_lhs(padded_grad_output.contiguous(), ...) + triton_fp8_blockwise_act_quant_rhs(padded_A.contiguous(), ...) + emulated_blockwise_scaled_grouped_mm(..., 1x128, ..., 1x128, ...) — was moved verbatim into _EmulatedGroupedMMBackendPlan.wgrad (grouped_mm_backend.py:170-192). Same .contiguous() calls, same recipes.
  • Return tuple gained one trailing None (:263) for the new kernel_preference parameter — correct, and harmless since it's non-differentiable.

The legacy entry point _to_fp8_blockwise_then_emulated_scaled_grouped_mm is retained as a thin wrapper (grouped_mm.py:65-87) that forces KernelPreference.EMULATED, so existing callers and test_fp8_blockwise_grouped_mm.py are unaffected.

Conclusion: with KernelPreference.EMULATED the numerics and op sequence are unchanged. With AUTO on a machine without DeepGEMM/SM90, can_use_deepgemm_grouped_training (deepgemm_grouped_kernels.py:148-160) returns False and you also get the exact emulated path — the only added cost is one @lru_cached import probe + capability check.


DeepGEMM path trace (new)

Selected only when all of: kernel_preference == AUTO, can_use_deepgemm_grouped_training is True (block_size 128, out_dtype bf16, both DeepGEMM grouped symbols importable, tensor on CUDA SM90+), and deepgemm_offset_plan.groups_are_block_aligned(128) (grouped_mm_backend.py:312-338). Any failure falls back to the emulated plan with deepgemm_offset_plan=None.

Forward

  1. build_deepgemm_grouped_offset_plan(...) builds the on-device grouped_layout / offset metadata (grouped_mm_backend.py:322-328).
  2. LHS quant: same shared triton_fp8_blockwise_act_quant_lhs as emulated.
  3. RHS quant: quantize_forward_rhstriton_fp8_blockwise_weight_quant_grouped_transposed_rhs_deepgemm (grouped_mm_backend.py:198-211) — writes DeepGEMM's (E, N, K) + (E, N//128, K//128) contract directly.
  4. GEMM: grouped_mm → asserts the offset plan is present (:242-244) then deepgemm_blockwise_scaled_grouped_mm(..., offset_plan=...).

Backward

  • grad_A: quantize_dgrad_rhs..._grouped_rhs_deepgemm emits (E, K, N) + (E, K//128, N//128); same deepgemm_blockwise_scaled_grouped_mm.
  • grad_B (wgrad): _DeepGemmGroupedMMBackendPlan.wgrad (:258-288) calls prepare_deepgemm_wgrad_plan(...) (per-operand chooses direct K-grouped quant vs. TorchAO-transposed+flatten via _DEEPGEMM_DIRECT_K_GROUPED_QUANT_MIN_DIM), asserts block-aligned groups, then deepgemm_blockwise_scaled_grouped_mm_wgrad(...).

Both backends terminate in the same unpad_token_groups and the same grad_B.transpose(-2, -1) return, so the autograd contract (shapes, transpose convention) is shared. The dispatch boundary is clean: the emulated plan never touches DeepGEMM metadata, and the DeepGEMM plan asserts the metadata is non-None before use.


Notes

  • The strongest guarantee in the test suite is test_fp8_blockwise_deepgemm_backend.py:439-446, which runs _to_fp8_blockwise_then_scaled_grouped_mm (AUTO) against _to_fp8_blockwise_then_emulated_scaled_grouped_mm (forced EMULATED) and asserts parity — i.e. it directly pins emulated-vs-DeepGEMM equivalence end-to-end (skipped in CI without DeepGEMM+H100).
  • Minor risk surface: the per-forward _select_... call adds a capability probe even when the result is always-emulated; it's @lru_cached so cost is negligible, but it does run every forward.
    · branch deepgemm-fp8-blockwise-moe

@andrewor14 andrewor14 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just mostly comments about docs and arg signatures. Looks like claude pointed out there's potentially a wrong shape. Can you look into that?

Also just wondering if you've tried this on a real model yet to see if we get some reasonable loss, not blocking for this PR but would be great to see some real examples



@dataclass(frozen=True)
class DeepGemmGroupedOffsetPlan:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some docstrings here (and all other new classes)? Explain what each of the fields refer to?

group_sizes: list[int],
block_size: int,
dim: int,
) -> DeepGemmKGroupedQuantMetadata:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need some docstrings on these helper functions too. What are they doing on a high level, what the args are etc

padded_group_end_offsets,
out_dtype,
block_size,
deepgemm_offset_plan=deepgemm_offset_plan,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit weird to pass this deepgemm specific plan to a general backend plan. Given a specific backend the plan is the same for all grouped_mm calls right? Can we just embed this offset plan in the backend_plan itself so you don't need to pass it to every grouped_mm call?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this makes sense, changed it so its part of that backends plan now



class _EmulatedGroupedMMBackendPlan(_GroupedMMBackendPlan):
kind = _GroupedMMBackend.EMULATED

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe just backend to be more specific?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to make it clearer:

  • _GroupedMMBackend enum -> _GroupedMMBackendKind
  • _GroupedMMBackendPlan -> _GroupedMMBackend
  • _EmulatedGroupedMMBackendPlan -> _EmulatedGroupedMMBackend
  • _DeepGemmGroupedMMBackendPlan -> _DeepGemmGroupedMMBackend

original_group_end_offsets: Optional[torch.Tensor] = None,
padded_group_start_offsets: Optional[torch.Tensor] = None,
num_rows: Optional[int] = None,
) -> _GroupedMMBackendSelection:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring here should list all the scenarios in which emulated is selected vs deepgemm

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

@iamzainhuda

Copy link
Copy Markdown
Contributor Author

@andrewor14 addressed the comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants