From caedbba934f775be70ca4a5d7d65e16646276298 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 28 Jun 2026 09:22:38 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/runtime/ops/slice/Slice.cpp | 98 ++++++++++++++++++--- 1 file changed, 87 insertions(+), 11 deletions(-) diff --git a/backends/webgpu/runtime/ops/slice/Slice.cpp b/backends/webgpu/runtime/ops/slice/Slice.cpp index 1d4406bbd1a..604d19b2422 100644 --- a/backends/webgpu/runtime/ops/slice/Slice.cpp +++ b/backends/webgpu/runtime/ops/slice/Slice.cpp @@ -46,9 +46,37 @@ read_scalar(WebGPUGraph& graph, int id, int64_t dflt, const char* what) { } } +// Read a slice index (start/end) that MAY be a dynamic SymInt; else Int/Null. +int64_t read_index(WebGPUGraph& graph, int id, int64_t dflt) { + switch (graph.get_value_type(id)) { + case WebGPUGraph::ValueType::SymInt: + return graph.read_symint(id); + case WebGPUGraph::ValueType::Int: { + const int64_t v = graph.get_int(id); + return v == INT64_MAX ? dflt : v; + } + default: + return dflt; + } +} + +bool is_symint(WebGPUGraph& graph, int id) { + return graph.get_value_type(id) == WebGPUGraph::ValueType::SymInt; +} + +// Clamp + normalize a (possibly negative) index into [0, size]. +int64_t norm_clamp(int64_t idx, int64_t size) { + if (idx < 0) { + idx += size; + } + return idx < 0 ? 0 : (idx > size ? size : idx); +} + void slice_impl(WebGPUGraph& graph, const std::vector& args) { // args: [self, dim, start, end, step, out]; end unread (out shape is AOT). const int in_id = args.at(0); + const int start_id = args.at(2); + const int end_id = args.at(3); const int out_id = args.at(5); WGPUDevice device = graph.device(); @@ -63,17 +91,14 @@ void slice_impl(WebGPUGraph& graph, const std::vector& args) { if (dim < 0 || dim >= in_ndim) { throw std::runtime_error("slice: dim out of range"); } - const int64_t in_size = in_tensor.dims[dim]; - int64_t start = read_scalar(graph, args.at(2), 0, "start"); - if (start < 0) { - start += in_size; - } - // Clamp start to [0, in_size] (guards the gather offset; out size is AOT). - start = start < 0 ? 0 : (start > in_size ? in_size : start); const int64_t step = read_scalar(graph, args.at(4), 1, "step"); if (step < 1) { throw std::runtime_error("slice: step must be >= 1"); } + // start/end may be dynamic SymInts; seed from current (max) dims, the resize + // hook recomputes live. Clamp guards the gather offset. + const int64_t in_size = in_tensor.dims[dim]; + const int64_t start = norm_clamp(read_index(graph, start_id, 0), in_size); TensorMeta out_meta; TensorMeta in_meta; @@ -175,14 +200,65 @@ void slice_impl(WebGPUGraph& graph, const std::vector& args) { WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); graph.add_dispatch({pipeline, bind_group, workgroup_count}); + const size_t dispatch_idx = graph.num_dispatches() - 1; + + // Dynamic shapes: live start/end -> out[dim] len + meta/params/dispatch. + auto recompute = [in_id, + out_id, + start_id, + end_id, + dim, + step, + wg_size, + out_meta_buf, + in_meta_buf, + params_buf, + dispatch_idx](WebGPUGraph& g) { + const auto& in_dims = g.cur_dims(in_id); + const int64_t live_in_size = in_dims[dim]; + const int64_t start = norm_clamp(read_index(g, start_id, 0), live_in_size); + const int64_t end = + norm_clamp(read_index(g, end_id, live_in_size), live_in_size); + const int64_t len = end > start ? (end - start + step - 1) / step : 0; + + std::vector od = g.cur_dims(out_id); + od[dim] = len; + g.set_cur_dims(out_id, od); + + WebGPUTensor t_out; + t_out.dims = od; + WebGPUTensor t_in; + t_in.dims = in_dims; + TensorMeta om; + TensorMeta im; + fill_tensor_meta(t_out, &om); + fill_tensor_meta(t_in, &im); + wgpuQueueWriteBuffer(g.queue(), out_meta_buf, 0, &om, sizeof(om)); + wgpuQueueWriteBuffer(g.queue(), in_meta_buf, 0, &im, sizeof(im)); + SliceParams p = {}; + p.dim = static_cast(dim); + p.start = static_cast(start); + p.step = static_cast(step); + wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); + g.dispatch_at(dispatch_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + g.device(), om.numel, wg_size, "slice(resize)"); + }; + if (is_symint(graph, start_id)) { + graph.add_resize_hook(start_id, recompute); + } + if (is_symint(graph, end_id)) { + graph.add_resize_hook(end_id, recompute); + } + graph.add_tensor_resize_hook(in_id, recompute); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - // Drop our refs; the bind group keeps the uniforms alive until release. - wgpuBufferRelease(out_meta_buf); - wgpuBufferRelease(in_meta_buf); - wgpuBufferRelease(params_buf); + // Graph owns the uniforms so the resize hook can rewrite them; freed in dtor. + graph.own_uniform_buffer(out_meta_buf); + graph.own_uniform_buffer(in_meta_buf); + graph.own_uniform_buffer(params_buf); } } // namespace