[ExecuTorch][WebGPU] Enable FlashDecoding by default for decode SDPA (runtime shape gate)#20544
[ExecuTorch][WebGPU] Enable FlashDecoding by default for decode SDPA (runtime shape gate)#20544JulianCloudNTH wants to merge 3 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20544
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 0c7e1d6 with merge base 1227757 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude Check for overengineering, areas for modularization, and functional correctness |
|
Claude finished @JulianCloudNTH's task in 9m 2s —— View job PR Review: Enable FlashDecoding by default for decode SDPAI read the gate change in
Overall this is a clean, well-scoped change. The gate is correct, the scratch-buffer move is a nice win, and the kernels are numerically equivalent to the materialized path for the decode case. Findings below, ordered by impact. Functional correctness ✅ (with notes)The gate and the no-mask assumption are correct. For decode ( Empty splits in the capped case are handled correctly — worth confirming you intended this. Defensive guard message is stale. "WebGPU sdpa FlashDecoding: head dim must be <= " +
std::to_string(kSdpaFdMaxHeadDim);Modularization — main finding
These differences are all expressible in one signature: I understand the kill-switch design wants FD code to compile out cleanly when Overengineering — minor
Build / packaging — verify outside OSSThe PR description says the option default is flipped in "fbcode + xplat" CMake and that Nice work — the runtime shape gate is the right call and the dead-code analysis in the description (FD's limits all below WebGPU baseline) is convincing. The one thing I'd act on before landing is the |
…(runtime shape gate) Pull Request resolved: #20544 **Makes split-KV FlashDecoding the default decode-attention path** (it was shipped dormant behind a default-OFF compile flag). FD is the fastest WebGPU SDPA decode arm (**+178% vs naive**, M4 Pro, isolated op); this turns it on for production and selects it at runtime by a shape-capability predicate. {F1991715077} **Problem:** the FD kernel is correct and measured (+178%) but compile-gated OFF, so no production build used it. A device-limit gate (web-llm-style `maxStorageBufferBindingSize`) was considered but is dead code here: FD's resource needs (workgroup size 64, 512 B shared memory, 5 storage bindings) are all below WebGPU's baseline minimum limits, and FD binds the same K/V caches as the materialized fallback — so no spec-compliant device can run materialized decode but fail FD. The only selection criterion with real effect is shape. **Solution:** enable FD by default and select it at runtime on shape, not device. - **Before:** `EXECUTORCH_BUILD_WEBGPU_SDPA_FD` default OFF; FD code unlinked; every decode used the materialized QK/softmax/AV path. - **After:** flag default ON (kept as a build-time kill-switch); decode (`S == 1`, static input_pos) with head dim `<= kSdpaFdMaxHeadDim` uses FD; other shapes (including head dim > 128) fall through to the materialized path. **Implementation:** - `Sdpa.cpp`: extend the FD selection predicate with `D <= kSdpaFdMaxHeadDim` so unsupported head dims fall through instead of throwing. - `SdpaFdDecode.h`: expose `kSdpaFdMaxHeadDim` (FD's lane-owns-D reach) as the single source of truth; `SdpaFdDecode.cpp` ties it to `WG_SIZE * MAX_D_PER_LANE` with a `static_assert`. - `CMakeLists.txt` (fbcode + xplat): flip the option default to ON; OFF remains a kill-switch that drops all FlashDecoding code. - `test_webgpu_native_ci.sh`: drop the now-redundant explicit `=ON` flag so CI builds and tests the default. - Mirrors Vulkan `backends/vulkan/runtime/graph/ops/impl/SDPA.cpp` shape-based kernel selection (`is_single_token`); no device-adaptive gate, matching the Vulkan delegate. **Constraints:** decode-only (`S == 1`), static input_pos (dynamic-pos decode still uses the materialized path); fp32, buffer-only; the FD kernels are unchanged by this diff. Co-authored with Claude Code. ghstack-source-id: 397435149 @exported-using-ghexport Differential Revision: [D109520722](https://our.internmc.facebook.com/intern/diff/D109520722/)
…(runtime shape gate) Pull Request resolved: #20544 **Makes split-KV FlashDecoding the default decode-attention path** (it was shipped dormant behind a default-OFF compile flag). FD is the fastest WebGPU SDPA decode arm (**+178% vs naive**, M4 Pro, isolated op); this turns it on for production and selects it at runtime by a shape-capability predicate. {F1991715077} **Problem:** the FD kernel is correct and measured (+178%) but compile-gated OFF, so no production build used it. A device-limit gate (web-llm-style `maxStorageBufferBindingSize`) was considered but is dead code here: FD's resource needs (workgroup size 64, 512 B shared memory, 5 storage bindings) are all below WebGPU's baseline minimum limits, and FD binds the same K/V caches as the materialized fallback — so no spec-compliant device can run materialized decode but fail FD. The only selection criterion with real effect is shape. **Solution:** enable FD by default and select it at runtime on shape, not device. - **Before:** `EXECUTORCH_BUILD_WEBGPU_SDPA_FD` default OFF; FD code unlinked; every decode used the materialized QK/softmax/AV path. - **After:** flag default ON (kept as a build-time kill-switch); decode (`S == 1`, static input_pos) with head dim `<= kSdpaFdMaxHeadDim` uses FD; other shapes (including head dim > 128) fall through to the materialized path. **Implementation:** - `Sdpa.cpp`: extend the FD selection predicate with `D <= kSdpaFdMaxHeadDim` so unsupported head dims fall through instead of throwing. - `SdpaFdDecode.h`: expose `kSdpaFdMaxHeadDim` (FD's lane-owns-D reach) as the single source of truth; `SdpaFdDecode.cpp` ties it to `WG_SIZE * MAX_D_PER_LANE` with a `static_assert`. - `CMakeLists.txt` (fbcode + xplat): flip the option default to ON; OFF remains a kill-switch that drops all FlashDecoding code. - `test_webgpu_native_ci.sh`: drop the now-redundant explicit `=ON` flag so CI builds and tests the default. - Mirrors Vulkan `backends/vulkan/runtime/graph/ops/impl/SDPA.cpp` shape-based kernel selection (`is_single_token`); no device-adaptive gate, matching the Vulkan delegate. **Constraints:** decode-only (`S == 1`), static input_pos (dynamic-pos decode still uses the materialized path); fp32, buffer-only; the FD kernels are unchanged by this diff. Co-authored with Claude Code. ghstack-source-id: 397454762 @exported-using-ghexport Differential Revision: [D109520722](https://our.internmc.facebook.com/intern/diff/D109520722/)
SS-JIA
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
Stack from ghstack (oldest at bottom):
Makes split-KV FlashDecoding the default decode-attention path (it was shipped dormant behind a default-OFF compile flag). FD is the fastest WebGPU SDPA decode arm (+178% vs naive, M4 Pro, isolated op); this turns it on for production and selects it at runtime by a shape-capability predicate.
{F1991715077}
Problem: the FD kernel is correct and measured (+178%) but compile-gated OFF, so no production build used it. A device-limit gate (web-llm-style
maxStorageBufferBindingSize) was considered but is dead code here: FD's resource needs (workgroup size 64, 512 B shared memory, 5 storage bindings) are all below WebGPU's baseline minimum limits, and FD binds the same K/V caches as the materialized fallback — so no spec-compliant device can run materialized decode but fail FD. The only selection criterion with real effect is shape.Solution: enable FD by default and select it at runtime on shape, not device.
EXECUTORCH_BUILD_WEBGPU_SDPA_FDdefault OFF; FD code unlinked; every decode used the materialized QK/softmax/AV path.S == 1, static input_pos) with head dim<= kSdpaFdMaxHeadDimuses FD; other shapes (including head dim > 128) fall through to the materialized path.Implementation:
Sdpa.cpp: extend the FD selection predicate withD <= kSdpaFdMaxHeadDimso unsupported head dims fall through instead of throwing.SdpaFdDecode.h: exposekSdpaFdMaxHeadDim(FD's lane-owns-D reach) as the single source of truth;SdpaFdDecode.cppties it toWG_SIZE * MAX_D_PER_LANEwith astatic_assert.CMakeLists.txt(fbcode + xplat): flip the option default to ON; OFF remains a kill-switch that drops all FlashDecoding code.test_webgpu_native_ci.sh: drop the now-redundant explicit=ONflag so CI builds and tests the default.backends/vulkan/runtime/graph/ops/impl/SDPA.cppshape-based kernel selection (is_single_token); no device-adaptive gate, matching the Vulkan delegate.Constraints: decode-only (
S == 1), static input_pos (dynamic-pos decode still uses the materialized path); fp32, buffer-only; the FD kernels are unchanged by this diff.Co-authored with Claude Code.
Differential Revision: D109520722