Skip to content

[ExecuTorch][WebGPU] Dynamic resize hook for SDPA (live seq-len S)#20580

Open
JulianCloudNTH wants to merge 1 commit into
gh/JulianCloudNTH/72/basefrom
gh/JulianCloudNTH/72/head
Open

[ExecuTorch][WebGPU] Dynamic resize hook for SDPA (live seq-len S)#20580
JulianCloudNTH wants to merge 1 commit into
gh/JulianCloudNTH/72/basefrom
gh/JulianCloudNTH/72/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Make sdpa_with_kv_cache serve any live seq-len S from one graph (batched prefill S=K and decode S=1).

Problem: the existing dynamic path only reacted to a live input_pos (decode), with S captured at build time. It rewrote the QK dispatch (which depends on context_len) but left update_cache, softmax, and AV sized for the build-time S. Under a dynamic seq-len S (one graph serving prefill and decode), kv_numel, the QK/AV tile grids, and the softmax row count all depend on S and were stale.

Solution: a single recompute hook driven by either a live S (q tensor resize) or a live input_pos (SymInt), recomputing every per-step quantity from the live shape.

  • Before: hook keyed only on input_pos; recomputes ctx + QK count; S fixed.
  • After: hook keyed on q (always) and input_pos (when SymInt); reads live S from cur_dims(q) and live pos, recomputes all five dispatches' counts + UBOs (update_cache K/V, QK, softmax, AV), and sets the output cur_dims to q's.

Implementation:

  • Capture the update_cache/softmax/AV dispatch indices (previously only QK) so their workgroup counts can be rewritten per step.
  • QK/AV workgroup counts use the landed register-tiled grids (Hq*ceil(S/TM)*ceil(ctx-or-D/TN)); softmax is one workgroup per Hq*S row.
  • Register the hook on q unconditionally — inert until q is resized, so a static graph is byte-identical.
  • Mirrors Vulkan DynamicDispatchNode (recompute workgroups per execute); scratch is sized at build (S=max, ctx=Cmax) so buffers never move and bind groups stay valid.

Constraints: fp32-only, batch=1, GQA, is_causal=true, D%4==0 invariants unchanged; the static / decode-only paths are unaffected (the q hook never fires without a resize).

Co-authored-with: Claude Code.

Differential Revision: D109906097

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 28, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20580

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d2a1ecc with merge base 55a71e6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants