Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 92 additions & 12 deletions backends/webgpu/runtime/ops/slice/Slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,40 @@ read_scalar(WebGPUGraph& graph, int id, int64_t dflt, const char* what) {
}
}

// Read a slice index (start/end) that MAY be a dynamic SymInt; else Int/Null.
int64_t read_index(WebGPUGraph& graph, int id, int64_t dflt) {
switch (graph.get_value_type(id)) {
case WebGPUGraph::ValueType::SymInt:
return graph.read_symint(id);
case WebGPUGraph::ValueType::Int: {
const int64_t v = graph.get_int(id);
return v == INT64_MAX ? dflt : v;
}
case WebGPUGraph::ValueType::Null:
return dflt;
default:
throw std::runtime_error("slice: dynamic/unsupported start/end index");
}
}

bool is_symint(WebGPUGraph& graph, int id) {
return graph.get_value_type(id) == WebGPUGraph::ValueType::SymInt;
}

// Clamp + normalize a (possibly negative) index into [0, size].
int64_t norm_clamp(int64_t idx, int64_t size) {
if (idx < 0) {
idx += size;
}
return idx < 0 ? 0 : (idx > size ? size : idx);
}

void slice_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, dim, start, end, step, out]; end unread (out shape is AOT).
// args: [self, dim, start, end, step, out]. start/end may be dynamic SymInts;
// a resize hook recomputes the live extent on `dim` (out[dim] / cur_dims).
const int in_id = args.at(0);
const int start_id = args.at(2);
const int end_id = args.at(3);
const int out_id = args.at(5);

WGPUDevice device = graph.device();
Expand All @@ -63,17 +94,14 @@ void slice_impl(WebGPUGraph& graph, const std::vector<int>& args) {
if (dim < 0 || dim >= in_ndim) {
throw std::runtime_error("slice: dim out of range");
}
const int64_t in_size = in_tensor.dims[dim];
int64_t start = read_scalar(graph, args.at(2), 0, "start");
if (start < 0) {
start += in_size;
}
// Clamp start to [0, in_size] (guards the gather offset; out size is AOT).
start = start < 0 ? 0 : (start > in_size ? in_size : start);
const int64_t step = read_scalar(graph, args.at(4), 1, "step");
if (step < 1) {
throw std::runtime_error("slice: step must be >= 1");
}
// start/end may be dynamic SymInts; seed from current (max) dims, the resize
// hook recomputes live. Clamp guards the gather offset.
const int64_t in_size = in_tensor.dims[dim];
const int64_t start = norm_clamp(read_index(graph, start_id, 0), in_size);

TensorMeta out_meta;
TensorMeta in_meta;
Expand Down Expand Up @@ -175,14 +203,66 @@ void slice_impl(WebGPUGraph& graph, const std::vector<int>& args) {
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});
const size_t dispatch_idx = graph.num_dispatches() - 1;

// Dynamic shapes: live start/end -> out[dim] len + meta/params/dispatch.
auto recompute = [in_id,
out_id,
start_id,
end_id,
dim,
step,
wg_size,
out_meta_buf,
in_meta_buf,
params_buf,
dispatch_idx](WebGPUGraph& g) {
const auto& in_dims = g.cur_dims(in_id);
const int64_t live_in_size = in_dims[dim];
const int64_t start = norm_clamp(read_index(g, start_id, 0), live_in_size);
const int64_t end =
norm_clamp(read_index(g, end_id, live_in_size), live_in_size);
const int64_t len = end > start ? (end - start + step - 1) / step : 0;

// Out dims = live input dims (mirror Vulkan resize_slice_copy_node).
std::vector<int64_t> od = in_dims;
od[dim] = len;
g.set_cur_dims(out_id, od);

WebGPUTensor t_out;
t_out.dims = od;
WebGPUTensor t_in;
t_in.dims = in_dims;
TensorMeta om;
TensorMeta im;
fill_tensor_meta(t_out, &om);
fill_tensor_meta(t_in, &im);
wgpuQueueWriteBuffer(g.queue(), out_meta_buf, 0, &om, sizeof(om));
wgpuQueueWriteBuffer(g.queue(), in_meta_buf, 0, &im, sizeof(im));
SliceParams p = {};
p.dim = static_cast<uint32_t>(dim);
p.start = static_cast<uint32_t>(start);
p.step = static_cast<uint32_t>(step);
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(), om.numel, wg_size, "slice(resize)");
};
if (is_symint(graph, start_id)) {
graph.add_resize_hook(start_id, recompute);
}
if (is_symint(graph, end_id) && end_id != start_id) {
graph.add_resize_hook(end_id, recompute);
}
graph.add_tensor_resize_hook(in_id, recompute);

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our refs; the bind group keeps the uniforms alive until release.
wgpuBufferRelease(out_meta_buf);
wgpuBufferRelease(in_meta_buf);
wgpuBufferRelease(params_buf);
// Graph owns the uniforms so the resize hook can rewrite them; freed in dtor.
graph.own_uniform_buffer(out_meta_buf);
graph.own_uniform_buffer(in_meta_buf);
graph.own_uniform_buffer(params_buf);
}

} // namespace
Expand Down
Loading