Skip to content

[cuda] Compact int4/int6 weight quant metadata (bf16 -> uint8 + per-row super-scale)#20571

Open
Gasoonjia wants to merge 4 commits into
mainfrom
cuda-int4-int6-metadata-opt
Open

[cuda] Compact int4/int6 weight quant metadata (bf16 -> uint8 + per-row super-scale)#20571
Gasoonjia wants to merge 4 commits into
mainfrom
cuda-int4-int6-metadata-opt

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Decode for gemma4-31B on CUDA is weight-bandwidth-bound: the int4/int6 weight-only matvecs are ~72% of per-token decode time and already run at ~89% of the RTX 5090 HBM roofline using the same dp4a algorithm as llama.cpp. The only lever left is fewer bytes/token.

ET previously stored per-group (group_size=32) scale AND zero as bf16:
int4 = 5.0 bits/weight (20% metadata overhead), int6 = 7.0 bits/weight.
llama.cpp's Q4_K/Q6_K store far less metadata (4.5 / 6.5625 bpw).

This change re-encodes the quant metadata to match llama.cpp's byte density WITHOUT touching the dp4a inner loop or the 4/6-bit weight codes:

  • per-group scale/zero: bf16 -> uint8 codes
  • per-row bf16 "super-scale" (step = row_max/255) restores dynamic range (plain fp8/int8 alone fails: Q4_K scales span ~1e-4..8e-2 -> subnormal blowup, 15-18 dB SNR; the two-level encoding keeps 46.7-48.1 dB == bf16)
  • int6 is symmetric (no zero), int8 scale codes + per-row step (mirrors Q6_K)

Result -> 4.77 bpw (llama.cpp parity).

e2e (gemma4-31B, 128k context length, CUDA-graph decode):

A100 + bf16 kvcache (Prompt | Decode) Orig Decode (toks/s) Orig VRAM (GiB) New Decode (toks/s) New VRAM (GiB) Decode Improvement VRAM Improvement
512 | 512 46.5 33.1 49.89 31.16 +7.29% -5.85%
2K | 512 45.0 33.4 49.28 31.45 +9.51% -5.85%
8K | 512 44.5 33.6 48.77 31.48 +9.60% -6.31%
32K | 512 43.5 33.6 46.19 31.48 +6.18% -6.31%
127K | 512 34.9 33.6 37.80 31.48 +8.31% -6.31%
A100 + tq4 kvcache (Prompt | Decode) Orig Decode (toks/s) Orig VRAM (GiB) New Decode (toks/s) New VRAM (GiB) Decode Improvement VRAM Improvement
512 | 512 45.15 25.5 46.41 23.70 +2.79% -7.06%
2K | 512 43.07 25.8 43.60 23.98 +1.23% -7.05%
8K | 512 42.79 25.8 43.32 24.01 +1.24% -6.94%
32K | 512 41.99 25.8 42.99 24.01 +2.38% -6.94%
127K | 512 34.25 25.8 34.13 24.01 -0.35% -6.94%
5090 + tq4 kvcache (Prompt | Decode) Orig Decode (toks/s) Orig VRAM (GiB) New Decode (toks/s) New VRAM (GiB) Decode Improvement VRAM Improvement
512 | 512 57.16 25.06 60.78 23.29 +6.33% -7.06%
2K | 512 56.39 25.06 59.79 23.29 +6.03% -7.06%
8K | 512 55.60 25.06 58.88 23.29 +5.90% -7.06%
32K | 512 55.29 25.06 58.60 23.29 +5.99% -7.06%

…ow super-scale)

Decode for gemma4-31B on CUDA is weight-bandwidth-bound: the int4/int6
weight-only matvecs are ~72% of per-token decode time and already run at
~89% of the RTX 5090 HBM roofline using the same dp4a algorithm as
llama.cpp. The only lever left is fewer bytes/token.

ET previously stored per-group (group_size=32) scale AND zero as bf16:
  int4 = 5.0 bits/weight (20% metadata overhead), int6 = 7.0 bits/weight.
llama.cpp's Q4_K/Q6_K store far less metadata (4.5 / 6.5625 bpw).

This change re-encodes the quant metadata to match llama.cpp's byte
density WITHOUT touching the dp4a inner loop or the 4/6-bit weight codes:
  - per-group scale/zero: bf16 -> uint8 codes
  - per-row bf16 "super-scale" (step = row_max/255) restores dynamic range
    (plain fp8/int8 alone fails: Q4_K scales span ~1e-4..8e-2 -> subnormal
     blowup, 15-18 dB SNR; the two-level encoding keeps 46.7-48.1 dB == bf16)
  - int6 is symmetric (no zero), int8 scale codes + per-row step (mirrors Q6_K)

Result -> 4.77 bpw (llama.cpp parity). Touches the packer
(coalesced_int4_tensor / dp4a_planar_int6_tensor), the decode shims
(int4/int6_plain_mm .cu/.cuh/.h, +steps arg), the dispatch ops, the AOTI
ABI sig in cuda_backend, and the gemma4_31b concat source transform.

e2e (gemma4-31B, 128k+TurboQuant export, CUDA-graph decode, 3-rep median):
  ctx     baseline  ->  this
  512      57.16    ->  60.78   (+6.3%)
  2048     56.39    ->  59.79   (+6.0%)
  8192     55.60    ->  58.88   (+5.9%)
  32768    55.29    ->  58.60   (+6.0%)
VRAM peak 25.06 -> 23.29 GiB; .ptd 26.18 -> 24.28 GB. Accuracy: Paris
coherent, dequant SNR ~baseline. Win holds at long context under CUDA
graph (verified, not a microbenchmark estimate).
@pytorch-bot

pytorch-bot Bot commented Jun 28, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20571

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit e50dfe5 with merge base d54a0c0 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 28, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia marked this pull request as ready for review June 30, 2026 22:26
The compact-metadata change (bf16 -> uint8/int8 codes + per-row bf16
super-scale) migrated the int4/int6 plain_mm kernels and Python dispatch
but left the tests on the old signatures, breaking the CUDA shim CMake
build and the dispatch unit tests.

- test_aoti_torch_cuda_int4_plain_mm.cpp: regenerate vectors as uint8
  scale/zero codes + [N,2] bf16 steps; update to the 7-arg signature.
- test_aoti_torch_cuda_int6_plain_mm.cpp: regenerate vectors as int8
  scale codes + [N,1] bf16 steps; update to the 6-arg signature.
- test_int4_dispatch.py / test_int6_dispatch.py: add the steps arg to
  the custom-op recorders, build tensors via the new constructors, and
  compare against the decoded (code * step) scale instead of raw codes.

Vectors are generated from the production pack path (CudaCoalescedInt4Tensor
/ pack_int6 + _encode_int8_per_row) with expected[] from the export-path
_dequant_matmul references.

Test Plan:
- Built the CUDA shim gtests and ran on GPU: int4 7/7, int6 3/3 pass.
- python -m pytest backends/cuda/tests/test_int4_dispatch.py
  backends/cuda/tests/test_int6_dispatch.py -> 33 passed.
- lintrunner init && lintrunner -a: no lint issues.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant