Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <vector>

namespace executorch::backends::webgpu {

Expand All @@ -36,14 +37,6 @@ static_assert(
sizeof(EmbeddingParams) == 32,
"EmbeddingParams must be 32 bytes");

uint64_t numel_of(const std::vector<int64_t>& dims) {
uint64_t n = 1;
for (int64_t d : dims) {
n *= static_cast<uint64_t>(d);
}
return n;
}

// arg order mirrors Vulkan EmbeddingQ4gsw.cpp.
void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int weight_id = args.at(0);
Expand Down Expand Up @@ -102,7 +95,7 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
}

// Leading index dims flatten row-major (mirrors Vulkan num_indices).
const uint64_t out_numel = numel_of(out.dims);
const uint64_t out_numel = utils::numel_of(out.dims);
const uint32_t num_indices = static_cast<uint32_t>(out_numel / embed_dim);
const uint32_t groups_per_row = static_cast<uint32_t>(scales.dims[1]);
const uint32_t blocks_per_row = embed_dim / 32u;
Expand All @@ -119,9 +112,9 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
}

// Per-type byte guards (no runtime dtype): indices i32, weight u8, fp32 rest.
const uint64_t indices_numel = numel_of(indices.dims);
const uint64_t weight_numel = numel_of(weight.dims);
const uint64_t scales_numel = numel_of(scales.dims);
const uint64_t indices_numel = utils::numel_of(indices.dims);
const uint64_t weight_numel = utils::numel_of(weight.dims);
const uint64_t scales_numel = utils::numel_of(scales.dims);
if (indices_numel != num_indices ||
indices.nbytes != indices_numel * sizeof(int32_t) ||
weight.nbytes != weight_numel ||
Expand Down Expand Up @@ -230,13 +223,56 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch(
const size_t dispatch_idx = graph.add_dispatch(
{pipeline, bind_group, workgroup_count, "embedding_q4gsw"});

// Dynamic shapes: recompute counts/dispatch; out = indices + [embed_dim].
const uint32_t gs_u = static_cast<uint32_t>(group_size);
WGPUBuffer params_buf = uniform_buffer;
graph.add_tensor_resize_hook(
indices_id,
[indices_id,
out_id,
embed_dim,
blocks_per_row,
gs_u,
groups_per_row,
bytes_per_row,
wg_size,
dispatch_idx,
params_buf](WebGPUGraph& g) {
const auto& id = g.cur_dims(indices_id);
const uint64_t ni = utils::numel_of(id);
const uint64_t total_blocks = ni * blocks_per_row;
if (total_blocks > UINT32_MAX) {
throw std::runtime_error(
"WebGPU embedding_q4gsw: total_blocks exceeds uint32");
}
std::vector<int64_t> od = id;
od.push_back(static_cast<int64_t>(embed_dim));
g.set_cur_dims(out_id, od);
EmbeddingParams p = {};
p.embed_dim = embed_dim;
p.blocks_per_row = blocks_per_row;
p.num_indices = static_cast<uint32_t>(ni);
p.group_size = gs_u;
p.groups_per_row = groups_per_row;
p.bytes_per_row = bytes_per_row;
p.total_blocks = static_cast<uint32_t>(total_blocks);
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(),
static_cast<uint32_t>(total_blocks),
wg_size,
"embedding_q4gsw(resize)");
});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
wgpuBufferRelease(uniform_buffer);
// Graph owns it so the resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
}

} // namespace
Expand Down
31 changes: 28 additions & 3 deletions backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h>
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h>
Expand Down Expand Up @@ -187,14 +188,38 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
static_assert(
kRmsNormVec4WorkgroupSizeX == 64,
"must match @workgroup_size and WG_SIZE in rms_norm_vec4.wgsl");
graph.add_dispatch({pipeline, bind_group, num_rows});
const size_t dispatch_idx =
graph.add_dispatch({pipeline, bind_group, num_rows});

// Dynamic shapes: recompute num_rows + rewrite the UBO for the live input.
WGPUBuffer params_buf = uniform_buffer;
graph.add_tensor_resize_hook(
in_id,
[in_id, out_id, row_width, epsilon, dispatch_idx, params_buf](
WebGPUGraph& g) {
const auto& d = g.cur_dims(in_id);
const uint64_t numel = utils::numel_of(d);
const uint32_t rows =
static_cast<uint32_t>(numel / static_cast<uint64_t>(row_width));
if (rows == 0 || rows > 65535u) {
throw std::runtime_error(
"WebGPU rms_norm: num_rows exceeds the 1D dispatch limit (65535)");
}
RmsNormParams p = {};
p.num_rows = rows;
p.row_width = row_width;
p.epsilon = epsilon;
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x = rows;
g.set_cur_dims(out_id, d);
});

// Release intermediate objects (pipeline and bind_group are kept by dispatch)
wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our ref; the bind group keeps the uniform buffer alive until release.
wgpuBufferRelease(uniform_buffer);
// Graph owns it so the resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
}

} // namespace
Expand Down
80 changes: 64 additions & 16 deletions backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,9 @@ struct RotaryParams {
};
static_assert(sizeof(RotaryParams) == 32, "RotaryParams must be 32 bytes");

