Skip to content

Commit 8c2ee7b

Browse files
authored
[WebNN EP] Create MLGraphBuilder for every model builder (#21514)
Currently WebNN spec only allows MLGraphBuilder.build() to be called once, we need to create new builder for every subgraph in WebNN EP. Spec change: webmachinelearning/webnn#717
1 parent 3b73ef2 commit 8c2ee7b

File tree

6 files changed

+28
-28
lines changed

6 files changed

+28
-28
lines changed

onnxruntime/core/providers/webnn/builders/helper.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
8484
}
8585

8686
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
87-
const emscripten::val& wnn_builder_,
87+
const emscripten::val& wnn_builder,
8888
const WebnnDeviceType device_type,
8989
const logging::Logger& logger) {
9090
std::vector<std::vector<size_t>> supported_node_groups;
@@ -103,7 +103,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
103103
const auto* node(graph_viewer.GetNode(node_idx));
104104
bool supported = false;
105105
// Firstly check if platform supports the WebNN op.
106-
if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) {
106+
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
107107
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
108108
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
109109
}

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c
151151

152152
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
153153
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
154-
const emscripten::val& wnn_builder_,
154+
const emscripten::val& wnn_builder,
155155
const WebnnDeviceType device_type,
156156
const logging::Logger& logger);
157157
static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
@@ -241,14 +241,14 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
241241
{"Where", {"where", true}},
242242
};
243243

244-
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_,
244+
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
245245
const WebnnDeviceType device_type) {
246246
// Returns false if the op_type is not listed in the op_map.
247247
if (op_map.find(op_type) == op_map.end()) {
248248
return false;
249249
}
250250
// Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser.
251-
if (!wnn_builder_[op_map.find(op_type)->second.opName].as<bool>()) {
251+
if (!wnn_builder[op_map.find(op_type)->second.opName].as<bool>()) {
252252
return false;
253253
}
254254
// The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather

onnxruntime/core/providers/webnn/builders/model_builder.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,20 @@ namespace onnxruntime {
2020
namespace webnn {
2121

2222
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
23-
const emscripten::val& context, const emscripten::val& builder,
24-
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type)
23+
const emscripten::val& context, const DataLayout preferred_layout,
24+
const WebnnDeviceType wnn_device_type)
2525
: graph_viewer_(graph_viewer),
2626
logger_(logger),
2727
wnn_context_(context),
28-
wnn_builder_(builder),
2928
preferred_layout_(preferred_layout),
30-
wnn_device_type_(wnn_device_type) {}
29+
wnn_device_type_(wnn_device_type) {
30+
// Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build()
31+
// is only allowed to be called once.
32+
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context);
33+
if (!wnn_builder_.as<bool>()) {
34+
ORT_THROW("Failed to create WebNN builder.");
35+
}
36+
}
3137

3238
Status ModelBuilder::Initialize() {
3339
PreprocessInitializers();
@@ -332,6 +338,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
332338
if (!wnn_graph.as<bool>()) {
333339
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
334340
}
341+
// Explicitly release the WebNN builder to free memory.
342+
wnn_builder_ = emscripten::val::undefined();
335343
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_));
336344
model->SetInputs(std::move(input_names_));
337345
model->SetOutputs(std::move(output_names_));

onnxruntime/core/providers/webnn/builders/model_builder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class IOpBuilder;
2222
class ModelBuilder {
2323
public:
2424
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
25-
const emscripten::val& context, const emscripten::val& builder,
26-
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type);
25+
const emscripten::val& context, const DataLayout preferred_layout,
26+
const WebnnDeviceType wnn_device_type);
2727
~ModelBuilder() = default;
2828

2929
Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
@@ -62,8 +62,8 @@ class ModelBuilder {
6262
const GraphViewer& graph_viewer_;
6363
const logging::Logger& logger_;
6464

65-
emscripten::val wnn_context_ = emscripten::val::object();
66-
emscripten::val wnn_builder_ = emscripten::val::object();
65+
emscripten::val wnn_context_ = emscripten::val::undefined();
66+
emscripten::val wnn_builder_ = emscripten::val::undefined();
6767
DataLayout preferred_layout_;
6868
WebnnDeviceType wnn_device_type_;
6969
InlinedHashMap<std::string, emscripten::val> wnn_operands_;

onnxruntime/core/providers/webnn/webnn_execution_provider.cc

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
3838
if (!wnn_context_.as<bool>()) {
3939
ORT_THROW("Failed to create WebNN context.");
4040
}
41-
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
42-
if (!wnn_builder_.as<bool>()) {
43-
ORT_THROW("Failed to create WebNN builder.");
44-
}
4541
}
4642

4743
WebNNExecutionProvider::~WebNNExecutionProvider() {}
@@ -81,14 +77,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
8177

8278
const auto& logger = *GetLogger();
8379

84-
if (!wnn_builder_.as<bool>()) {
85-
// The GetCapability function may be called again after Compile due to the logic in the
86-
// PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc).
87-
// We need to re-create the wnn_builder_ here to avoid it's been released in last Compile.
88-
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
80+
emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
81+
if (!wnn_builder.as<bool>()) {
82+
ORT_THROW("Failed to create WebNN builder.");
8983
}
9084

91-
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger);
85+
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger);
86+
wnn_builder = emscripten::val::undefined();
9287

9388
if (node_groups.empty()) {
9489
return result;
@@ -218,9 +213,10 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
218213
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
219214

220215
webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_,
221-
wnn_builder_, preferred_layout_, wnn_device_type_);
216+
preferred_layout_, wnn_device_type_);
222217
std::unique_ptr<webnn::Model> model;
223218
ORT_RETURN_IF_ERROR(builder.Compile(model));
219+
224220
// Build map from input name to its index in input definitions.
225221
{
226222
InlinedHashMap<std::string, size_t> input_map;
@@ -329,9 +325,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
329325
node_compute_funcs.push_back(compute_info);
330326
}
331327

332-
// Explicitly release the WebNN builder to free memory.
333-
wnn_builder_ = emscripten::val::undefined();
334-
335328
return Status::OK();
336329
}
337330

onnxruntime/core/providers/webnn/webnn_execution_provider.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ class WebNNExecutionProvider : public IExecutionProvider {
4343

4444
private:
4545
emscripten::val wnn_context_ = emscripten::val::undefined();
46-
mutable emscripten::val wnn_builder_ = emscripten::val::undefined();
4746

4847
DataLayout preferred_layout_;
4948
webnn::WebnnDeviceType wnn_device_type_;

0 commit comments

Comments
 (0)