diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp b/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp index 5801b650f27..065403674d1 100644 --- a/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace executorch::backends::webgpu { @@ -36,14 +37,6 @@ static_assert( sizeof(EmbeddingParams) == 32, "EmbeddingParams must be 32 bytes"); -uint64_t numel_of(const std::vector& dims) { - uint64_t n = 1; - for (int64_t d : dims) { - n *= static_cast(d); - } - return n; -} - // arg order mirrors Vulkan EmbeddingQ4gsw.cpp. void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { const int weight_id = args.at(0); @@ -102,7 +95,7 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { } // Leading index dims flatten row-major (mirrors Vulkan num_indices). - const uint64_t out_numel = numel_of(out.dims); + const uint64_t out_numel = utils::numel_of(out.dims); const uint32_t num_indices = static_cast(out_numel / embed_dim); const uint32_t groups_per_row = static_cast(scales.dims[1]); const uint32_t blocks_per_row = embed_dim / 32u; @@ -119,9 +112,9 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { } // Per-type byte guards (no runtime dtype): indices i32, weight u8, fp32 rest. - const uint64_t indices_numel = numel_of(indices.dims); - const uint64_t weight_numel = numel_of(weight.dims); - const uint64_t scales_numel = numel_of(scales.dims); + const uint64_t indices_numel = utils::numel_of(indices.dims); + const uint64_t weight_numel = utils::numel_of(weight.dims); + const uint64_t scales_numel = utils::numel_of(scales.dims); if (indices_numel != num_indices || indices.nbytes != indices_numel * sizeof(int32_t) || weight.nbytes != weight_numel || @@ -230,13 +223,56 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch( + const size_t dispatch_idx = graph.add_dispatch( {pipeline, bind_group, workgroup_count, "embedding_q4gsw"}); + // Dynamic shapes: recompute counts/dispatch; out = indices + [embed_dim]. + const uint32_t gs_u = static_cast(group_size); + WGPUBuffer params_buf = uniform_buffer; + graph.add_tensor_resize_hook( + indices_id, + [indices_id, + out_id, + embed_dim, + blocks_per_row, + gs_u, + groups_per_row, + bytes_per_row, + wg_size, + dispatch_idx, + params_buf](WebGPUGraph& g) { + const auto& id = g.cur_dims(indices_id); + const uint64_t ni = utils::numel_of(id); + const uint64_t total_blocks = ni * blocks_per_row; + if (total_blocks > UINT32_MAX) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: total_blocks exceeds uint32"); + } + std::vector od = id; + od.push_back(static_cast(embed_dim)); + g.set_cur_dims(out_id, od); + EmbeddingParams p = {}; + p.embed_dim = embed_dim; + p.blocks_per_row = blocks_per_row; + p.num_indices = static_cast(ni); + p.group_size = gs_u; + p.groups_per_row = groups_per_row; + p.bytes_per_row = bytes_per_row; + p.total_blocks = static_cast(total_blocks); + wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); + g.dispatch_at(dispatch_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + g.device(), + static_cast(total_blocks), + wg_size, + "embedding_q4gsw(resize)"); + }); + wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - wgpuBufferRelease(uniform_buffer); + // Graph owns it so the resize hook can rewrite it; freed in the dtor. + graph.own_uniform_buffer(uniform_buffer); } } // namespace diff --git a/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp index e73c6e23a88..5223b6530f4 100644 --- a/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp +++ b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include #include @@ -187,14 +188,38 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector& args) { static_assert( kRmsNormVec4WorkgroupSizeX == 64, "must match @workgroup_size and WG_SIZE in rms_norm_vec4.wgsl"); - graph.add_dispatch({pipeline, bind_group, num_rows}); + const size_t dispatch_idx = + graph.add_dispatch({pipeline, bind_group, num_rows}); + + // Dynamic shapes: recompute num_rows + rewrite the UBO for the live input. + WGPUBuffer params_buf = uniform_buffer; + graph.add_tensor_resize_hook( + in_id, + [in_id, out_id, row_width, epsilon, dispatch_idx, params_buf]( + WebGPUGraph& g) { + const auto& d = g.cur_dims(in_id); + const uint64_t numel = utils::numel_of(d); + const uint32_t rows = + static_cast(numel / static_cast(row_width)); + if (rows == 0 || rows > 65535u) { + throw std::runtime_error( + "WebGPU rms_norm: num_rows exceeds the 1D dispatch limit (65535)"); + } + RmsNormParams p = {}; + p.num_rows = rows; + p.row_width = row_width; + p.epsilon = epsilon; + wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); + g.dispatch_at(dispatch_idx).workgroup_count_x = rows; + g.set_cur_dims(out_id, d); + }); // Release intermediate objects (pipeline and bind_group are kept by dispatch) wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - // Drop our ref; the bind group keeps the uniform buffer alive until release. - wgpuBufferRelease(uniform_buffer); + // Graph owns it so the resize hook can rewrite it; freed in the dtor. + graph.own_uniform_buffer(uniform_buffer); } } // namespace diff --git a/backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp b/backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp index cf4fa0a1ca2..afca26cb9c6 100644 --- a/backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp +++ b/backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp @@ -34,16 +34,9 @@ struct RotaryParams { }; static_assert(sizeof(RotaryParams) == 32, "RotaryParams must be 32 bytes"); -uint64_t numel_of(const std::vector& dims) { - uint64_t n = 1; - for (int64_t d : dims) { - n *= static_cast(d); - } - return n; -} - // Rotate one (x->out) with the shared shader; freqs shared between xq and xk. -void add_rope_dispatch( +// Returns the param-uniform handle so a resize hook can rewrite seq/num_pairs. +WGPUBuffer add_rope_dispatch( WebGPUGraph& graph, WGPUDevice device, WGPUComputePipeline pipeline, @@ -58,7 +51,8 @@ void add_rope_dispatch( uint32_t workgroup_count) { const uint32_t half_dim = head_dim / 2u; // out.dims == in.dims (asserted in impl), so this matches the caller's wgc. - const uint32_t num_pairs = static_cast(numel_of(out.dims) / 2u); + const uint32_t num_pairs = + static_cast(utils::numel_of(out.dims) / 2u); RotaryParams params = {}; params.n_heads = n_heads; @@ -104,7 +98,9 @@ void add_rope_dispatch( graph.add_dispatch( {pipeline, bind_group, workgroup_count, "apply_rotary_emb"}); - wgpuBufferRelease(uniform_buffer); + // Graph owns it so a resize hook can rewrite it; freed in the dtor. + graph.own_uniform_buffer(uniform_buffer); + return uniform_buffer; } // args: [xq, xk, freqs_cos, freqs_sin, out_list(ValueList[xq_out, xk_out])]. @@ -164,9 +160,9 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector& args) { } // All tensors are fp32; output shapes equal their inputs. - const uint64_t xq_numel = numel_of(xq.dims); - const uint64_t xk_numel = numel_of(xk.dims); - const uint64_t freqs_numel = numel_of(freqs_cos.dims); + const uint64_t xq_numel = utils::numel_of(xq.dims); + const uint64_t xk_numel = utils::numel_of(xk.dims); + const uint64_t freqs_numel = utils::numel_of(freqs_cos.dims); if (freqs_numel != static_cast(seq) * half_dim || xq.nbytes != xq_numel * sizeof(float) || xk.nbytes != xk_numel * sizeof(float) || @@ -246,7 +242,7 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector& args) { WGPUComputePipeline pipeline_k = wgpuDeviceCreateComputePipeline(device, &pipeline_desc); - add_rope_dispatch( + WGPUBuffer q_ubuf = add_rope_dispatch( graph, device, pipeline_q, @@ -259,7 +255,8 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector& args) { seq, head_dim, xq_wgc); - add_rope_dispatch( + const size_t q_idx = graph.num_dispatches() - 1; + WGPUBuffer k_ubuf = add_rope_dispatch( graph, device, pipeline_k, @@ -272,6 +269,57 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector& args) { seq, head_dim, xk_wgc); + const size_t k_idx = graph.num_dispatches() - 1; + + // Dynamic shapes: recompute S/num_pairs + both dispatches; out follows xq/xk. + const int xq_out_id = out_list[0]; + const int xk_out_id = out_list[1]; + graph.add_tensor_resize_hook( + xq_id, + [xq_id, + xk_id, + xq_out_id, + xk_out_id, + n_heads_q, + n_heads_k, + head_dim, + half_dim, + wg_size, + q_idx, + k_idx, + q_ubuf, + k_ubuf](WebGPUGraph& g) { + const auto& qd = g.cur_dims(xq_id); + const auto& kd = g.cur_dims(xk_id); + const uint32_t s = static_cast(qd[qd.size() - 3]); + const uint64_t qn = utils::numel_of(qd); + const uint64_t kn = utils::numel_of(kd); + RotaryParams pq = {}; + pq.n_heads = n_heads_q; + pq.seq = s; + pq.head_dim = head_dim; + pq.half_dim = half_dim; + pq.num_pairs = static_cast(qn / 2u); + RotaryParams pk = pq; + pk.n_heads = n_heads_k; + pk.num_pairs = static_cast(kn / 2u); + wgpuQueueWriteBuffer(g.queue(), q_ubuf, 0, &pq, sizeof(pq)); + wgpuQueueWriteBuffer(g.queue(), k_ubuf, 0, &pk, sizeof(pk)); + g.dispatch_at(q_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + g.device(), + static_cast(qn / 2u), + wg_size, + "apply_rotary_emb(resize)"); + g.dispatch_at(k_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + g.device(), + static_cast(kn / 2u), + wg_size, + "apply_rotary_emb(resize)"); + g.set_cur_dims(xq_out_id, qd); + g.set_cur_dims(xk_out_id, kd); + }); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl);