diff --git a/backends/webgpu/runtime/WebGPUBackend.cpp b/backends/webgpu/runtime/WebGPUBackend.cpp index ceca89d1710..eba406426d9 100644 --- a/backends/webgpu/runtime/WebGPUBackend.cpp +++ b/backends/webgpu/runtime/WebGPUBackend.cpp @@ -14,8 +14,11 @@ #include #include +#include #include +#include + #include namespace executorch { @@ -35,6 +38,7 @@ using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::FreeableBuffer; using executorch::runtime::register_backend; +using executorch::runtime::resize_tensor; using executorch::runtime::Result; using executorch::runtime::Span; @@ -108,11 +112,32 @@ Error WebGPUBackend::execute( } // Fail loud as a runtime Error so a throw never crosses the backend boundary. try { + // Dynamic shapes: shrink each input to its live sizes before upload + // (mirrors Vulkan maybe_resize_input). No-op when unchanged, so a static + // graph is byte-identical. + for (size_t i = 0; i < num_inputs; i++) { + const auto sizes = args[i]->toTensor().sizes(); + std::vector new_dims(sizes.begin(), sizes.end()); + graph->resize_input(graph->input_ids()[i], new_dims); + } graph->copy_inputs(inputs); graph->update_symints_from_inputs(inputs); graph->propagate_resize(); + // Resize each output EValue to its live shape so the readback length is + // correct (mirrors Vulkan maybe_resize_output). + for (size_t i = 0; i < num_outputs; i++) { + const auto& cd = graph->cur_dims(graph->output_ids()[i]); + std::vector osizes(cd.begin(), cd.end()); + Error e = resize_tensor( + args[num_inputs + i]->toTensor(), + ArrayRef(osizes.data(), osizes.size())); + if (e != Error::Ok) { + ET_LOG(Error, "WebGPU: output %zu resize failed", i); + return Error::Internal; + } + } } catch (const std::exception& e) { - ET_LOG(Error, "WebGPU input copy / symint refresh failed: %s", e.what()); + ET_LOG(Error, "WebGPU input/output resize / copy failed: %s", e.what()); return Error::Internal; } diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index b72c5256e69..6d3956ea994 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -62,6 +63,18 @@ bool vk_datatype_is_int(vkgraph::VkDataType dtype) { } } +// Normalize a possibly-negative dim against rank; throws (fail-loud) if OOR. +int normalize_dim(int dim, int rank, const char* op) { + if (dim < 0) { + dim += rank; + } + if (dim < 0 || dim >= rank) { + throw std::runtime_error( + std::string("WebGPU ") + op + ": dim out of range"); + } + return dim; +} + } // namespace WebGPUGraph::WebGPUGraph() = default; @@ -104,11 +117,10 @@ void WebGPUGraph::update_symints_from_inputs( throw std::runtime_error( "select_as_symint: source tensor is not a graph input"); } - const auto& dims = tensors_[src.input_tensor_id].dims; - int dim = src.dim < 0 ? src.dim + static_cast(dims.size()) : src.dim; - if (dim < 0 || dim >= static_cast(dims.size())) { - throw std::runtime_error("select_as_symint: dim out of range"); - } + // Live cur_dims: the source may be a dynamic-shape input. + const auto& dims = tensors_[src.input_tensor_id].cur_dims; + int dim = normalize_dim( + src.dim, static_cast(dims.size()), "select_as_symint"); int index = src.index; if (index < 0) { index += static_cast(dims[dim]); @@ -129,9 +141,9 @@ void WebGPUGraph::update_symints_from_inputs( } // Reads the [0,..,index,..,0] element; symint sources are scalar-ish. const int64_t offset = static_cast(index) * stride; - // elem_size back-derived from build-time numel (sources are static-shaped). const void* host = inputs[pos].data; - const size_t elem_size = inputs[pos].nbytes / static_cast(numel); + // Stored elem_size (live nbytes/numel mis-derives for a dynamic source). + const size_t elem_size = tensors_[src.input_tensor_id].elem_size; int32_t val; if (elem_size == sizeof(int64_t)) { val = static_cast(static_cast(host)[offset]); @@ -143,6 +155,14 @@ void WebGPUGraph::update_symints_from_inputs( } set_symint(src.symint_id, val); } + // sym_size.int: SymInt = a tensor's live dim (cur_dims). Usually unused (ops + // read cur_dims directly); for an intermediate source cur_dims is the build + // max here (hooks run later in propagate_resize), which is fine while unused. + for (const auto& s : symint_dim_sources_) { + const auto& d = tensors_[s.tensor_id].cur_dims; + int dim = normalize_dim(s.dim, static_cast(d.size()), "sym_size"); + set_symint(s.symint_id, static_cast(d[dim])); + } } void WebGPUGraph::set_symint(int id, int32_t val) { @@ -158,16 +178,78 @@ void WebGPUGraph::set_symint(int id, int32_t val) { } } +void WebGPUGraph::set_cur_dims( + int value_id, + const std::vector& new_dims) { + auto& t = tensors_[value_id]; + if (new_dims.size() != t.dims.size()) { + throw std::runtime_error("WebGPU resize: tensor rank changed"); + } + size_t numel = 1; + for (size_t d = 0; d < new_dims.size(); d++) { + // 0-sized dims unsupported: live shapes are always in [1, max] per dim. + if (new_dims[d] <= 0) { + throw std::runtime_error("WebGPU resize: new dim must be positive"); + } + if (new_dims[d] > t.dims[d]) { + throw std::runtime_error( + "WebGPU resize: new dim exceeds the max (serialized) allocation"); + } + numel *= static_cast(new_dims[d]); + } + const size_t new_nbytes = numel * t.elem_size; + if (t.cur_dims != new_dims) { + t.cur_dims = new_dims; + t.cur_nbytes = new_nbytes; + dirty_tensors_.insert(value_id); + } +} + +void WebGPUGraph::resize_input( + int value_id, + const std::vector& new_dims) { + if (std::find(input_ids_.begin(), input_ids_.end(), value_id) == + input_ids_.end()) { + throw std::runtime_error( + "WebGPUGraph::resize_input: value_id is not a graph input"); + } + set_cur_dims(value_id, new_dims); +} + void WebGPUGraph::propagate_resize() { - if (dirty_symints_.empty()) { + if (dirty_symints_.empty() && dirty_tensors_.empty()) { return; } + // Hooks fire in registration (topological) order: operands update first. for (auto& hook : resize_hooks_) { if (dirty_symints_.count(hook.symint_id) != 0) { hook.fn(*this); } } dirty_symints_.clear(); + // Tensor hooks: bounded fixpoint. A hook may dirty its output (cascading to a + // consumer); each pass handles the currently-dirty set. A forward DAG + // converges in <= depth passes (set_cur_dims re-dirties only on a change). + for (size_t pass = 0; + !dirty_tensors_.empty() && pass <= tensor_resize_hooks_.size(); + pass++) { + std::unordered_set processing; + processing.swap(dirty_tensors_); + for (auto& hook : tensor_resize_hooks_) { + if (processing.count(hook.trigger_tensor_id) != 0) { + hook.fn(*this); + } + } + } + if (!dirty_tensors_.empty()) { + throw std::runtime_error( + "WebGPU resize: tensor resize hooks did not converge"); + } + // Tensor hooks must not set_symint (dirty_symints_ already drained above). + if (!dirty_symints_.empty()) { + throw std::runtime_error( + "WebGPU resize: a tensor resize hook set a SymInt; not supported"); + } } WebGPUGraph::~WebGPUGraph() { @@ -322,6 +404,10 @@ void WebGPUGraph::build( tensor.elem_size = vk_datatype_size(vk_tensor->datatype()); tensor.is_int = vk_datatype_is_int(vk_tensor->datatype()); tensor.nbytes = numel * tensor.elem_size; + // Live dims start == max (serialized upper bound); resize_input shrinks + // them per call. Static graphs keep cur == max forever. + tensor.cur_dims = tensor.dims; + tensor.cur_nbytes = tensor.nbytes; int constant_id = vk_tensor->constant_id(); int mem_obj_id = vk_tensor->mem_obj_id(); @@ -624,17 +710,20 @@ void WebGPUGraph::copy_inputs(const std::vector& inputs) { } int tid = input_ids_[i]; const auto& tensor = tensors_[tid]; + // Upload only the live (cur) bytes, not the max allocation; cur_nbytes == + // nbytes on a static graph, so this is byte-identical there. + const size_t live_nbytes = tensor.cur_nbytes; // Fast path: host and GPU element types match byte-for-byte. - if (in.nbytes == tensor.nbytes) { - wgpuQueueWriteBuffer(queue_, tensor.buffer, 0, in.data, tensor.nbytes); + if (in.nbytes == live_nbytes) { + wgpuQueueWriteBuffer(queue_, tensor.buffer, 0, in.data, live_nbytes); continue; } // Narrow int64 host indices into the int32 buffer (mirrors Vulkan). const bool buffer_is_int32 = tensor.is_int && tensor.elem_size == 4; - if (in.host_is_int64 && buffer_is_int32 && in.nbytes == tensor.nbytes * 2) { - const size_t numel = tensor.nbytes / 4; + if (in.host_is_int64 && buffer_is_int32 && in.nbytes == live_nbytes * 2) { + const size_t numel = live_nbytes / 4; const int64_t* src = static_cast(in.data); std::vector narrowed(numel); for (size_t e = 0; e < numel; e++) { @@ -648,7 +737,7 @@ void WebGPUGraph::copy_inputs(const std::vector& inputs) { narrowed[e] = static_cast(src[e]); } wgpuQueueWriteBuffer( - queue_, tensor.buffer, 0, narrowed.data(), tensor.nbytes); + queue_, tensor.buffer, 0, narrowed.data(), live_nbytes); continue; } @@ -656,7 +745,7 @@ void WebGPUGraph::copy_inputs(const std::vector& inputs) { "WebGPU: unsupported input copy for input " + std::to_string(i) + " (host " + std::to_string(in.nbytes) + " bytes" + (in.host_is_int64 ? " int64" : "") + " vs buffer " + - std::to_string(tensor.nbytes) + " bytes)"); + std::to_string(live_nbytes) + " bytes)"); } } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 87f1576a5f3..5474ee4667a 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -23,8 +23,14 @@ namespace executorch::backends::webgpu { struct WebGPUTensor { WGPUBuffer buffer = nullptr; + // Max (allocation) dims/nbytes: the serialized upper-bound shape. The GPU + // buffer is sized from these and never reallocated (Vulkan allocate-at-max). std::vector dims; size_t nbytes = 0; + // Live dims/nbytes for dynamic shapes; always <= the max. == max on a static + // graph, so dynamic-resize logic keyed off these is inert there. + std::vector cur_dims; + size_t cur_nbytes = 0; // Serialized (GPU-side) element type, used to narrow wider host inputs. size_t elem_size = 0; bool is_int = false; @@ -171,6 +177,17 @@ class WebGPUGraph { return symint_sources_; } + // Records that a SymInt is a tensor's live dim size (sym_size.int), read from + // cur_dims at execute; distinct from SymIntSource (a scalar data element). + struct SymIntDimSource { + int symint_id; + int tensor_id; + int dim; + }; + void add_symint_dim_source(int symint_id, int tensor_id, int dim) { + symint_dim_sources_.push_back({symint_id, tensor_id, dim}); + } + // Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl. void update_symints_from_inputs(const std::vector& inputs); @@ -178,7 +195,25 @@ class WebGPUGraph { void add_resize_hook(int symint_id, std::function fn) { resize_hooks_.push_back({symint_id, std::move(fn)}); } - // Run hooks for changed SymInts then clear; call before execute(). + + // Set a graph input's live dims (<= max) + dirty it; static path stays inert. + void resize_input(int value_id, const std::vector& new_dims); + // Set a tensor's live dims (an op resize hook calls this for its output to + // cascade to consumers); validates the new dims fit the max, never reallocs. + void set_cur_dims(int value_id, const std::vector& new_dims); + const std::vector& cur_dims(int value_id) const { + return tensors_[value_id].cur_dims; + } + + // Per-tensor resize hook; mirrors Vulkan ExecuteNode::resize_fn. Runs in + // propagate_resize when trigger_tensor_id is dirty. + void add_tensor_resize_hook( + int trigger_tensor_id, + std::function fn) { + tensor_resize_hooks_.push_back({trigger_tensor_id, std::move(fn)}); + } + + // Run hooks for changed SymInts and tensors, then clear; call before execute. void propagate_resize(); // Mutable dispatch access for resize hooks (to rewrite workgroup_count_x). @@ -196,12 +231,14 @@ class WebGPUGraph { return queue_; } - void add_dispatch(WebGPUDispatch dispatch) { + // Returns the new dispatch's index (resize hooks rewrite it via dispatch_at). + size_t add_dispatch(WebGPUDispatch dispatch) { dispatches_.push_back(dispatch); + return dispatches_.size() - 1; } - // Record an in-graph-order buffer-to-buffer DMA (e.g. a flat copy). - void add_buffer_copy(WGPUBuffer src, WGPUBuffer dst, size_t nbytes) { + // In-graph buffer-to-buffer DMA (e.g. flat copy); returns the dispatch index. + size_t add_buffer_copy(WGPUBuffer src, WGPUBuffer dst, size_t nbytes) { WebGPUDispatch d; d.kind = WebGPUDispatch::Kind::Copy; d.copy_src = src; @@ -209,6 +246,7 @@ class WebGPUGraph { d.copy_nbytes = nbytes; d.kernel_name = "flat_copy"; dispatches_.push_back(d); + return dispatches_.size() - 1; } // Materialize a recorded prepack-routed constant into dst via one CPU->GPU @@ -297,6 +335,7 @@ class WebGPUGraph { }; std::unordered_map symints_; std::vector symint_sources_; + std::vector symint_dim_sources_; // Resize hooks + the set of SymInts changed since the last propagate_resize. struct ResizeHook { @@ -306,6 +345,15 @@ class WebGPUGraph { std::vector resize_hooks_; std::unordered_set dirty_symints_; + // Tensor-shape resize hooks + the set of tensors changed since the last + // propagate_resize (mirrors the SymInt pair above, for dynamic shapes). + struct TensorResizeHook { + int trigger_tensor_id; + std::function fn; + }; + std::vector tensor_resize_hooks_; + std::unordered_set dirty_tensors_; + std::vector input_ids_; std::vector output_ids_; diff --git a/backends/webgpu/runtime/WebGPUUtils.h b/backends/webgpu/runtime/WebGPUUtils.h index c5c779ffd5e..afa90d54aec 100644 --- a/backends/webgpu/runtime/WebGPUUtils.h +++ b/backends/webgpu/runtime/WebGPUUtils.h @@ -15,6 +15,7 @@ #include #include #include +#include namespace executorch::backends::webgpu::utils { @@ -24,6 +25,18 @@ inline T div_up(T a, T b) { return (a + b - 1) / b; } +// Product of dims (live element count); used by dynamic-resize hooks. +inline uint64_t numel_of(const std::vector& dims) { + uint64_t n = 1; + for (int64_t v : dims) { + if (v < 0) { + throw std::runtime_error("numel_of: negative dimension"); + } + n *= static_cast(v); + } + return n; +} + // Clamp workgroup size to device limit (SwiftShader caps at 128). inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) { WGPULimits limits = {}; diff --git a/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp b/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp index 44445d5b2cf..4f1c822aba8 100644 --- a/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp +++ b/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp @@ -40,6 +40,28 @@ void select_as_symint_impl(WebGPUGraph& graph, const std::vector& args) { static_cast(graph.get_int(index_id))); } +// aten.sym_size.int(self, dim) -> SymInt = self.size(dim). The WebGPU ops read +// live sizes from cur_dims directly, so this SymInt is usually unused. +void sym_size_impl(WebGPUGraph& graph, const std::vector& args) { + if (args.size() < 3) { + throw std::runtime_error("sym_size.int: expected [self, dim, out] args"); + } + const int self_id = args.at(0); + const int dim_id = args.at(1); + const int out_id = args.at(2); + if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::SymInt) { + return; // folded to a static Int -> nothing live to source + } + if (graph.get_value_type(dim_id) != WebGPUGraph::ValueType::Int) { + throw std::runtime_error("sym_size.int: dim arg is not an Int"); + } + if (graph.get_value_type(self_id) != WebGPUGraph::ValueType::Tensor) { + throw std::runtime_error("sym_size.int: self arg is not a Tensor"); + } + graph.add_symint_dim_source( + out_id, self_id, static_cast(graph.get_int(dim_id))); +} + // An operand is a live SymInt or a static Int constant. int32_t read_scalar(WebGPUGraph& graph, int id) { if (graph.get_value_type(id) == WebGPUGraph::ValueType::SymInt) { @@ -107,6 +129,7 @@ void sym_floordiv_impl(WebGPUGraph& graph, const std::vector& args) { WEBGPU_REGISTER_OPERATORS { WEBGPU_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint_impl); + WEBGPU_REGISTER_OP(sym_size.int, sym_size_impl); WEBGPU_REGISTER_OP(add, sym_add_impl); WEBGPU_REGISTER_OP(sub, sym_sub_impl); WEBGPU_REGISTER_OP(mul, sym_mul_impl);