uint64_t numel_of(const std::vector<int64_t>& dims) {
uint64_t n = 1;
for (int64_t d : dims) {
n *= static_cast<uint64_t>(d);
}
return n;
}

// Rotate one (x->out) with the shared shader; freqs shared between xq and xk.
void add_rope_dispatch(
// Returns the param-uniform handle so a resize hook can rewrite seq/num_pairs.
WGPUBuffer add_rope_dispatch(
WebGPUGraph& graph,
WGPUDevice device,
WGPUComputePipeline pipeline,
Expand All @@ -58,7 +51,8 @@ void add_rope_dispatch(
uint32_t workgroup_count) {
const uint32_t half_dim = head_dim / 2u;
// out.dims == in.dims (asserted in impl), so this matches the caller's wgc.
const uint32_t num_pairs = static_cast<uint32_t>(numel_of(out.dims) / 2u);
const uint32_t num_pairs =
static_cast<uint32_t>(utils::numel_of(out.dims) / 2u);

RotaryParams params = {};
params.n_heads = n_heads;
Expand Down Expand Up @@ -104,7 +98,9 @@ void add_rope_dispatch(
graph.add_dispatch(
{pipeline, bind_group, workgroup_count, "apply_rotary_emb"});

wgpuBufferRelease(uniform_buffer);
// Graph owns it so a resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
return uniform_buffer;
}

// args: [xq, xk, freqs_cos, freqs_sin, out_list(ValueList[xq_out, xk_out])].
Expand Down Expand Up @@ -164,9 +160,9 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector<int>& args) {
}

// All tensors are fp32; output shapes equal their inputs.
const uint64_t xq_numel = numel_of(xq.dims);
const uint64_t xk_numel = numel_of(xk.dims);
const uint64_t freqs_numel = numel_of(freqs_cos.dims);
const uint64_t xq_numel = utils::numel_of(xq.dims);
const uint64_t xk_numel = utils::numel_of(xk.dims);
const uint64_t freqs_numel = utils::numel_of(freqs_cos.dims);
if (freqs_numel != static_cast<uint64_t>(seq) * half_dim ||
xq.nbytes != xq_numel * sizeof(float) ||
xk.nbytes != xk_numel * sizeof(float) ||
Expand Down Expand Up @@ -246,7 +242,7 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector<int>& args) {
WGPUComputePipeline pipeline_k =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

add_rope_dispatch(
WGPUBuffer q_ubuf = add_rope_dispatch(
graph,
device,
pipeline_q,
Expand All @@ -259,7 +255,8 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector<int>& args) {
seq,
head_dim,
xq_wgc);
add_rope_dispatch(
const size_t q_idx = graph.num_dispatches() - 1;
WGPUBuffer k_ubuf = add_rope_dispatch(
graph,
device,
pipeline_k,
Expand All @@ -272,6 +269,57 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector<int>& args) {
seq,
head_dim,
xk_wgc);
const size_t k_idx = graph.num_dispatches() - 1;

// Dynamic shapes: recompute S/num_pairs + both dispatches; out follows xq/xk.
const int xq_out_id = out_list[0];
const int xk_out_id = out_list[1];
graph.add_tensor_resize_hook(
xq_id,
[xq_id,
xk_id,
xq_out_id,
xk_out_id,
n_heads_q,
n_heads_k,
head_dim,
half_dim,
wg_size,
q_idx,
k_idx,
q_ubuf,
k_ubuf](WebGPUGraph& g) {
const auto& qd = g.cur_dims(xq_id);
const auto& kd = g.cur_dims(xk_id);
const uint32_t s = static_cast<uint32_t>(qd[qd.size() - 3]);
const uint64_t qn = utils::numel_of(qd);
const uint64_t kn = utils::numel_of(kd);
RotaryParams pq = {};
pq.n_heads = n_heads_q;
pq.seq = s;
pq.head_dim = head_dim;
pq.half_dim = half_dim;
pq.num_pairs = static_cast<uint32_t>(qn / 2u);
RotaryParams pk = pq;
pk.n_heads = n_heads_k;
pk.num_pairs = static_cast<uint32_t>(kn / 2u);
wgpuQueueWriteBuffer(g.queue(), q_ubuf, 0, &pq, sizeof(pq));
wgpuQueueWriteBuffer(g.queue(), k_ubuf, 0, &pk, sizeof(pk));
g.dispatch_at(q_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(),
static_cast<uint32_t>(qn / 2u),
wg_size,
"apply_rotary_emb(resize)");
g.dispatch_at(k_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(),
static_cast<uint32_t>(kn / 2u),
wg_size,
"apply_rotary_emb(resize)");
g.set_cur_dims(xq_out_id, qd);
g.set_cur_dims(xk_out_id, kd);
});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
Expand Down
Loading