[ET-VK][patterns] Fuse torchao 4-bit quantized embedding to embedding_q4gsw#20381
[ET-VK][patterns] Fuse torchao 4-bit quantized embedding to embedding_q4gsw#20381SS-JIA wants to merge 6 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20381
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f009138 with merge base 1227757 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This PR needs a
|
…_q4gsw Pull Request resolved: #20381 TISO and other torchao-quantized models emit a `torchao.dequantize_affine -> aten.embedding` subgraph for their weight-only int4 quantized embedding. The existing `QuantizedEmbeddingMatch` only matches the `quantized_decomposed.embedding_4bit.dtype` fused op, so the torchao embedding never fused: its `dequantize_affine` const-folded to an fp32 weight, the resulting `aten.embedding` exceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model. This adds a separate `TorchAOQuantizedEmbeddingMatch` matcher that recognizes the torchao int4 `dequantize_affine -> aten.embedding` shape (qmin=-8/qmax=7, per-row group block_size `[1, G]`) and rewrites it to the existing `et_vk.embedding_q4gsw.default` op, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once by recording the repack against the weight's state-dict FQN via `register_param_mutation` (a per-ExportedProgram registry on `ep._et_vk_param_modification_tags`); a second match on the same tied weight sees the recorded tag and skips the in-place repack. It is kept as a separate class from `QuantizedEmbeddingMatch` because the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths. On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized `.pte` drops from 418 MiB to 348 MiB. This change was authored with Claude. ghstack-source-id: 397415221 @exported-using-ghexport Differential Revision: [D108457797](https://our.internmc.facebook.com/intern/diff/D108457797/)
…_q4gsw Pull Request resolved: #20381 TISO and other torchao-quantized models emit a `torchao.dequantize_affine -> aten.embedding` subgraph for their weight-only int4 quantized embedding. The existing `QuantizedEmbeddingMatch` only matches the `quantized_decomposed.embedding_4bit.dtype` fused op, so the torchao embedding never fused: its `dequantize_affine` const-folded to an fp32 weight, the resulting `aten.embedding` exceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model. This adds a separate `TorchAOQuantizedEmbeddingMatch` matcher that recognizes the torchao int4 `dequantize_affine -> aten.embedding` shape (qmin=-8/qmax=7, per-row group block_size `[1, G]`) and rewrites it to the existing `et_vk.embedding_q4gsw.default` op, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once by recording the repack against the weight's state-dict FQN via `register_param_mutation` (a per-ExportedProgram registry on `ep._et_vk_param_modification_tags`); a second match on the same tied weight sees the recorded tag and skips the in-place repack. It is kept as a separate class from `QuantizedEmbeddingMatch` because the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths. On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized `.pte` drops from 418 MiB to 348 MiB. This change was authored with Claude. ghstack-source-id: 397447720 @exported-using-ghexport Differential Revision: [D108457797](https://our.internmc.facebook.com/intern/diff/D108457797/)
|
@claude review this code |
|
Claude finished @metascroy's task in 7m 54s —— View job Code Review: Fuse torchao 4-bit quantized embedding
Overall this is a well-structured change. The new 1. Tied weight + shared scales may be mismatched (correctness — please confirm)The whole point of always repacking to the linear-weight layout and reusing the
If a tied embedding/LM-head shares the same scales placeholder (which torchao tied-weight quantization would produce) and the linear pattern also fuses to Could you confirm whether tied weight with shared scales + both ops fusing to q4gsw can co-occur on the TISO backbone? If so this needs a guard or a separate scales tensor for the embedding. If the LM-head isn't q4gsw on this model, a note in the PR description would help. 2.
|
|
@claude review this code |
|
Claude finished @metascroy's task in 7m 16s —— View job Code Review
Overall this is a well-structured, carefully-documented change. The new 1. Tied weight + shared scales between an embedding and a q4gsw linear (most important to confirm)The embedding replacement passes The linear replacement, however, transposes scales in place to The PR's whole premise is that an LLM ties the embedding weight to the LM-head linear weight, and the weight repack is deliberately shared via the
Because the embedding path never registers a scales mutation, the 2.
|
…_q4gsw Pull Request resolved: #20381 TISO and other torchao-quantized models emit a `torchao.dequantize_affine -> aten.embedding` subgraph for their weight-only int4 quantized embedding. The existing `QuantizedEmbeddingMatch` only matches the `quantized_decomposed.embedding_4bit.dtype` fused op, so the torchao embedding never fused: its `dequantize_affine` const-folded to an fp32 weight, the resulting `aten.embedding` exceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model. This adds a separate `TorchAOQuantizedEmbeddingMatch` matcher that recognizes the torchao int4 `dequantize_affine -> aten.embedding` shape (qmin=-8/qmax=7, per-row group block_size `[1, G]`) and rewrites it to the existing `et_vk.embedding_q4gsw.default` op, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once by recording the repack against the weight's state-dict FQN via `register_param_mutation` (a per-ExportedProgram registry on `ep._et_vk_param_modification_tags`); a second match on the same tied weight sees the recorded tag and skips the in-place repack. It is kept as a separate class from `QuantizedEmbeddingMatch` because the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths. On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized `.pte` drops from 418 MiB to 348 MiB. This change was authored with Claude. ghstack-source-id: 397529341 @exported-using-ghexport Differential Revision: [D108457797](https://our.internmc.facebook.com/intern/diff/D108457797/)
Stack from ghstack (oldest at bottom):
TISO and other torchao-quantized models emit a
torchao.dequantize_affine -> aten.embeddingsubgraph for their weight-only int4 quantized embedding. The existingQuantizedEmbeddingMatchonly matches thequantized_decomposed.embedding_4bit.dtypefused op, so the torchao embedding never fused: itsdequantize_affineconst-folded to an fp32 weight, the resultingaten.embeddingexceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model.This adds a separate
TorchAOQuantizedEmbeddingMatchmatcher that recognizes the torchao int4dequantize_affine -> aten.embeddingshape (qmin=-8/qmax=7, per-row group block_size[1, G]) and rewrites it to the existinget_vk.embedding_q4gsw.defaultop, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once via anet_vk_embedding_q4gsw_packedmeta flag. It is kept as a separate class fromQuantizedEmbeddingMatchbecause the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths.On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized
.ptedrops from 418 MiB to 348 MiB.This change was authored with Claude.
Differential Revision: D108457797