[cuda] Compact int4/int6 weight quant metadata (bf16 -> uint8 + per-row super-scale)#20571
Open
Gasoonjia wants to merge 4 commits into
Open
[cuda] Compact int4/int6 weight quant metadata (bf16 -> uint8 + per-row super-scale)#20571Gasoonjia wants to merge 4 commits into
Gasoonjia wants to merge 4 commits into
Conversation
…ow super-scale)
Decode for gemma4-31B on CUDA is weight-bandwidth-bound: the int4/int6
weight-only matvecs are ~72% of per-token decode time and already run at
~89% of the RTX 5090 HBM roofline using the same dp4a algorithm as
llama.cpp. The only lever left is fewer bytes/token.
ET previously stored per-group (group_size=32) scale AND zero as bf16:
int4 = 5.0 bits/weight (20% metadata overhead), int6 = 7.0 bits/weight.
llama.cpp's Q4_K/Q6_K store far less metadata (4.5 / 6.5625 bpw).
This change re-encodes the quant metadata to match llama.cpp's byte
density WITHOUT touching the dp4a inner loop or the 4/6-bit weight codes:
- per-group scale/zero: bf16 -> uint8 codes
- per-row bf16 "super-scale" (step = row_max/255) restores dynamic range
(plain fp8/int8 alone fails: Q4_K scales span ~1e-4..8e-2 -> subnormal
blowup, 15-18 dB SNR; the two-level encoding keeps 46.7-48.1 dB == bf16)
- int6 is symmetric (no zero), int8 scale codes + per-row step (mirrors Q6_K)
Result -> 4.77 bpw (llama.cpp parity). Touches the packer
(coalesced_int4_tensor / dp4a_planar_int6_tensor), the decode shims
(int4/int6_plain_mm .cu/.cuh/.h, +steps arg), the dispatch ops, the AOTI
ABI sig in cuda_backend, and the gemma4_31b concat source transform.
e2e (gemma4-31B, 128k+TurboQuant export, CUDA-graph decode, 3-rep median):
ctx baseline -> this
512 57.16 -> 60.78 (+6.3%)
2048 56.39 -> 59.79 (+6.0%)
8192 55.60 -> 58.88 (+5.9%)
32768 55.29 -> 58.60 (+6.0%)
VRAM peak 25.06 -> 23.29 GiB; .ptd 26.18 -> 24.28 GB. Accuracy: Paris
coherent, dequant SNR ~baseline. Win holds at long context under CUDA
graph (verified, not a microbenchmark estimate).
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20571
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit e50dfe5 with merge base d54a0c0 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
The compact-metadata change (bf16 -> uint8/int8 codes + per-row bf16 super-scale) migrated the int4/int6 plain_mm kernels and Python dispatch but left the tests on the old signatures, breaking the CUDA shim CMake build and the dispatch unit tests. - test_aoti_torch_cuda_int4_plain_mm.cpp: regenerate vectors as uint8 scale/zero codes + [N,2] bf16 steps; update to the 7-arg signature. - test_aoti_torch_cuda_int6_plain_mm.cpp: regenerate vectors as int8 scale codes + [N,1] bf16 steps; update to the 6-arg signature. - test_int4_dispatch.py / test_int6_dispatch.py: add the steps arg to the custom-op recorders, build tensors via the new constructors, and compare against the decoded (code * step) scale instead of raw codes. Vectors are generated from the production pack path (CudaCoalescedInt4Tensor / pack_int6 + _encode_int8_per_row) with expected[] from the export-path _dequant_matmul references. Test Plan: - Built the CUDA shim gtests and ran on GPU: int4 7/7, int6 3/3 pass. - python -m pytest backends/cuda/tests/test_int4_dispatch.py backends/cuda/tests/test_int6_dispatch.py -> 33 passed. - lintrunner init && lintrunner -a: no lint issues.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Decode for gemma4-31B on CUDA is weight-bandwidth-bound: the int4/int6 weight-only matvecs are ~72% of per-token decode time and already run at ~89% of the RTX 5090 HBM roofline using the same dp4a algorithm as llama.cpp. The only lever left is fewer bytes/token.
ET previously stored per-group (group_size=32) scale AND zero as bf16:
int4 = 5.0 bits/weight (20% metadata overhead), int6 = 7.0 bits/weight.
llama.cpp's Q4_K/Q6_K store far less metadata (4.5 / 6.5625 bpw).
This change re-encodes the quant metadata to match llama.cpp's byte density WITHOUT touching the dp4a inner loop or the 4/6-bit weight codes:
Result -> 4.77 bpw (llama.cpp parity).
e2e (gemma4-31B, 128k context length, CUDA-graph decode):