[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583
Open
JulianCloudNTH wants to merge 1 commit into
Open
[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583JulianCloudNTH wants to merge 1 commit into
JulianCloudNTH wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20583
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit d54c2d2 with merge base 55a71e6 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was referenced Jun 28, 2026
This PR needs a
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.
Problem: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (
maxComputeWorkgroupsPerDimension, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill.Solution: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from
@builtin(num_workgroups).compute_1d_workgroup_countthrows ifcount > limit; dispatch(count, 1, 1).compute_2d_workgroup_countreturns{count, 1}(fast path) or{limit, div_up(count, limit)}; dispatch(x, y, 1).Implementation:
WgCount+ purefold_workgroup_count_2d+compute_2d_workgroup_countinWebGPUUtils.h(device-free, unit-testable;queried_max_workgroupsfactored out of the 1D path)WebGPUDispatch.workgroup_count_y(default 1, declared last so existing aggregate inits are unchanged); bothdispatchWorkgroupscalls + the profiling record pass(x, y, 1)idx = gid.x + gid.y*(num_workgroups.x*wg_size)(QK/AV/add); row-formrow_idx = wid.x + wid.y*num_workgroups.x(softmax — keeps avalidpredicate, not an early return, soworkgroupBarrier()s stay uniform)Sdpa.cpp: QK/softmax/AV counts via the 2D helper; the dynamic-input_posresize hook recomputes both x and y for QKConstraints:
y=1fast path keeps every non-folded dispatch byte-identical to the prior 1D pathrms_norm/embedding/lm_head/update_cacheare row/token-indexed and never hit the cap, so they keep the 1D pathuint32element guard fires first at S~11585)Co-authored-with: Claude Code.
Differential Revision: D109517684