Skip to content

[ET-VK][quantized] Store dq8ca per-token zero-point as fp32#20491

Open
SS-JIA wants to merge 6 commits into
gh/SS-JIA/563/basefrom
gh/SS-JIA/563/head
Open

[ET-VK][quantized] Store dq8ca per-token zero-point as fp32#20491
SS-JIA wants to merge 6 commits into
gh/SS-JIA/563/basefrom
gh/SS-JIA/563/head

Conversation

@SS-JIA

@SS-JIA SS-JIA commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

The per-token dynamic-activation-quant (dq8ca) zero-point was corrupted by a tensor-allocation vs shader-access dtype mismatch. The per-token zero-point tensor is created with a float dtype -- fp32, or fp16 under USE_VULKAN_FP16_INFERENCE -- so its backing image uses a float texel format (rgba32f / rgba16f). But the shader declared and accessed that image with an integer dtype (int8, an integer image format rgba8i). Reading a float-format image through an integer-format binding is the bug. On ARM Mali (Valhall) GPUs this mismatch corrupted the per-token zero-points: negative zero-points came back as garbage (-k read as -2^23 - k), driving the quantized activation to the int8 floor, the per-group sums to -4096, and the GEMM output to garbage, producing garbled, runaway generation for 8da4w models (e.g. the Llama4-mini TISO TTS backbone on Mali-G715/G710). Adreno happened to tolerate the same mismatch and read correct values, so the corruption was Mali-specific even though the mismatch itself is general.

The per-token zero-point is serialized as fp32 by torchao design: Int8DynamicActivationIntxWeightConfig (8da4w) uses asymmetric per-token activation quant with an explicit fp32 zero_point_dtype. Decoding the serialized .pte confirms the zero-point tensor is FLOAT32, and (like the scale) it is stored in a texture as an rgba32f texel -- never rgba8i. The float allocation is the truth; the int8 shader access was the mismatched side.

The fix is to declare, store, and read the per-token zero-point as fp32 across the dq8ca qparams shaders, so the shader access dtype matches the tensor's allocation dtype and the texture is read as the rgba32f image it actually is. The zero-point value is integer-valued (nudged to [-128, 127]), so fp32 represents it exactly and the consumer's int(zp) conversion for the integer dequant-correction is lossless. This touches the dq8ca qparams shaders -- choose_qparams_per_row, quantize_and_pack_4h4w_with_group_sums, linear_dq8ca_q4gsw_tiled, the shared linear_int8_input_scales_zps_load helper, and the linear_q4gsw_coop variant (whose zero-point binding only matches the descriptor-set layout and is never read) -- plus a documentation comment in ChooseQParams.cpp.

Because the per-token qparams remain in texture storage (unchanged from before) and only the zero-point dtype changes, this is a pure runtime shader fix: existing texture-qparams 8da4w .pte files are corrected without re-export, since the texture already bakes the zero-point as rgba32f and the shader now reads it as such.

Authored with Claude Code.

Differential Revision: D109595977

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20491

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ef9eec1 with merge base 1227757 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2026
@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 24, 2026

Copy link
Copy Markdown

CLA Missing ID

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
[ghstack-poisoned]
SS-JIA pushed a commit that referenced this pull request Jun 26, 2026
…type

Pull Request resolved: #20491

The per-token dynamic-activation-quant (`dq8ca`) zero-point image must be bound in the shader with the same dtype the tensor was allocated with; a binding-vs-allocation dtype mismatch corrupts the per-token zero-point. The allocation dtype differs by export path: standard `export_llama -qmode 8da4w` models (e.g. Qwen3-0.6B) serialize the zero-point as `int8`, while the Llama4-mini TISO backbone (torchao `per_token_dynamic_quant` / `Int8DynamicActivationIntxWeightConfig` with an explicit fp32 `zero_point_dtype`) serializes it as float, which `vulkan_graph_builder.get_effective_dtype` downcasts to `half` under `force_fp16`.

A single fixed binding dtype cannot satisfy both paths. Binding the zero-point as `int8` (`rgba8i`) corrupts the float-allocated TISO zero-point on ARM Mali (Valhall) -- negative values come back as garbage, garbling the 8da4w TTS backbone. Conversely, binding it as the codegen `DTYPE` (matching the scale's float dtype) corrupts the int8-allocated zero-point: under fp16 inference the `rgba8i` image is read and written as `rgba16f`, saturating the per-token zero-point to the int8 floor/ceiling and garbling standard fp16 8da4w models such as Qwen3-0.6B.

This change makes the zero-point binding a codegen variant so it always matches the tensor's allocation. A new `ZP_DTYPE_MODE` axis emits two variants of every dq8ca shader that binds the per-token zero-point: `zpint8` (binding declared `int8`, an `rgba8i` integer image) and `zpinherit` (binding declared with the codegen `DTYPE`, inheriting the inference float dtype to match the scale -- `rgba32f`, or `rgba16f` under `USE_VULKAN_FP16_INFERENCE`). The C++ shader pickers select the variant from `graph.dtype_of(zero_point)` (`kChar` -> `zpint8`; `kHalf` / `kFloat` -> `zpinherit`), so the shader binding matches the tensor's allocation regardless of how the model was exported. The shared read helper is unchanged: `ivec4(texelFetch(t_int8_input_zps, ...))` already reads both an integer image (identity) and a float image (exact truncation of the integer-valued zero-point in `[-128, 127]`).

Affected shaders: `choose_qparams_per_row` (writes the zero-point, storing `ivec4` or `VEC4_T` per variant), `quantize_and_pack_4h4w_with_group_sums`, `linear_dq8ca_q4gsw_tiled`, and the dq8ca `linear_q4gsw_coop` variants (read the zero-point). This fixes the fp16 8da4w regression for standard int8 zero-point exports while preserving the float zero-point path that the TISO backbone and the original Mali fix depend on. Only the runtime shader binding changes, so existing `.pte` files are handled correctly with no re-export.

Authored with Claude Code.
ghstack-source-id: 397279874
@exported-using-ghexport

Differential Revision: [D109595977](https://our.internmc.facebook.com/intern/diff/D109595977/)
[ghstack-poisoned]
[ghstack-poisoned]
@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 8m 41s —— View job


Code Review

  • Read the PR diff
  • Analyze shader dtype handling correctness
  • Check shader-name dispatch logic
  • Review for edge cases & consistency
  • Post review

I traced the full dq8ca per-token path (choose_qparams → quantize_and_pack → linear) plus the shared loader and the coop binding. The fix is correct and well-targeted. Summary and a few minor notes below.

Correctness — looks right ✅

  • Allocation/access dtype now match. The zero-point texture is allocated as a float texel (rgba32f/rgba16f, like the scale), and in zpinherit mode the shader binding is declared with DTYPE so the texture is read as the float image it actually is. This is exactly the mismatch that corrupted Mali reads.
  • The shared loader unifies both modes cleanly. linear_int8_input_scales_zps_load.glslh:23-24 wrapping the fetch in ivec4(...) is a no-op for the int8 image (texelFetch already returns ivec4) and a lossless float→int conversion for the float image. One helper serves both ZP_DTYPE_MODE variants — nice.
  • Lossless round-trip. choose_qparams_per_row.glsl:202 stores integer-valued zps_out as VEC4_T(zps_out); since the nudged zero-point is in [-128,127], it's exactly representable in both fp32 and fp16, and the consumer's ivec4(...) truncation recovers it exactly. Good.
  • Arg-index lookups in the dispatch fns are correct — I checked each against its DynamicDispatchNode arg group:
    • ChooseQParams.cpp:44 args.at(0).refs.at(1)input_zps (write group {input_scales, input_zps} at :121) ✅
    • QuantizeDequantize.cpp:69 args.at(1).refs.at(2)packed_input_zps (read group {fp_input, packed_input_scales, packed_input_zps} at :229) ✅
    • QuantizedLinear.cpp:148 args.at(1).refs.at(4)packed_input_zp (read group ordering at :500-508) ✅
  • add_zp_dtype_mode_suffix mapping (ShaderNameUtils.cpp:75) maps kChar_zpint8, kHalf/kFloat_zpinherit, and throws otherwise. Matches the codegen axis and fails loudly on an unexpected dtype.
  • yaml variant generation is scoped correctly: the tiled/choose/quantize yamls add the axis globally (all their variants are dq8ca), while linear_q4gsw_coop.yaml adds it only under the DYNAMIC_QUANT_VARIANT: true variants, leaving the weight-only coop variants untouched.

Minor notes (non-blocking)

  1. linear_q4gsw_coop.glsl:43 also changed intint8 for the zpint8 case (not just adding the float branch). As the description notes this binding is never read in the coop/gemv path, so it's inert — but it's a slightly larger change than "add fp32 mode." Worth a one-line code comment on the binding stating it exists only to satisfy the descriptor-set layout and is never sampled, so a future reader doesn't try to use it.

  2. zpinherit assumes the zp tensor's float dtype equals the shader's DTYPE (input/scale dtype). That holds today because USE_VULKAN_FP16_INFERENCE flips input, scale, and zp together — but the suffix encodes only "inherit," not the actual dtype, so a future flow that allocates an fp32 zp alongside an fp16 input would silently reintroduce a texel mismatch. The VK_THROW only guards non-float dtypes, not float-vs-float divergence. Probably fine to leave, but a VK_CHECK_COND(dtype_of(input_zp) == dtype_of(input)) in the dispatch would make the invariant explicit.

  3. Test coverage. test_q4gsw_linear.cpp:100 and the choose-qparams custom op test use kChar zero-points, which exercise the zpint8 path. The newly-added zpinherit path (the one that fixes the real Mali bug) doesn't appear to have direct C++ test coverage. Since this is a runtime-only fix validated against a real model, consider adding a variant with an fp32 zero-point so the corrected path is regression-protected. Add fp32-zp coverage →

  4. define_required_extensions("texture3d", "int8") is now emitted unconditionally in choose_qparams_per_row.glsl:12 and quantize_and_pack_4h4w_with_group_sums.glsl:12 even in zpinherit mode where no int8 image is bound. Harmless (just enables an unused extension), but could be guarded behind $if ZP_DTYPE_MODE == "zpint8" for cleanliness.

Overall: focused, correct fix with a sensible codegen abstraction. The notes above are polish, not blockers.
· gh/SS-JIA/563/head

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants