Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions backends/webgpu/runtime/ops/select/Select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ int64_t read_scalar(WebGPUGraph& graph, int id, const char* what) {
throw std::runtime_error(std::string("select: dynamic/unsupported ") + what);
}

// Build a TensorMeta from live dims, write it to buf, return numel.
uint32_t write_meta_from_dims(
WebGPUGraph& g,
WGPUBuffer buf,
const std::vector<int64_t>& dims) {
WebGPUTensor t;
t.dims = dims;
TensorMeta m;
fill_tensor_meta(t, &m);
wgpuQueueWriteBuffer(g.queue(), buf, 0, &m, sizeof(m));
return m.numel;
}

void select_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, dim, index, out]; output rank = in rank - 1.
const int in_id = args.at(0);
Expand All @@ -58,10 +71,9 @@ void select_impl(WebGPUGraph& graph, const std::vector<int>& args) {
throw std::runtime_error("select: dim out of range");
}
const int64_t in_size = in_tensor.dims[dim];
int64_t index = read_scalar(graph, args.at(2), "index");
if (index < 0) {
index += in_size;
}
// Keep the RAW index: -1 normalizes against the LIVE dim (the resize hook).
const int64_t raw_index = read_scalar(graph, args.at(2), "index");
int64_t index = raw_index < 0 ? raw_index + in_size : raw_index;
if (index < 0 || index >= in_size) {
throw std::runtime_error("select: index out of range");
}
Expand Down Expand Up @@ -164,15 +176,56 @@ void select_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});
const size_t dispatch_idx =
graph.add_dispatch({pipeline, bind_group, workgroup_count});

// Dynamic shapes: out = in minus `dim`; re-resolve index, meta, dispatch.
graph.add_tensor_resize_hook(
in_id,
[in_id,
out_id,
dim,
raw_index,
out_meta_buf,
in_meta_buf,
params_buf,
wg_size,
dispatch_idx](WebGPUGraph& g) {
const auto& ind = g.cur_dims(in_id);
if (dim < 0 || dim >= static_cast<int>(ind.size())) {
throw std::runtime_error("select(resize): dim out of range");
}
const int64_t live_in_size = ind[dim];
int64_t idx = raw_index < 0 ? raw_index + live_in_size : raw_index;
if (idx < 0 || idx >= live_in_size) {
throw std::runtime_error("select(resize): index out of range");
}
std::vector<int64_t> od;
od.reserve(ind.size() - 1);
for (size_t k = 0; k < ind.size(); k++) {
if (static_cast<int>(k) != dim) {
od.push_back(ind[k]);
}
}
g.set_cur_dims(out_id, od);
const uint32_t out_numel = write_meta_from_dims(g, out_meta_buf, od);
write_meta_from_dims(g, in_meta_buf, ind);
SelectParams p = {};
p.dim = static_cast<uint32_t>(dim);
p.index = static_cast<uint32_t>(idx);
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(), out_numel, wg_size, "select(resize)");
});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our refs; the bind group keeps the uniforms alive until release.
wgpuBufferRelease(out_meta_buf);
wgpuBufferRelease(in_meta_buf);
wgpuBufferRelease(params_buf);
// Graph owns them so the resize hook can rewrite them; freed in the dtor.
graph.own_uniform_buffer(out_meta_buf);
graph.own_uniform_buffer(in_meta_buf);
graph.own_uniform_buffer(params_buf);
}

} // namespace
Expand Down
26 changes: 23 additions & 3 deletions backends/webgpu/runtime/ops/sigmoid/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,34 @@ void add_unary_op(
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});
const size_t dispatch_idx =
graph.add_dispatch({pipeline, bind_group, workgroup_count});

// Dynamic shapes: recompute num_elements/dispatch for the live shape.
WGPUBuffer params_buf = uniform_buffer;
graph.add_tensor_resize_hook(
in_id,
[in_id, out_id, wg_size, dispatch_idx, params_buf](WebGPUGraph& g) {
const auto& d = g.cur_dims(in_id);
const uint64_t numel = utils::numel_of(d);
g.set_cur_dims(out_id, d);
UnaryParams p = {};
p.num_elements = static_cast<uint32_t>(numel);
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(),
static_cast<uint32_t>(numel),
wg_size,
"unary(resize)");
});

// Release intermediates (pipeline + bind_group are kept by dispatch).
wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our ref; the bind group keeps the uniform buffer alive until release.
wgpuBufferRelease(uniform_buffer);
// Graph owns it so the resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
}

void sigmoid_impl(WebGPUGraph& graph, const std::vector<int>& args) {
Expand Down
Loading