Skip to content

[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583

Open
JulianCloudNTH wants to merge 1 commit into
gh/JulianCloudNTH/75/basefrom
gh/JulianCloudNTH/75/head
Open

[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583
JulianCloudNTH wants to merge 1 commit into
gh/JulianCloudNTH/75/basefrom
gh/JulianCloudNTH/75/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.

Problem: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (maxComputeWorkgroupsPerDimension, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill.

Solution: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from @builtin(num_workgroups).

  • Before: compute_1d_workgroup_count throws if count > limit; dispatch (count, 1, 1).
  • After: compute_2d_workgroup_count returns {count, 1} (fast path) or {limit, div_up(count, limit)}; dispatch (x, y, 1).

Implementation:

  • WgCount + pure fold_workgroup_count_2d + compute_2d_workgroup_count in WebGPUUtils.h (device-free, unit-testable; queried_max_workgroups factored out of the 1D path)
  • WebGPUDispatch.workgroup_count_y (default 1, declared last so existing aggregate inits are unchanged); both dispatchWorkgroups calls + the profiling record pass (x, y, 1)
  • Per-kernel in-shader reconstruction: thread-form idx = gid.x + gid.y*(num_workgroups.x*wg_size) (QK/AV/add); row-form row_idx = wid.x + wid.y*num_workgroups.x (softmax — keeps a valid predicate, not an early return, so workgroupBarrier()s stay uniform)
  • Sdpa.cpp: QK/softmax/AV counts via the 2D helper; the dynamic-input_pos resize hook recomputes both x and y for QK
  • Mirrors Vulkan dispatch (Vulkan itself does not guard the per-dim limit)

Constraints:

  • y=1 fast path keeps every non-folded dispatch byte-identical to the prior 1D path
  • Scope = prefill path only; rms_norm/embedding/lm_head/update_cache are row/token-indexed and never hit the cap, so they keep the 1D path
  • Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the uint32 element guard fires first at S~11585)

Co-authored-with: Claude Code.

Differential Revision: D109517684

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 28, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20583

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit d54c2d2 with merge base 55a71e6 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants