From d2a1ecc1ba308b6d346888f6045e04f40c8ed215 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 28 Jun 2026 09:22:34 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/runtime/ops/sdpa/Sdpa.cpp | 179 +++++++++++++--------- 1 file changed, 106 insertions(+), 73 deletions(-) diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index c5bcaf05a0e..9ce7148821c 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -256,7 +256,7 @@ static WGPUBuffer record_update_cache_dispatch( uint32_t kv_dst_offset, uint64_t cache_numel, uint32_t uc_wg, - bool dynamic_pos, + bool retain_uniform, const char* label) { const uint32_t wgc = utils::compute_1d_workgroup_count( device, static_cast(kv_numel), uc_wg, label); @@ -274,7 +274,7 @@ static WGPUBuffer record_update_cache_dispatch( sizeof(uc), wgc, uc_wg, - dynamic_pos, + retain_uniform, "update_cache"); return ubuf; } @@ -436,7 +436,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { // Dynamic input_pos: the resize hook rewrites these per step. WGPUBuffer uc_k_buf = nullptr, uc_v_buf = nullptr, qk_buf = nullptr, softmax_buf = nullptr, av_buf = nullptr; - size_t qk_idx = 0; + size_t qk_idx = 0, uc_k_idx = 0, uc_v_idx = 0, softmax_idx = 0, av_idx = 0; const WGPUDevice device = graph.device(); const uint32_t uc_wg = @@ -461,7 +461,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { kv_dst_offset, numel(k_cache), uc_wg, - dynamic_pos, + true, "update_cache(K)"); uc_v_buf = record_update_cache_dispatch( graph, @@ -472,8 +472,10 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { kv_dst_offset, numel(v_cache), uc_wg, - dynamic_pos, + true, "update_cache(V)"); + uc_k_idx = graph.num_dispatches() - 2; + uc_v_idx = graph.num_dispatches() - 1; // --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile. { @@ -501,7 +503,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { sizeof(p), wgc, qk_wg, - dynamic_pos, + true, "sdpa_compute_attn_weights"); qk_buf = ubuf; qk_idx = graph.num_dispatches() - 1; @@ -525,9 +527,10 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { sizeof(p), wgc, 0, - dynamic_pos, + true, "sdpa_softmax"); softmax_buf = ubuf; + softmax_idx = graph.num_dispatches() - 1; } // --- Dispatch 5: AV -> out. One thread per TM x TN tile. @@ -551,77 +554,107 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { sizeof(p), wgc, av_wg, - dynamic_pos, + true, "sdpa_compute_out"); av_buf = ubuf; + av_idx = graph.num_dispatches() - 1; } - // Per-step recompute hook; mirrors Vulkan DynamicDispatchNode. + // Per-step recompute: live S (q resize) or input_pos (SymInt); inert if + // static. + const int64_t pos_const = input_pos; + auto sdpa_resize = [q_id, + qn, + out_id, + dynamic_pos, + input_pos_id, + pos_const, + Hq, + Hkv, + D, + Cmax, + g, + scale, + qk_idx, + uc_k_idx, + uc_v_idx, + softmax_idx, + av_idx, + uc_wg, + qk_wg, + av_wg, + uc_k_buf, + uc_v_buf, + qk_buf, + softmax_buf, + av_buf](WebGPUGraph& gr) { + const int64_t s = gr.cur_dims(q_id)[qn - 3]; + const int64_t pos = dynamic_pos + ? static_cast(gr.read_symint(input_pos_id)) + : pos_const; + if (s <= 0 || pos < 0) { + throw std::runtime_error("WebGPU sdpa: invalid live S or input_pos"); + } + const int64_t ctx = s + pos; + if (ctx <= 0 || ctx > Cmax) { + throw std::runtime_error( + "WebGPU sdpa: context_len exceeds cache capacity"); + } + const uint32_t kv_off = static_cast( + static_cast(pos) * static_cast(Hkv) * + static_cast(D)); + const uint64_t aw_floats = static_cast(Hq) * + static_cast(s) * static_cast(ctx); + if (aw_floats > UINT32_MAX) { + throw std::runtime_error("WebGPU sdpa: Hq*S*context_len exceeds uint32"); + } + const uint64_t kv_numel = static_cast(s) * + static_cast(Hkv) * static_cast(D); + const uint64_t k_cache_numel = static_cast(Cmax) * + static_cast(Hkv) * static_cast(D); + + // update_cache K/V: dispatch (kv_numel) + dst offset scale with live S/pos. + UpdateCacheParams uc = + make_update_cache_params(kv_numel, kv_off, k_cache_numel); + wgpuQueueWriteBuffer(gr.queue(), uc_k_buf, 0, &uc, sizeof(uc)); + wgpuQueueWriteBuffer(gr.queue(), uc_v_buf, 0, &uc, sizeof(uc)); + const uint32_t uc_wgc = utils::compute_1d_workgroup_count( + gr.device(), static_cast(kv_numel), uc_wg, "uc(resize)"); + gr.dispatch_at(uc_k_idx).workgroup_count_x = uc_wgc; + gr.dispatch_at(uc_v_idx).workgroup_count_x = uc_wgc; + + // QK: one thread per TM x TN tile; grid = Hq*ceil(S/TM)*ceil(ctx/TN). + AttnWeightsParams qp = + make_attn_weights_params(s, Hq, Hkv, D, ctx, pos, g, scale); + wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp)); + const int64_t qk_tiles = + Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(ctx, kSdpaTileN); + gr.dispatch_at(qk_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + gr.device(), static_cast(qk_tiles), qk_wg, "QK(resize)"); + + // softmax: one workgroup per (h,s) row. + SoftmaxParams sp = make_softmax_params(Hq, s, ctx); + wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp)); + gr.dispatch_at(softmax_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + gr.device(), static_cast(Hq * s), 1, "softmax(resize)"); + + // AV: one thread per TM x TN tile; grid = Hq*ceil(S/TM)*ceil(D/TN). + ComputeOutParams op = make_compute_out_params(s, Hq, Hkv, D, ctx, g); + wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op)); + const int64_t av_tiles = + Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(D, kSdpaTileN); + gr.dispatch_at(av_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + gr.device(), static_cast(av_tiles), av_wg, "AV(resize)"); + + // Output attn has the same shape as q: [.., S, Hq, D]. + gr.set_cur_dims(out_id, gr.cur_dims(q_id)); + }; + graph.add_tensor_resize_hook(q_id, sdpa_resize); if (dynamic_pos) { - graph.add_resize_hook( - input_pos_id, - [input_pos_id, - S, - Hq, - Hkv, - D, - Cmax, - g, - scale, - qk_idx, - qk_wg, - uc_k_buf, - uc_v_buf, - qk_buf, - softmax_buf, - av_buf](WebGPUGraph& gr) { - const int32_t pos = gr.read_symint(input_pos_id); - if (pos < 0) { - throw std::runtime_error( - "WebGPU sdpa: input_pos must be non-negative"); - } - const int64_t ctx = S + pos; - if (ctx <= 0 || ctx > Cmax) { - throw std::runtime_error( - "WebGPU sdpa: context_len exceeds cache capacity"); - } - const uint32_t kv_off = static_cast( - static_cast(pos) * static_cast(Hkv) * - static_cast(D)); - const uint64_t aw_floats = static_cast(Hq) * - static_cast(S) * static_cast(ctx); - if (aw_floats > UINT32_MAX) { - throw std::runtime_error( - "WebGPU sdpa: Hq*S*context_len exceeds uint32 max"); - } - const uint64_t kv_numel = static_cast(S) * - static_cast(Hkv) * static_cast(D); - const uint64_t k_cache_numel = static_cast(Cmax) * - static_cast(Hkv) * static_cast(D); - - UpdateCacheParams uc = - make_update_cache_params(kv_numel, kv_off, k_cache_numel); - wgpuQueueWriteBuffer(gr.queue(), uc_k_buf, 0, &uc, sizeof(uc)); - wgpuQueueWriteBuffer(gr.queue(), uc_v_buf, 0, &uc, sizeof(uc)); - - AttnWeightsParams qp = - make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale); - wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp)); - const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) * - utils::div_up(ctx, kSdpaTileN); - const uint32_t qk_wgc = utils::compute_1d_workgroup_count( - gr.device(), - static_cast(qk_tiles), - qk_wg, - "QK(resize)"); - gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc; - - SoftmaxParams sp = make_softmax_params(Hq, S, ctx); - wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp)); - - ComputeOutParams op = make_compute_out_params(S, Hq, Hkv, D, ctx, g); - wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op)); - }); + graph.add_resize_hook(input_pos_id, sdpa_resize); } }