[ExecuTorch][WebGPU] Dynamic resize hook for linear_q4gsw#20576
[ExecuTorch][WebGPU] Dynamic resize hook for linear_q4gsw#20576JulianCloudNTH wants to merge 2 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20576
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review and check for any areas or opportunities for modularization |
|
Claude finished @JulianCloudNTH's task in 1m 35s —— View job Review: Dynamic resize hook for
|
Stack from ghstack (oldest at bottom):
Make the 4-bit quantized linear serve any live M (rows) from one graph, so a dynamic prefill+decode graph computes correct-size outputs.
Problem:
linear_q4gswbaked its dispatch count,params.M, and output shape atbuild()for the max M. On a dynamic-shape graph at a smaller live M (e.g. decode M=1 vs prefill M=S) it would over-dispatch and leave the output sized at the max.Solution:
cur_dims, rewritesparams.M, updates the dispatchworkgroup_count_xusing the SAME kernel chosen at build (tiled GEMM or coop4 GEMV), and sets the outputcur_dims(= input dims with the last dim replaced by N). Inert until the input is resized.Implementation:
use_gemv= M==1 GEMV else tiled GEMM) is fixed at build; the hook only recomputes the dispatch count, the param UBO, and the output dims for the live M — it does not switch kernels (runtime M-switching is a separate optimization).own_uniform_bufferkeeps the param UBO alive so the hook can rewrite it.resize_q4gsw_linear_node(recompute M-derived dispatch each execute).Constraints: Behavior-neutral on static graphs (hook fires only when the input's live M differs from the max). No kernel/WGSL/numerics change. The coop4-vs-GEMM runtime kernel-switch decode optimization is deliberately out of scope (a later opt diff).
Co-authored-with: Claude Code.
Differential Revision: D109906094