Skip to content

[ET-VK][patterns] Fuse torchao 4-bit quantized embedding to embedding_q4gsw#20381

Open
SS-JIA wants to merge 6 commits into
gh/SS-JIA/560/basefrom
gh/SS-JIA/560/head
Open

[ET-VK][patterns] Fuse torchao 4-bit quantized embedding to embedding_q4gsw#20381
SS-JIA wants to merge 6 commits into
gh/SS-JIA/560/basefrom
gh/SS-JIA/560/head

Conversation

@SS-JIA

@SS-JIA SS-JIA commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

TISO and other torchao-quantized models emit a torchao.dequantize_affine -> aten.embedding subgraph for their weight-only int4 quantized embedding. The existing QuantizedEmbeddingMatch only matches the quantized_decomposed.embedding_4bit.dtype fused op, so the torchao embedding never fused: its dequantize_affine const-folded to an fp32 weight, the resulting aten.embedding exceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model.

This adds a separate TorchAOQuantizedEmbeddingMatch matcher that recognizes the torchao int4 dequantize_affine -> aten.embedding shape (qmin=-8/qmax=7, per-row group block_size [1, G]) and rewrites it to the existing et_vk.embedding_q4gsw.default op, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once via an et_vk_embedding_q4gsw_packed meta flag. It is kept as a separate class from QuantizedEmbeddingMatch because the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths.

On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized .pte drops from 418 MiB to 348 MiB.

This change was authored with Claude.

Differential Revision: D108457797

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 18, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20381

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f009138 with merge base 1227757 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 18, 2026
@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 18, 2026

Copy link
Copy Markdown

CLA Missing ID

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
SS-JIA pushed a commit that referenced this pull request Jun 26, 2026
…_q4gsw

Pull Request resolved: #20381

TISO and other torchao-quantized models emit a `torchao.dequantize_affine -> aten.embedding` subgraph for their weight-only int4 quantized embedding. The existing `QuantizedEmbeddingMatch` only matches the `quantized_decomposed.embedding_4bit.dtype` fused op, so the torchao embedding never fused: its `dequantize_affine` const-folded to an fp32 weight, the resulting `aten.embedding` exceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model.

This adds a separate `TorchAOQuantizedEmbeddingMatch` matcher that recognizes the torchao int4 `dequantize_affine -> aten.embedding` shape (qmin=-8/qmax=7, per-row group block_size `[1, G]`) and rewrites it to the existing `et_vk.embedding_q4gsw.default` op, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once by recording the repack against the weight's state-dict FQN via `register_param_mutation` (a per-ExportedProgram registry on `ep._et_vk_param_modification_tags`); a second match on the same tied weight sees the recorded tag and skips the in-place repack. It is kept as a separate class from `QuantizedEmbeddingMatch` because the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths.

On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized `.pte` drops from 418 MiB to 348 MiB.

This change was authored with Claude.
ghstack-source-id: 397415221
@exported-using-ghexport

Differential Revision: [D108457797](https://our.internmc.facebook.com/intern/diff/D108457797/)
[ghstack-poisoned]
SS-JIA pushed a commit that referenced this pull request Jun 26, 2026
…_q4gsw

Pull Request resolved: #20381

TISO and other torchao-quantized models emit a `torchao.dequantize_affine -> aten.embedding` subgraph for their weight-only int4 quantized embedding. The existing `QuantizedEmbeddingMatch` only matches the `quantized_decomposed.embedding_4bit.dtype` fused op, so the torchao embedding never fused: its `dequantize_affine` const-folded to an fp32 weight, the resulting `aten.embedding` exceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model.

This adds a separate `TorchAOQuantizedEmbeddingMatch` matcher that recognizes the torchao int4 `dequantize_affine -> aten.embedding` shape (qmin=-8/qmax=7, per-row group block_size `[1, G]`) and rewrites it to the existing `et_vk.embedding_q4gsw.default` op, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once by recording the repack against the weight's state-dict FQN via `register_param_mutation` (a per-ExportedProgram registry on `ep._et_vk_param_modification_tags`); a second match on the same tied weight sees the recorded tag and skips the in-place repack. It is kept as a separate class from `QuantizedEmbeddingMatch` because the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths.

On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized `.pte` drops from 418 MiB to 348 MiB.

This change was authored with Claude.
ghstack-source-id: 397447720
@exported-using-ghexport

Differential Revision: [D108457797](https://our.internmc.facebook.com/intern/diff/D108457797/)
@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 7m 54s —— View job


Code Review: Fuse torchao 4-bit quantized embedding

  • Read the PR diff
  • Analyze quantized_embedding.py (core change)
  • Analyze weight_packing_utils.py and utils.py helpers
  • Review refactors in quantized_linear.py / quantized_convolution.py
  • Check tests
  • Post review

Overall this is a well-structured change. The new TorchAOQuantizedEmbeddingMatch is carefully guarded (qmin/qmax, zero_point symmetry, block_size, embed_dim % 32), the matcher correctly keys off the dequant node's declared dtype/shape rather than the (mutable) weight placeholder meta, and register_param_mutation is a clean replacement for the old value-equality _detect_tied_linear_weight scan. Tests cover the happy path, the rejection guard, shared-weight dedup, and the mutation registry. A few things worth a look:


1. Tied weight + shared scales may be mismatched (correctness — please confirm)

The whole point of always repacking to the linear-weight layout and reusing the "4 bit linear weight" tag is to dedup a weight tied between the embedding and the LM-head linear. The weight is handled correctly, but the scales are not handled symmetrically:

  • make_linear_q4gsw_op transposes the scales to [num_groups, N] in place and registers "4 bit linear scales" (quantized_linear.py:348-361).
  • The torchao embedding replacement uses match.scales_node as-is and never transposes/registers it. embedding_q4gsw_impl expects scales shaped [vocab, num_groups] (custom_ops_lib.py:1011).

If a tied embedding/LM-head shares the same scales placeholder (which torchao tied-weight quantization would produce) and the linear pattern also fuses to linear_q4gsw, the in-place transpose corrupts the layout the embedding op expects — and there's no register_param_mutation guard linking the two scale conventions, so this wouldn't raise; it would silently miscompute. Whichever pattern runs second loses.

Could you confirm whether tied weight with shared scales + both ops fusing to q4gsw can co-occur on the TISO backbone? If so this needs a guard or a separate scales tensor for the embedding. If the LM-head isn't q4gsw on this model, a note in the PR description would help.

2. even_value_high_nibble is dead code

pack_4bit_weight_tensor(..., even_value_high_nibble=...) is never called with True anywhere — the decomposed embedding path unpacks and repacks with the default (linear) convention. Per the repo's "choose simpler, more concise" guidance, consider dropping the parameter and its docstring branch until there's a caller. Fix this →

3. BUCK srcs ordering (nit)

In backends/vulkan/patterns/BUCK, weight_packing_utils.py is inserted between rms_norm.py and sdpa.py, breaking the otherwise-alphabetical srcs list (it should go after select_as_symint.py).


Minor / non-blocking

  • register_param_mutation persists on ep with no reset — correctly documented, and the persistence is actually what makes repeated FusePatternsPass runs idempotent. Good. Just be aware that sharing the "4 bit linear weight" tag across the linear, decomposed-embedding, and torchao-embedding paths is load-bearing for dedup; the conv path's distinct "8 bit conv2d weight" tag is correct.
  • zero_point re-verification on sibling matches — for a shared weight, only the first match verifies zero_point == 0 (the check lives inside the register_param_mutation guard). This assumes sibling dequant nodes over the same weight share quant params, which is virtually always true, but technically unverified for the 2nd+ call site.
  • pack_4bit_weight_tensor calls .min().item()/.max().item() — fine here since it only runs on real param tensors inside the replacement, never on fake meta.

Nice work on the thorough inline comments explaining the meta-vs-placeholder reasoning — they make the matcher much easier to follow.
• branch gh/SS-JIA/560/head

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 7m 16s —— View job


Code Review

  • Read changed files
  • Analyze the new TorchAO matcher and repacking logic
  • Review shared-weight repacking guard (register_param_mutation) and refactors
  • Post review feedback

Overall this is a well-structured, carefully-documented change. The new TorchAOQuantizedEmbeddingMatch cleanly mirrors the existing matcher, the matcher-vs-replacement split (cheap guards in the matcher reading stable dequant meta, expensive tensor verification in the replacement) is the right call, and extracting register_param_mutation + pack_4bit_weight_tensor removes real duplication. The tests are thorough — shared-weight regression, bad-embed_dim rejection, and end-to-end numerical checks. A few things worth a look:


1. Tied weight + shared scales between an embedding and a q4gsw linear (most important to confirm)

The embedding replacement passes match.scales_node through untouched — it expects scales in the native [vocab, num_groups] layout (see embedding_q4gsw_impl, custom_ops_lib.py:1011, num_groups = weight_scales.shape[1]). It does not call register_param_mutation on the scales node.

The linear replacement, however, transposes scales in place to [num_groups, N] and guards it under tag "4 bit linear scales" (quantized_linear.py:349-361, force_update=True).

The PR's whole premise is that an LLM ties the embedding weight to the LM-head linear weight, and the weight repack is deliberately shared via the "4 bit linear weight" tag so both ops see identical packed bytes. But if that tied pair also shares the same scales placeholder, then:

  • the linear transposes the shared scales in place, and
  • the embedding then reads [num_groups, vocab] while interpreting it as [vocab, num_groups] → silently wrong output (or a shape error when num_groups != vocab).

Because the embedding path never registers a scales mutation, the register_param_mutation conflict guard won't catch this — the transpose is invisible to it. Could you confirm whether the tied embedding/LM-head case shares the scales constant, or only the weight? If scales can be shared, this needs handling (e.g. the embedding registering a distinct scales tag so a conflicting in-place transpose raises instead of corrupting). The existing tests share only an embedding weight across two aten.embedding sites, so they don't exercise this cross-op path.

2. quant_min/quant_max read positionally (quantized_embedding.py:202-204)

quant_min = dequant_node.args[5], quant_max = dequant_node.args[6]. This assumes the edge-dialect node has these normalized into positional args. If torchao.dequantize_affine ever surfaces them as kwargs, len(args) > 5/6 is false → quant_min/quant_max become None → the match is silently rejected and the embedding falls back to CPU (the exact bloat regression this PR fixes). Edge ops are normally arg-normalized so this is likely fine in practice, but a quick check of dequant_node.kwargs (or asserting it's empty) would make the failure mode loud rather than a silent un-fuse.

3. Dead parameter even_value_high_nibble in pack_4bit_weight_tensor (weight_packing_utils.py:13)

No caller passes even_value_high_nibble=True — the original embedding path unpacks and then repacks into the linear convention (default False), and both linear/conv use the default. The embedding-convention branch is currently unreachable. Either drop it to keep the helper minimal (matches the repo's "no speculative API" guidance), or keep it only if a follow-up in the stack uses it.

4. Minor: PR description is stale vs. implementation

The body says repacking is guarded "via an et_vk_embedding_q4gsw_packed meta flag," but the code actually uses utils.register_param_mutation (storage-FQN-keyed registry on the ExportedProgram). Worth syncing the description so reviewers aren't hunting for a meta flag that doesn't exist.


Things I checked that look correct

  • Nibble round-trip: original-embedding unpack (high/low) → pack_4bit_weight_tensor (linear convention odd<<4 | even) → embedding_q4gsw_impl(is_linear_weight=True) restores the right element order. ✓
  • register_param_mutation keying on the state-dict FQN rather than node identity or data_ptr() is correct — aliased placeholders resolving to one storage are handled, and the non-param rejection up front is a good guardrail.
  • embed_dim % 32 guard read from the stable dequant output meta (not the possibly-repacked weight placeholder meta) — correctly rejects shapes the shader's VK_CHECK would abort on, falling back to CPU instead. The conv weight-scales align_width call staying unguarded is also fine since it's idempotent without force_update, whereas the guarded weight/linear-scales mutations are not.

[ghstack-poisoned]
SS-JIA pushed a commit that referenced this pull request Jun 26, 2026
…_q4gsw

Pull Request resolved: #20381

TISO and other torchao-quantized models emit a `torchao.dequantize_affine -> aten.embedding` subgraph for their weight-only int4 quantized embedding. The existing `QuantizedEmbeddingMatch` only matches the `quantized_decomposed.embedding_4bit.dtype` fused op, so the torchao embedding never fused: its `dequantize_affine` const-folded to an fp32 weight, the resulting `aten.embedding` exceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model.

This adds a separate `TorchAOQuantizedEmbeddingMatch` matcher that recognizes the torchao int4 `dequantize_affine -> aten.embedding` shape (qmin=-8/qmax=7, per-row group block_size `[1, G]`) and rewrites it to the existing `et_vk.embedding_q4gsw.default` op, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once by recording the repack against the weight's state-dict FQN via `register_param_mutation` (a per-ExportedProgram registry on `ep._et_vk_param_modification_tags`); a second match on the same tied weight sees the recorded tag and skips the in-place repack. It is kept as a separate class from `QuantizedEmbeddingMatch` because the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths.

On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized `.pte` drops from 418 MiB to 348 MiB.

This change was authored with Claude.
ghstack-source-id: 397529341
@exported-using-ghexport

Differential Revision: [D108457797](https://our.internmc.facebook.com/intern/diff/D108457797/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants