Skip to content

[ExecuTorch][WebGPU] Dynamic resize hook for linear_q4gsw#20576

Open
JulianCloudNTH wants to merge 2 commits into
gh/JulianCloudNTH/68/basefrom
gh/JulianCloudNTH/68/head
Open

[ExecuTorch][WebGPU] Dynamic resize hook for linear_q4gsw#20576
JulianCloudNTH wants to merge 2 commits into
gh/JulianCloudNTH/68/basefrom
gh/JulianCloudNTH/68/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Make the 4-bit quantized linear serve any live M (rows) from one graph, so a dynamic prefill+decode graph computes correct-size outputs.

Problem: linear_q4gsw baked its dispatch count, params.M, and output shape at build() for the max M. On a dynamic-shape graph at a smaller live M (e.g. decode M=1 vs prefill M=S) it would over-dispatch and leave the output sized at the max.

Solution:

  • Before: one fixed dispatch sized for the build-time M.
  • After: a tensor resize hook on the input recomputes the live M from cur_dims, rewrites params.M, updates the dispatch workgroup_count_x using the SAME kernel chosen at build (tiled GEMM or coop4 GEMV), and sets the output cur_dims (= input dims with the last dim replaced by N). Inert until the input is resized.

Implementation:

  • The build-time kernel select (use_gemv = M==1 GEMV else tiled GEMM) is fixed at build; the hook only recomputes the dispatch count, the param UBO, and the output dims for the live M — it does not switch kernels (runtime M-switching is a separate optimization).
  • own_uniform_buffer keeps the param UBO alive so the hook can rewrite it.
  • Mirrors Vulkan resize_q4gsw_linear_node (recompute M-derived dispatch each execute).

Constraints: Behavior-neutral on static graphs (hook fires only when the input's live M differs from the max). No kernel/WGSL/numerics change. The coop4-vs-GEMM runtime kernel-switch decode optimization is deliberately out of scope (a later opt diff).

Co-authored-with: Claude Code.

Differential Revision: D109906094

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 28, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20576

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review and check for any areas or opportunities for modularization

@claude

claude Bot commented Jun 29, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 1m 35s —— View job


Review: Dynamic resize hook for linear_q4gsw

  • Read the PR diff and surrounding code
  • Review correctness (resize hook, M computation, dispatch update)
  • Assess modularization opportunities
  • Post review feedback

The change is well-scoped and reads cleanly — it mirrors the Vulkan resize_q4gsw_linear_node pattern, keeps the kernel-select build-time-fixed as the PR description claims, and is correctly inert on static graphs (the hook only fires when in_id is dirtied). own_uniform_buffer correctly replaces the immediate wgpuBufferRelease so the UBO survives for wgpuQueueWriteBuffer. No numerics/WGSL change. Overall this looks correct.

Below are correctness notes and the modularization opportunities you asked about.


Correctness / robustness

1. The hook drops the numel % K == 0 check the build path has. The build path validates divisibility (QuantizedLinear.cpp:73), but the hook just does numel / K (line 281-282) and silently truncates if a live shape isn't a clean multiple of K. In practice the last input dim is always K so this can't trigger, but mirroring the build-time guard is cheap insurance and keeps the two paths behaviorally identical.

2. Minor: the GEMV branch in the hook omits the outputs == 0 guard. Build checks outputs == 0u || outputs > UINT32_MAX (line 131); the hook only checks the upper bound (line 289). Harmless here since m == 0 already throws and N is fixed-nonzero, but see the dedup below which makes the two consistent for free.

3. Lint is red. Dr. CI reports lintrunner failing on this file. Line 260-261 (add_dispatch(...)) looks like a clang-format wrap issue. Worth a local lintrunner -a before merge.


Modularization opportunities

The build path and the hook now contain two near-identical blocks: the workgroup-count computation and the Q4gswParams population. Both can be deduped, which also shrinks the 12-entry capture list.

A. Extract the workgroup-count computation. The GEMV/GEMM branch logic is duplicated verbatim (build 127-148, hook 287-306). A file-local helper makes them a single source of truth:

uint32_t q4gsw_workgroup_count(
    WGPUDevice device, uint32_t m, uint32_t N, bool use_gemv,
    uint32_t wg_size, const char* op_name) {
  if (use_gemv) {
    const uint64_t outputs = static_cast<uint64_t>(m) * N;
    if (outputs == 0u || outputs > UINT32_MAX) {
      throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range");
    }
    return utils::clamp_workgroup_count(device, static_cast<uint32_t>(outputs));
  }
  const int64_t total_tiles = utils::div_up<int64_t>(m, kQ4gswTileM) *
      utils::div_up<int64_t>(N, kQ4gswTileN);
  if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
    throw std::runtime_error(
        "WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
  }
  return utils::compute_1d_workgroup_count(
      device, static_cast<uint32_t>(total_tiles), wg_size, op_name);
}

Both sites then call q4gsw_workgroup_count(...), which also resolves note #2 and the slightly divergent error string ("exceeds the dispatch limit" vs "...1D dispatch limit").

B. Capture the params struct by value instead of rebuilding it. The hook re-populates every field of Q4gswParams (lines 307-314) when only M changes. Capturing the already-built params in a mutable lambda eliminates four captures (K_packed, gs, padded_N, has_bias) and the whole re-populate block:

graph.add_tensor_resize_hook(
    in_id,
    [in_id, out_id, K, N, wg_size, use_gemv, dispatch_idx, uniform_buffer, params]
    (WebGPUGraph& g) mutable {
      const auto& d = g.cur_dims(in_id);
      const uint32_t m = static_cast<uint32_t>(utils::numel_of(d) / K);
      if (m == 0u) {
        throw std::runtime_error("WebGPU linear_q4gsw: live M == 0");
      }
      params.M = m;
      wgpuQueueWriteBuffer(g.queue(), uniform_buffer, 0, &params, sizeof(params));
      g.dispatch_at(dispatch_idx).workgroup_count_x =
          q4gsw_workgroup_count(g.device(), m, N, use_gemv, wg_size,
                                "linear_q4gsw(resize)");
      std::vector<int64_t> od(d.begin(), d.end());
      od.back() = static_cast<int64_t>(N);
      g.set_cur_dims(out_id, od);
    });

This also lets you drop the WGPUBuffer params_buf = uniform_buffer; alias (line 264) — capture uniform_buffer directly. The capture list goes from 12 entries to 8, and the lambda body roughly halves.

C. (Optional, larger) The build path itself could call the same helper for its workgroup_count so there is exactly one dispatch-sizing code path. Given the build path has extra zero-asserts, this is lower-value than A/B, but it would make the "build and resize agree by construction" invariant structural rather than reviewer-verified.

Apply modularization A+B →


None of these block correctness; A and B are the high-value cleanups (one source of truth for dispatch sizing + a much smaller lambda). I haven't pushed any changes — this is review-only per your request.

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants