From cbba0ddd884aec0ee43d2a07d04788bce7c40f2d Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 14 Apr 2023 11:19:09 +0800 Subject: [PATCH 1/7] Support WebNN EP This PR enables WebNN EP in ONNX Runtime Web. It translates the ONNX nodes by WebNN API, which is implemented in C++ and uses Emscripten Embind API. Temporarily using preferred layout NHWC for WebNN graph partitions since the restriction in WebNN XNNPack backend implementation and the ongoing discussion in WebNN spec that whether WebNN should support both 'NHWC' and 'NCHW' layouts. No WebNN native EP, only for Web. --- cmake/CMakeLists.txt | 6 + cmake/adjust_global_compile_flags.cmake | 6 +- cmake/onnxruntime.cmake | 1 + cmake/onnxruntime_providers.cmake | 28 ++ cmake/onnxruntime_webassembly.cmake | 10 + include/onnxruntime/core/graph/constants.h | 1 + js/common/lib/inference-session.ts | 8 +- js/web/karma.conf.js | 57 ++- js/web/lib/index.ts | 1 + js/web/lib/wasm/session-options.ts | 20 + js/web/script/test-runner-cli-args.ts | 11 +- js/web/script/test-runner-cli.ts | 20 +- js/web/test/suite-test-list.jsonc | 99 +++++ js/web/test/test-runner.ts | 2 +- .../transpose_optimizer/optimizer_api_impl.cc | 2 +- .../core/providers/get_execution_providers.cc | 8 + .../providers/provider_factory_creators.h | 4 + .../core/providers/webnn/builders/helper.cc | 107 +++++ .../core/providers/webnn/builders/helper.h | 60 +++ .../builders/impl/activation_op_builder.cc | 87 ++++ .../webnn/builders/impl/base_op_builder.cc | 110 +++++ .../webnn/builders/impl/base_op_builder.h | 50 +++ .../webnn/builders/impl/binary_op_builder.cc | 79 ++++ .../webnn/builders/impl/builder_utils.cc | 86 ++++ .../webnn/builders/impl/builder_utils.h | 27 ++ .../webnn/builders/impl/clip_op_builder.cc | 81 ++++ .../webnn/builders/impl/concat_op_builder.cc | 80 ++++ .../webnn/builders/impl/conv_op_builder.cc | 287 +++++++++++++ .../webnn/builders/impl/gemm_op_builder.cc | 147 +++++++ .../webnn/builders/impl/pool_op_builder.cc | 165 ++++++++ .../webnn/builders/impl/reshape_op_builder.cc | 126 ++++++ .../webnn/builders/impl/resize_op_builder.cc | 275 +++++++++++++ .../builders/impl/transpose_op_builder.cc | 59 +++ .../core/providers/webnn/builders/model.cc | 127 ++++++ .../core/providers/webnn/builders/model.h | 94 +++++ .../providers/webnn/builders/model_builder.cc | 369 +++++++++++++++++ .../providers/webnn/builders/model_builder.h | 106 +++++ .../providers/webnn/builders/op_builder.h | 34 ++ .../webnn/builders/op_builder_factory.cc | 76 ++++ .../webnn/builders/op_builder_factory.h | 34 ++ .../webnn/webnn_execution_provider.cc | 388 ++++++++++++++++++ .../webnn/webnn_execution_provider.h | 48 +++ .../providers/webnn/webnn_provider_factory.cc | 39 ++ .../webnn/webnn_provider_factory_creator.h | 18 + .../core/session/provider_registration.cc | 10 + tools/ci_build/build.py | 7 + 46 files changed, 3438 insertions(+), 22 deletions(-) create mode 100644 onnxruntime/core/providers/webnn/builders/helper.cc create mode 100644 onnxruntime/core/providers/webnn/builders/helper.h create mode 100644 onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h create mode 100644 onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/builder_utils.h create mode 100644 onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/model.cc create mode 100644 onnxruntime/core/providers/webnn/builders/model.h create mode 100644 onnxruntime/core/providers/webnn/builders/model_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/model_builder.h create mode 100644 onnxruntime/core/providers/webnn/builders/op_builder.h create mode 100644 onnxruntime/core/providers/webnn/builders/op_builder_factory.cc create mode 100644 onnxruntime/core/providers/webnn/builders/op_builder_factory.h create mode 100644 onnxruntime/core/providers/webnn/webnn_execution_provider.cc create mode 100644 onnxruntime/core/providers/webnn/webnn_execution_provider.h create mode 100644 onnxruntime/core/providers/webnn/webnn_provider_factory.cc create mode 100644 onnxruntime/core/providers/webnn/webnn_provider_factory_creator.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index d5ab7f4b74a9a..2804063e1fc68 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -122,6 +122,7 @@ option(onnxruntime_TVM_CUDA_RUNTIME "Build TVM with CUDA support" OFF) option(onnxruntime_TVM_USE_LLVM "Build TVM with LLVM. Set customized path to llvm-config.exe here if need" OFF) option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algorithm. It is defined for TVM only") option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF) +option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF) # Options related to reducing the binary size produced by the build # XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON @@ -722,6 +723,11 @@ if (onnxruntime_USE_XNNPACK) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_XNNPACK=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES xnnpack) endif() +if (onnxruntime_USE_WEBNN) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBNN=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBNN=1) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES webnn) +endif() if (onnxruntime_USE_CANN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CANN=1) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 58a9271d26e7f..58027b2cf2e96 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -131,7 +131,11 @@ if (onnxruntime_DISABLE_RTTI) # Disable RTTI and turn usage of dynamic_cast and typeid into errors add_compile_options("$<$:/GR->" "$<$:/we4541>") else() - add_compile_options("$<$:-fno-rtti>") + # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled + # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/16911 + if(NOT onnxruntime_USE_WEBNN) + add_compile_options("$<$:-fno-rtti>") + endif() endif() else() #MSVC RTTI flag /GR is not added to CMAKE_CXX_FLAGS by default. But, anyway VC++2019 treats "/GR" default on. diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 9f34d1f46751d..571ceec193e8b 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -185,6 +185,7 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_ROCM} ${PROVIDERS_VITISAI} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} ${PROVIDERS_INTERNAL_TESTING} ${onnxruntime_winml} diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index caae2aacfd582..73469afa6116b 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -147,6 +147,9 @@ endif() if (onnxruntime_USE_XNNPACK) set(PROVIDERS_XNNPACK onnxruntime_providers_xnnpack) endif() +if(onnxruntime_USE_WEBNN) + set(PROVIDERS_WEBNN onnxruntime_providers_webnn) +endif() if(onnxruntime_USE_SNPE) include(onnxruntime_snpe_provider.cmake) endif() @@ -983,6 +986,31 @@ if (onnxruntime_USE_COREML) endif() endif() +if (onnxruntime_USE_WEBNN) + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "WebNN EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + + add_compile_definitions(USE_WEBNN=1) + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1) + endif() + file(GLOB_RECURSE onnxruntime_providers_webnn_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/webnn/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/webnn/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + + source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webnn_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_webnn ${onnxruntime_providers_webnn_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_webnn onnxruntime_common onnx onnx_proto Boost::mp11) + + add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) +endif() + if (onnxruntime_USE_NNAPI_BUILTIN) if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) message(FATAL_ERROR "NNAPI can not be used in a basic minimal build. Please build with '--minimal_build extended'") diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 193315f541b85..0d6ef7a9711b5 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -110,6 +110,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB) onnxruntime_providers ${PROVIDERS_JS} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBNN} onnxruntime_session onnxruntime_util re2::re2 @@ -186,6 +187,7 @@ else() onnxruntime_providers ${PROVIDERS_JS} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBNN} onnxruntime_session onnxruntime_util re2::re2 @@ -194,6 +196,10 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) endif() + if(onnxruntime_USE_WEBNN) + target_link_libraries(onnxruntime_webassembly PRIVATE onnxruntime_providers_webnn) + endif() + if (onnxruntime_ENABLE_TRAINING) target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() @@ -255,6 +261,10 @@ else() ) endif() + if (onnxruntime_USE_WEBNN) + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind") + endif() + # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 6fc9ef6e1c8c3..7e59aad80cc47 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -48,6 +48,7 @@ constexpr const char* kJsExecutionProvider = "JsExecutionProvider"; constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider"; constexpr const char* kTvmExecutionProvider = "TvmExecutionProvider"; constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider"; +constexpr const char* kWebNNExecutionProvider = "WebNNExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 638cb90f36716..f629858f99b72 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -165,7 +165,7 @@ export declare namespace InferenceSession { // Currently, we have the following backends to support execution providers: // Backend Node.js binding: supports 'cpu' and 'cuda'. - // Backend WebAssembly: supports 'cpu', 'wasm' and 'xnnpack'. + // Backend WebAssembly: supports 'cpu', 'wasm', 'xnnpack' and 'webnn'. // Backend ONNX.js: supports 'webgl'. interface ExecutionProviderOptionMap { cpu: CpuExecutionProviderOption; @@ -173,6 +173,7 @@ export declare namespace InferenceSession { wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; xnnpack: XnnpackExecutionProviderOption; + webnn: WebNNExecutionProviderOption; } type ExecutionProviderName = keyof ExecutionProviderOptionMap; @@ -200,6 +201,11 @@ export declare namespace InferenceSession { export interface XnnpackExecutionProviderOption extends ExecutionProviderOption { readonly name: 'xnnpack'; } + export interface WebNNExecutionProviderOption extends ExecutionProviderOption { + readonly name: 'webnn'; + deviceType?: number; // 0 - auto, 1 - gpu, 2 - cpu + powerPreference?: number; // 0 - auto, 1 - high-performance, 2 - low-power + } // #endregion // #endregion diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 2a4e71e064632..3d71df7687b01 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -86,13 +86,56 @@ module.exports = function (config) { hostname, listenAddress, customLaunchers: { - ChromeTest: { base: 'ChromeHeadless', flags: ['--enable-features=SharedArrayBuffer'] }, - ChromePerf: { base: 'Chrome', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer'] }, - ChromeDebug: { debug: true, base: 'Chrome', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer'] }, - ChromeCanaryTest: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] }, - ChromeCanaryProfileTest: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu', '--disable-dawn-features=disallow_unsafe_apis'] }, - ChromeCanaryDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] }, - ChromeCanaryProfileDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu', '--disable-dawn-features=disallow_unsafe_apis'] }, + ChromeTest: { + base: 'ChromeHeadless', + flags: ['--enable-features=SharedArrayBuffer'] + }, + ChromePerf: { + base: 'Chrome', + flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer'] + }, + ChromeDebug: { + debug: true, + base: 'Chrome', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer'] + }, + ChromeCanaryTest: { + base: 'ChromeCanary', + flags: [ + '--window-size=1,1', + '--enable-features=SharedArrayBuffer', + '--enable-unsafe-webgpu', + '--enable-experimental-web-platform-features' + ] + }, + ChromeCanaryProfileTest: { + base: 'ChromeCanary', + flags: [ + '--window-size=1,1', + '--enable-features=SharedArrayBuffer', + '--enable-unsafe-webgpu', + '--disable-dawn-features=disallow_unsafe_apis' + ] + }, + ChromeCanaryDebug: { + debug: true, + base: 'ChromeCanary', + flags: [ + '--remote-debugging-port=9333', + '--enable-features=SharedArrayBuffer', + '--enable-unsafe-webgpu', + '--enable-experimental-web-platform-features' + ] + }, + ChromeCanaryProfileDebug: { + debug: true, + base: 'ChromeCanary', + flags: [ + '--remote-debugging-port=9333', + '--enable-features=SharedArrayBuffer', + '--enable-unsafe-webgpu', + '--disable-dawn-features=disallow_unsafe_apis', + ] + }, // // ==== BrowserStack browsers ==== // diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 749331058cc4a..41024de108acb 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -22,4 +22,5 @@ if (!BUILD_DEFS.DISABLE_WASM) { registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); registerBackend('xnnpack', wasmBackend, 9); + registerBackend('webnn', wasmBackend, 9); } diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 10cc48257dc52..abd27e8f7f5ac 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -64,6 +64,26 @@ const setExecutionProviders = case 'xnnpack': epName = 'XNNPACK'; break; + case 'webnn': + epName = 'WEBNN'; + if (typeof ep !== 'string') { + const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; + if (webnnOptions?.deviceType) { + const keyDataOffset = allocWasmString("deviceType", allocs); + const valueDataOffset = allocWasmString(webnnOptions.deviceType.toString(), allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + throw new Error(`Can't set a session config entry: "deviceType" - ${webnnOptions.deviceType}`); + } + } + if (webnnOptions?.powerPreference) { + const keyDataOffset = allocWasmString("powerPreference", allocs); + const valueDataOffset = allocWasmString(webnnOptions.powerPreference.toString(), allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + throw new Error(`Can't set a session config entry: "powerPreference" - ${webnnOptions.powerPreference}`); + } + } + } + break; case 'webgpu': epName = 'JS'; break; diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index e20c391513c67..c189c22cfdf91 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -37,6 +37,7 @@ Options: webgpu wasm xnnpack + webnn -e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following: chrome (default) edge (Windows only) @@ -104,7 +105,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'; + type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; } @@ -359,12 +360,12 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } // Option: -b=<...>, --backend=<...> - const browserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack']; + const browserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack', 'webnn']; - // TODO: remove this when Chrome support WebGPU. - // we need this for now because Chrome does not support webgpu yet, + // TODO: remove this when Chrome support WebGPU or WebNN. + // we need this for now because Chrome does not support webgpu and webnn yet, // and ChromeCanary is not in CI. - const defaultBrowserBackends = ['webgl', /* 'webgpu', */ 'wasm', 'xnnpack']; + const defaultBrowserBackends = ['webgl', /* 'webgpu', */ 'wasm', 'xnnpack'/*, 'webnn'*/]; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; const backend = (typeof backendArgs !== 'string') ? (env === 'node' ? nodejsBackends : defaultBrowserBackends) : diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 72938789bc2df..90eb32ddaddd8 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -53,7 +53,7 @@ async function main() { // The default backends and opset version lists. Those will be used in suite tests. const DEFAULT_BACKENDS: readonly TestRunnerCliArgs.Backend[] = - args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl', 'webgpu']; + args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl', 'webgpu', 'webnn']; const DEFAULT_OPSET_VERSIONS = fs.readdirSync(TEST_DATA_MODEL_NODE_ROOT, {withFileTypes: true}) .filter(dir => dir.isDirectory() && dir.name.startsWith('opset')) .map(dir => dir.name.slice(5)); @@ -459,12 +459,13 @@ async function main() { // STEP 5. use Karma to run test npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); const webgpu = args.backends.indexOf('webgpu') > -1; + const webnn = args.backends.indexOf('webnn') > -1; const browser = getBrowserNameFromEnv( args.env, args.bundleMode === 'perf' ? 'perf' : args.debug ? 'debug' : 'test', - webgpu, config.options.globalEnvFlags?.webgpu?.profilingMode === 'default'); + webgpu, webnn, config.options.globalEnvFlags?.webgpu?.profilingMode === 'default'); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); @@ -474,7 +475,7 @@ async function main() { if (args.noSandbox) { karmaArgs.push('--no-sandbox'); } - if (webgpu) { + if (webgpu || webnn) { karmaArgs.push('--force-localhost'); } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); @@ -569,10 +570,10 @@ async function main() { } function getBrowserNameFromEnv( - env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean, profile: boolean) { + env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean, profile: boolean) { switch (env) { case 'chrome': - return selectChromeBrowser(mode, webgpu, profile); + return selectChromeBrowser(mode, webgpu, webnn, profile); case 'edge': return 'Edge'; case 'firefox': @@ -588,7 +589,7 @@ async function main() { } } - function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean, profile: boolean) { + function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean, profile: boolean) { if (webgpu) { switch (mode) { case 'debug': @@ -596,6 +597,13 @@ async function main() { default: return profile ? 'ChromeCanaryProfileTest' : 'ChromeCanaryDebug'; } + } else if (webnn) { + switch (mode) { + case 'debug': + return 'ChromeCanaryDebug'; + default: + return 'ChromeCanaryTest'; + } } else { switch (mode) { case 'debug': diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 17928899c91b1..f2965c506b2bf 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1419,5 +1419,104 @@ "test_instancenorm_example" ], "ops": [] + }, + "webnn": { + "onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"], + "node": [ + // Check in node tests that have native Wasm implementations. + // (i.e.) not tests that rely on the fallback cpu implementations. + // Use the 'cpu' level of node tests to test those implementations. + "test_add_bcast", + "test_add", + "test_sub_bcast", + "test_sub_example", + "test_sub", + "test_mul_bcast", + "test_mul_example", + "test_mul", + "test_div_bcast", + "test_div_example", + "test_div", + "test_xor_bcast3v1d", + "test_xor_bcast3v2d", + "test_xor_bcast4v2d", + "test_xor_bcast4v3d", + "test_xor_bcast4v4d", + "test_xor2d", + "test_xor3d", + "test_xor4d", + "test_or_bcast3v1d", + "test_or_bcast3v2d", + "test_or_bcast4v2d", + "test_or_bcast4v3d", + "test_or_bcast4v4d", + "test_and_bcast3v1d", + "test_and_bcast3v2d", + "test_and_bcast4v2d", + "test_and_bcast4v3d", + "test_and_bcast4v4d", + "test_and2d", + "test_and3d", + "test_and4d", + "test_prelu_broadcast", + "test_prelu_example", + "test_basic_conv_with_padding", + "test_basic_conv_without_padding", + "test_batchnorm_epsilon", + "test_batchnorm_example", + "opset{10,11,12}/test_cast_STRING_to_FLOAT", + "test_clip_splitbounds", + "test_clip_outbounds", + "test_clip_inbounds", + "test_clip_example", + "test_clip_default_min", + "test_clip_default_max", + "test_clip_default_inbounds", + "test_clip", + "test_conv_with_strides_and_asymmetric_padding", + "test_conv_with_strides_no_padding", + "test_conv_with_strides_padding", + "test_gemm_nobroadcast", + "test_gemm_broadcast", + "test_matmul_2d", + "test_matmul_3d", + "test_matmul_4d", + "test_softmax_axis_0", + "test_softmax_axis_1", + "test_softmax_axis_2", + "test_softmax_default_axis", + "test_softmax_example", + "test_softmax_large_number", + "test_sum_example", + "test_sum_one_input", + "test_sum_two_inputs", + "test_averagepool_1d_default", + "test_averagepool_2d_default", + "test_averagepool_2d_pads", + "test_averagepool_2d_precomputed_pads", + "test_averagepool_2d_precomputed_same_upper", + "test_averagepool_2d_precomputed_strides", + "test_averagepool_2d_same_upper", + "test_averagepool_2d_same_lower", + "test_averagepool_2d_strides", + "test_averagepool_3d_default", + "test_maxpool_1d_default", + "test_maxpool_2d_default", + "test_maxpool_2d_pads", + "test_maxpool_2d_precomputed_pads", + "test_maxpool_2d_precomputed_same_upper", + "test_maxpool_2d_precomputed_strides", + "test_maxpool_2d_same_lower", + "test_maxpool_2d_same_upper", + "test_maxpool_2d_strides", + "test_maxpool_3d_default", + "test_globalaveragepool_precomputed", + "test_globalaveragepool", + "test_globalmaxpool_precomputed", + "test_globalmaxpool", + "test_instancenorm_epsilon", + "test_instancenorm_example" + ], + "ops": [] } } diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 26ebcbbd6e212..6035b32a8f149 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -311,7 +311,7 @@ export class TensorResultValidator { } else if (backend === 'webgpu') { this.absoluteThreshold = WEBGPU_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WEBGPU_THRESHOLD_RELATIVE_ERROR; - } else if (backend === 'wasm' || backend === 'xnnpack') { + } else if (backend === 'wasm' || backend === 'xnnpack' || backend === 'webnn') { this.absoluteThreshold = WASM_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WASM_THRESHOLD_RELATIVE_ERROR; } else if (backend === 'onnxruntime') { diff --git a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc index 65a037cfa0a0f..b04b02e74c1bf 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc @@ -873,7 +873,7 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { { "FusedConv", "QLinearAveragePool", "QLinearGlobalAveragePool" -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) // The CUDA/ROCM Resize kernel is layout sensitive as it only handles NCHW input. // The CPU kernel and ONNX spec are not limited to handling NCHW input so are not layout sensitive, and // onnx_layout_transformation::HandleResize is used. diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index f7b372d1eff74..b0f510f054a03 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -146,6 +146,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kWebNNExecutionProvider, +#ifdef USE_WEBNN + true, +#else + false, #endif }, { diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index b019ede434b83..42a58097e1635 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -86,6 +86,10 @@ #include "core/providers/xnnpack/xnnpack_provider_factory_creator.h" #endif +#if defined(USE_WEBNN) +#include "core/providers/webnn/webnn_provider_factory_creator.h" +#endif + #if defined(USE_CANN) #include "core/providers/cann/cann_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc new file mode 100644 index 0000000000000..c52a6b18c696f --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "helper.h" +#include + +#include "op_builder_factory.h" + +namespace onnxruntime { +namespace webnn { + +bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger) { + const auto* shape_proto = node_arg.Shape(); + if (!shape_proto) { + LOGS(logger, WARNING) << "NodeArg [" << node_arg.Name() << "] has no shape info"; + return false; + } + + // We already checked the shape has no dynamic dimension. + for (const auto& dim : shape_proto->dim()) { + shape.push_back(dim.dim_value()); + } + + return true; +} + +bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const logging::Logger& logger) { + const auto& op_builders = GetOpBuilders(); + if (Contains(op_builders, node.OpType())) { + const auto* op_builder = op_builders.at(node.OpType()); + return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, logger); + } else { + return false; + } +} + +bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) { + const auto& input_name = input.Name(); + const auto* shape_proto = input.Shape(); + // We do not support input with no shape. + if (!shape_proto) { + LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name + << "] has not shape"; + return false; + } + + for (const auto& dim : shape_proto->dim()) { + // For now we workaround dynamic shape support by assuming 1. + if (!dim.has_dim_value()) { + LOGS(logger, VERBOSE) << "Dynamic shape is not supported for now, assume to be 1, for input:" << input_name; + } + } + + return true; +} + +std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, + const emscripten::val& wnn_builder_, + const logging::Logger& logger) { + std::vector> supported_node_groups; + + for (const auto* input : graph_viewer.GetInputs()) { + if (!IsInputSupported(*input, "graph", logger)) { + return supported_node_groups; + } + } + + std::vector supported_node_group; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + for (size_t i = 0; i < node_indices.size(); i++) { + auto node_idx = node_indices[i]; + const auto* node(graph_viewer.GetNode(node_idx)); + bool supported = false; + // Firstly check if platform supports the WebNN op. + if (CheckSingleOp(node->OpType(), wnn_builder_)) { + LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; + supported = IsNodeSupported(*node, graph_viewer, logger); + } + + LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() + << "] index: [" << node_idx + << "] name: [" << node->Name() + << "] supported: [" << supported + << "]"; + if (supported) { + supported_node_group.push_back(node_idx); + } else { + if (!supported_node_group.empty()) { + supported_node_groups.push_back(supported_node_group); + supported_node_group.clear(); + } + } + } + + if (!supported_node_group.empty()) { + supported_node_groups.push_back(supported_node_group); + } + + return supported_node_groups; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h new file mode 100644 index 0000000000000..9a386e6e34f25 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/inlined_containers.h" +#include +#include "core/providers/common.h" + +#include +#include + +namespace onnxruntime { + +class GraphViewer; +class NodeArg; + +namespace logging { +class Logger; +} + +namespace webnn { + +bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); + +bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); + +// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. +std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, + const emscripten::val& wnn_builder_, + const logging::Logger& logger); +static const InlinedHashMap op_map = { + {"Add", "add"}, + {"Sub", "sub"}, + {"Mul", "mul"}, + {"Div", "div"}, + {"Relu", "relu"}, + {"LeakyRelu", "leakyRelu"}, + {"Sigmoid", "sigmoid"}, + {"Clip", "clamp"}, + {"Conv", "conv2d"}, + {"ConvTranspose", "convTranspose2d"}, + {"Concat", "concat"}, + {"Gemm", "gemm"}, + {"GlobalAveragePool", "averagePool2d"}, + {"GlobalMaxPool", "maxPool2d"}, + {"AveragePool", "averagePool2d"}, + {"MaxPool", "maxPool2d"}, + {"Reshape", "reshape"}, + {"Resize", "resample2d"}, + {"Transpose", "transpose"}}; + +inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_) { + return op_map.find(op_type) != op_map.end() && wnn_builder_[op_map.find(op_type)->second].as(); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc new file mode 100644 index 0000000000000..0dcbf9b527ab3 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ActivationOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + int GetMinSupportedOpSet(const Node& node) const override; +}; + +// Add operator related. + +Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& /* logger */) const { + const auto& op_type(node.OpType()); + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val output = emscripten::val::object(); + if (op_type == "Relu") { + if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { + LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node.Name() << "] fused"; + output = input; + } else { + output = model_builder.GetBuilder().call("relu", input); + } + } else if (op_type == "LeakyRelu") { + if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { + LOGS_DEFAULT(VERBOSE) << "LeakyRelu Node [" << node.Name() << "] fused"; + output = input; + } else { + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); + options.set("alpha", helper.Get("alpha", (float)0.0)); + output = model_builder.GetBuilder().call("leakyRelu", input, options); + } + } else if (op_type == "Sigmoid") { + if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { + LOGS_DEFAULT(VERBOSE) << "Sigmoid Node [" << node.Name() << "] fused"; + output = input; + } else { + output = model_builder.GetBuilder().call("sigmoid", input); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +int ActivationOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { + // All ops opset 5- uses consumed_inputs attribute which is not supported for now. + return 6; +} + +void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = {"Relu", "LeakyRelu", "Sigmoid"}; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc new file mode 100644 index 0000000000000..13c0f0131f2a6 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/shared/utils/utils.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +// Shared functions. +bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) { + for (const auto* node_arg : node.InputDefs()) { + const auto& input_name(node_arg->Name()); + if (!Contains(initializers, input_name)) + continue; + + const auto& tensor = *initializers.at(input_name); + if (tensor.has_data_location() && + tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + LOGS(logger, VERBOSE) << "Initializer [" << input_name + << "] with external data location are not currently supported"; + return true; + } + } + + return false; +} + +// Add operator related. + +Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + ORT_RETURN_IF_NOT( + IsOpSupported(model_builder.GetInitializerTensors(), node, logger), + "Unsupported operator ", + node.OpType()); + ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); + LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() + << "] type: [" << node.OpType() << "] was added"; + return Status::OK(); +} + +// Operator support related. + +bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + if (!HasSupportedInputs(node, logger)) + return false; + + // We do not support external initializers for now. + if (HasExternalInitializer(initializers, node, logger)) + return false; + + if (!HasSupportedOpSet(node, logger)) + return false; + + return IsOpSupportedImpl(initializers, node, logger); +} + +bool BaseOpBuilder::HasSupportedInputs(const Node& node, const logging::Logger& logger) const { + const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); + for (const auto* input : node.InputDefs()) { + if (!IsInputSupported(*input, node_name, logger)) { + return false; + } + } + + return HasSupportedInputsImpl(node, logger); +} + +bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { + // We only check the type of input 0 by default, specific op builder can override this. + const auto& input = *node.InputDefs()[0]; + + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + +bool BaseOpBuilder::HasSupportedOpSet(const Node& node, + const logging::Logger& logger) const { + auto since_version = node.SinceVersion(); + if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) { + LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset [" + << GetMinSupportedOpSet(node) << ", " + << GetMaxSupportedOpSet(node) << "]"; + return false; + } + + return true; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h new file mode 100644 index 0000000000000..7b4148c34d4a7 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webnn/builders/op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ModelBuilder; + +class BaseOpBuilder : public IOpBuilder { + public: + virtual ~BaseOpBuilder() = default; + + // Add operator related. + public: + virtual void AddInitializersToSkip(ModelBuilder& /* model_builder */, const Node& /* node */) const override {} + Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + + protected: + virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0; + + // Operator support related. + public: + bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; + + protected: + virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, + const logging::Logger& /* logger */) const { + return true; + } + + virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; + + virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 14; } + + private: + bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; + bool HasSupportedInputs(const Node& node, const logging::Logger& logger) const; +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc new file mode 100644 index 0000000000000..96ee9cf3f11d9 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" +#include "core/providers/webnn/builders/helper.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class BinaryOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + int GetMinSupportedOpSet(const Node& node) const override; +}; + +// Add operator related. + +Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& /* logger */) const { + const auto& op_type(node.OpType()); + + emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val output = emscripten::val::object(); + if (op_type == "Add") { + output = model_builder.GetBuilder().call("add", input0, input1); + } else if (op_type == "Sub") { + output = model_builder.GetBuilder().call("sub", input0, input1); + } else if (op_type == "Mul") { + output = model_builder.GetBuilder().call("mul", input0, input1); + } else if (op_type == "Div") { + output = model_builder.GetBuilder().call("div", input0, input1); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { + // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now. + return 7; +} + +void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = + { + "Add", + "Sub", + "Mul", + "Div", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc new file mode 100644 index 0000000000000..516ac7464345b --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/shared/utils/utils.h" + +#include "builder_utils.h" +#include "core/providers/webnn/builders/helper.h" + +namespace onnxruntime { +namespace webnn { + +common::Status ComputeConvPads(const std::vector input_shape, + const int64_t weight_size_y, + const int64_t weight_size_x, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + AutoPadType auto_pad_type, + std::vector& pads_out) { + const int64_t input_size_y = input_shape[2]; + const int64_t input_size_x = input_shape[3]; + const int64_t stride_y = onnx_strides[0]; + const int64_t stride_x = onnx_strides[1]; + const int64_t dilation_y = onnx_dilations[0]; + const int64_t dilation_x = onnx_dilations[1]; + + int64_t padding_top = onnx_pads[0]; + int64_t padding_bottom = onnx_pads[2]; + int64_t padding_left = onnx_pads[1]; + int64_t padding_right = onnx_pads[3]; + + ORT_RETURN_IF_ERROR(ComputePad(input_size_y, + stride_y, weight_size_y, dilation_y, + auto_pad_type, + padding_top, padding_bottom)); + ORT_RETURN_IF_ERROR(ComputePad(input_size_x, + stride_x, weight_size_x, dilation_x, + auto_pad_type, + padding_left, padding_right)); + + pads_out = {padding_top, padding_left, padding_bottom, padding_right}; + + return Status::OK(); +} + +common::Status HandleAutoPad(const std::vector input_shape, + const int64_t weight_size_y, + const int64_t weight_size_x, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + AutoPadType auto_pad_type, + AutoPadType& auto_pad_type_out) { + auto_pad_type_out = auto_pad_type; + if (auto_pad_type == AutoPadType::NOTSET && onnx_dilations == std::vector{1, 1}) { + { + std::vector same_upper_pads; + ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, + onnx_pads, onnx_strides, onnx_dilations, + AutoPadType::SAME_UPPER, same_upper_pads)); + if (onnx_pads == same_upper_pads) { + auto_pad_type_out = AutoPadType::SAME_UPPER; + return Status::OK(); + } + } + + { + std::vector same_lower_pads; + ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, + onnx_pads, onnx_strides, onnx_dilations, + AutoPadType::SAME_LOWER, same_lower_pads)); + if (onnx_pads == same_lower_pads) { + auto_pad_type_out = AutoPadType::SAME_LOWER; + return Status::OK(); + } + } + } + + return Status::OK(); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h new file mode 100644 index 0000000000000..76acbca0536ea --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +// This contains the utility functions which will be used to build a webnn model + +#pragma once + +#include "core/common/status.h" +#include "core/graph/basic_types.h" + +namespace onnxruntime { +namespace webnn { + +// Try to see if we can map explicit padding to auto padding for Conv/Pool. +// Since usually use auto padding is more efficient. +common::Status HandleAutoPad(const std::vector input_shape, + const int64_t weight_size_y, + const int64_t weight_size_x, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + AutoPadType auto_pad_type, + AutoPadType& auto_pad_type_out) ORT_MUST_USE_RESULT; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc new file mode 100644 index 0000000000000..1c07439c8e8cc --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" +#include "core/providers/shared/utils/utils.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ClipOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Both min and max values will be injected into the layer, no need to add to the model. + if (node.SinceVersion() >= 11) { + if (node.InputDefs().size() > 1) + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + + if (node.InputDefs().size() > 2) + model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); + } +} + +Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_name = node.InputDefs()[0]->Name(); + const auto& output_name = node.OutputDefs()[0]->Name(); + emscripten::val options = emscripten::val::object(); + float minValue, maxValue; + ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetInitializerTensors(), node, minValue, maxValue, logger), + "GetClipMinMax failed"); + options.set("minValue", minValue); + options.set("maxValue", maxValue); + emscripten::val input = model_builder.GetOperand(input_name); + emscripten::val output = emscripten::val::object(); + if (Contains(model_builder.GetFusedActivations(), input_name)) { + LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node.Name() << "] fused"; + output = input; + } else { + output = model_builder.GetBuilder().call("clamp", input, options); + } + + model_builder.AddOperand(output_name, std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + float min, max; + return GetClipMinMax(initializers, node, min, max, logger); +} + +void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc new file mode 100644 index 0000000000000..1ed516a3c4d9e --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ConcatOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ConcatOpBuilder::AddToModelBuilderImpl, cannot get input shape"); + } + auto rank = input_shape.size(); + NodeAttrHelper helper(node); + uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); + + emscripten::val inputs = emscripten::val::array(); + for (const auto* input : node.InputDefs()) { + LOGS(logger, VERBOSE) << "input name " << input->Name(); + inputs.call("push", model_builder.GetOperand(input->Name())); + } + + emscripten::val output = model_builder.GetBuilder().call("concat", inputs, axis); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. +bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const logging::Logger& logger) const { + std::vector input_shape; + const auto& input_defs(node.InputDefs()); + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto input_size = input_shape.size(); + if (input_size > 4 || input_size == 0) { + LOGS_DEFAULT(VERBOSE) << "Concat only supports up to 1-4d shape, input is " + << input_size << "d shape"; + return false; + } + + return true; +} + +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc new file mode 100644 index 0000000000000..93746090a29bd --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -0,0 +1,287 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" +#include "builder_utils.h" + +namespace onnxruntime { +namespace webnn { + +class ConvOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, + const logging::Logger& /* logger */) const override; +}; + +void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // skip the weight for conv as we need to transpose. + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W + model_builder.AddInputToSkip(node.InputDefs()[1]->Name()); +} + +// Helper functions +common::Status SetConvBaseOptions(ModelBuilder& model_builder, + const Node& node, emscripten::val& options, + const std::vector& strides, + const std::vector& dilations, + const std::vector& pads, + const logging::Logger& logger) { + NodeAttrHelper helper(node); + const auto group = helper.Get("group", static_cast(1)); + const auto& input_defs = node.InputDefs(); + const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + const auto& weight_shape = weight_tensor.dims(); + + options.set("strides", emscripten::val::array(strides)); + options.set("dilations", emscripten::val::array(dilations)); + options.set("inputLayout", emscripten::val("nhwc")); + options.set("groups", group); + // Add Padding. + // Usually using autopadding is more efficient than using explicit padding. + // Try to see if we can map explicit padding to auto padding. + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], + helper.Get("pads", std::vector{0, 0, 0, 0}), + helper.Get("strides", std::vector{1, 1}), + helper.Get("dilations", std::vector{1, 1}), + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + options.set("autoPad", emscripten::val("same-lower")); + } else { + options.set("autoPad", emscripten::val("same-upper")); + } + } else { + options.set("padding", emscripten::val::array(pads)); + } + + // Add bias if present. + if (input_defs.size() > 2) { + options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); + } + InlinedHashSet supported_nodes{"Clip", "Relu"}; + emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0], supported_nodes); + if (emscripten::val::null() != activation) { + options.set("activation", activation); + } + + return Status::OK(); +} + +// Both depthwise Conv and ConvTranspose share the same logic to add the layout. +Status AddInitializerInNewLayout(ModelBuilder& model_builder, + const std::string& name, + bool is_conv) { + const auto& tensor = *model_builder.GetInitializerTensors().at(name); + auto data_type = tensor.data_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The initializer of graph has unsupported type, name: ", + tensor.name(), " type: ", data_type); + } + + const auto& shape = tensor.dims(); + std::vector dims; + std::transform(shape.cbegin(), shape.cend(), + std::back_inserter(dims), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + + ORT_RETURN_IF_NOT(dims.size() == 4, + "The initializer is not 4D: ", name, " actual dim ", dims.size()); + const uint8_t* src = nullptr; + Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath()); + src = unpacked_tensor.DataAsByteSpan().data(); + const auto out_t = dims[0], in_t = dims[1], + h_t = dims[2], w_t = dims[3]; + std::vector dest_shape; + if (is_conv == 1) + dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 + else + dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv weight + + SafeInt num_elements = SafeInt(Product(dest_shape)); + + size_t element_size = 4; + std::unique_ptr buffer_holder(new uint8_t[element_size * num_elements]); + uint8_t* buffer = buffer_holder.get(); + + for (uint32_t out = 0; out < out_t; out++) { + for (uint32_t in = 0; in < in_t; in++) { + for (uint32_t h = 0; h < h_t; h++) { + for (uint32_t w = 0; w < w_t; w++) { + auto onnx_idx = out * in_t * h_t * w_t + + in * h_t * w_t + + h * w_t + + w; + + uint32_t nnapi_idx; + if (is_conv == 1) { // L_0231 + nnapi_idx = out * h_t * w_t * in_t + + h * w_t * in_t + + w * in_t + + in; + } else { // L_1230 for depthwise conv weight + nnapi_idx = in * h_t * w_t * out_t + + h * w_t * out_t + + w * out_t + + out; + } + + for (size_t i = 0; i < element_size; i++) { + buffer[element_size * nnapi_idx + i] = src[element_size * onnx_idx + i]; + } + } + } + } + } + ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(name, buffer, num_elements * element_size, + dest_shape, 4)); + return Status::OK(); +} + +// Add operator related. + +Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val output = emscripten::val::object(); + + NodeAttrHelper helper(node); + const auto strides = helper.Get("strides", std::vector{1, 1}); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const auto& weight = input_defs[1]->Name(); + + if (op_type == "Conv") { + emscripten::val options = emscripten::val::object(); + ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); + int groups = options["groups"].as(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + bool depthwise = (groups == input_shape[3] && groups != 1); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout( + model_builder, weight, !depthwise)); + if (!depthwise) { + options.set("filterLayout", emscripten::val("ohwi")); + } else { + options.set("filterLayout", emscripten::val("ihwo")); + } + emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); + output = model_builder.GetBuilder().call("conv2d", input, filter, options); + } else { + emscripten::val options = emscripten::val::object(); + ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout( + model_builder, weight, false)); + options.set("filterLayout", emscripten::val("ohwi")); + // When the 'output_shape' is specificed, the 'output_padding' values + // in options.outputPadding are ignored. + std::vector dim; + std::vector output_padding{0, 0}; + if (helper.HasAttr("output_shape")) { + // Default value of 'output_shape' will be ignore as we already check if + // it's existed. + dim = helper.Get("output_shape", std::vector{-1, -1}); + // Extract the height and width. + std::vector output_shape; + if (dim.size() == 2) { + output_shape = dim; + } else if (dim.size() == 4) { + output_shape = {dim[2], dim[3]}; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); + } + // Padding values are auto generated. + if (helper.HasAttr("kernel_shape")) { + std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); + std::vector total_padding(2); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + for (size_t i = 0; i < 2; i++) { + total_padding[i] = strides[i] * (input_shape[i + 1] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + } + pads[0] = total_padding[0] - (total_padding[0] / 2); + pads[1] = total_padding[0] / 2; + pads[2] = total_padding[1] - (total_padding[1] / 2); + pads[3] = total_padding[1] / 2; + options.set("padding", emscripten::val::array(pads)); + } + options.set("outputSizes", emscripten::val::array(output_shape)); + } else { + output_padding = helper.Get("output_padding", std::vector{0, 0}); + options.set("outputPadding", emscripten::val::array(output_padding)); + } + emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); + + output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + const auto& name = node.Name(); + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + + const auto& weight_name = input_defs[1]->Name(); + if (Contains(initializers, weight_name)) { + const auto& tensor = *initializers.at(weight_name); + if (tensor.dims().size() != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() + << " Only conv 2d is supported."; + return false; + } + } else { + LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; + return false; + } + + return true; +} + +void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = + { + "Conv", + "ConvTranspose", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc new file mode 100644 index 0000000000000..6a41b75e26a14 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" +#include "builder_utils.h" + +namespace onnxruntime { +namespace webnn { + +class GemmOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, + const logging::Logger& /* logger */) const override; +}; + +// Add operator related. +Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& /* logger */) const { + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C + + emscripten::val a = model_builder.GetOperand(node.InputDefs()[a_idx]->Name()); + emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); + emscripten::val output = emscripten::val::object(); + if (op_type == "MatMul") { + output = model_builder.GetBuilder().call("matmul", a, b); + } else { // Gemm + emscripten::val options = emscripten::val::object(); + NodeAttrHelper helper(node); + const auto transA = helper.Get("transA", 0); + options.set("aTranspose", emscripten::val(transA == 1)); + const auto transB = helper.Get("transB", 0); + options.set("bTranspose", emscripten::val(transB == 1)); + const auto alpha = helper.Get("alpha", 1.0f); + const auto beta = helper.Get("beta", 1.0f); + options.set("alpha", alpha); + options.set("beta", beta); + + // Add bias if present. + if (input_defs.size() > 2) { + options.set("c", model_builder.GetOperand(node.InputDefs()[c_idx]->Name())); + } + + output = model_builder.GetBuilder().call("gemm", a, b, options); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + (void)initializers; + const auto& op_type = node.OpType(); + const auto& input_defs(node.InputDefs()); + const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C + + if (op_type == "Gemm") { + std::vector a_shape; + { + if (!GetShape(*input_defs[a_idx], a_shape, logger)) + return false; + + if (a_shape.size() != 2) { + LOGS(logger, VERBOSE) << "A must be 2D"; + return false; + } + + if (Product(a_shape) == 0) { + LOGS(logger, VERBOSE) << "A must be non-empty"; + return false; + } + } + + std::vector b_shape; + { + if (!GetShape(*input_defs[b_idx], b_shape, logger)) + return false; + + if (b_shape.size() != 2) { + LOGS(logger, VERBOSE) << "B must be 2D"; + return false; + } + + if (Product(b_shape) == 0) { + LOGS(logger, VERBOSE) << "B must be non-empty"; + return false; + } + } + + // C of Gemm. + if (input_defs.size() == 3) { + std::vector c_shape; + if (!GetShape(*input_defs[c_idx], c_shape, logger)) + return false; + + size_t c_dim = c_shape.size(); + + if (c_dim > 1) { + // TODO: Supports other shape of C. + // Currently WebNN implementation in Chromium only supports 1-D C. + return false; + } + if (c_dim == 0) { + LOGS(logger, VERBOSE) << "C of Gemm is a scalar"; + } else { + auto c_size = c_shape[c_dim - 1]; + NodeAttrHelper helper(node); + const auto transB = helper.Get("transB", 0); + if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) { + LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape[" + << (transB == 0 ? "1" : "0") << "]" + << " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]" + << " c_size: " << c_size; + + return false; + } + } + } + } + + return true; +} + +void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc new file mode 100644 index 0000000000000..a5200f612e860 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" +#include "builder_utils.h" + +namespace onnxruntime { +namespace webnn { + +class PoolOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + + bool is_global_pooling = false; + bool is_average_pool = false; + if (op_type == "GlobalAveragePool") { + is_global_pooling = true; + is_average_pool = true; + } else if (op_type == "GlobalMaxPool") { + is_global_pooling = true; + } else if (op_type == "AveragePool") { + is_average_pool = true; + } else if (op_type == "MaxPool") { + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unknown op: ", op_type); + } + + emscripten::val options = emscripten::val::object(); + NodeAttrHelper helper(node); + + const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + if (!is_global_pooling) { + options.set("windowDimensions", emscripten::val::array(kernel_shape)); + } + const auto strides = helper.Get("strides", std::vector{1, 1}); + options.set("strides", emscripten::val::array(strides)); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + options.set("dilations", emscripten::val::array(dilations)); + options.set("layout", emscripten::val("nhwc")); + + // Add Padding. + // Usually using autopadding is more efficient than using explicit padding. + // Try to see if we can map explicit padding to auto padding. + const auto onnx_kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + const auto onnx_strides = helper.Get("strides", std::vector{1, 1}); + const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, onnx_kernel_shape[0], onnx_kernel_shape[1], + onnx_pads, onnx_strides, {1, 1} /* dilations */, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + options.set("autoPad", "same-lower"); + } else { + options.set("autoPad", "same-upper"); + } + } else { + options.set("padding", emscripten::val::array(pads)); + } + + const auto ceil_mode = helper.Get("ceil_mode", 0); + options.set("roundingType", ceil_mode == 0 ? emscripten::val("floor") + : emscripten::val("ceil")); + + emscripten::val output = emscripten::val::object(); + if (is_average_pool) { + output = model_builder.GetBuilder().call("averagePool2d", input, options); + } else { + output = model_builder.GetBuilder().call("maxPool2d", input, options); + } + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. +bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const logging::Logger& logger) const { + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) + << op_type << " only supports rank-4 tensor, input [" + << input_defs[0]->Name() << "] has actual dim count " << input_size; + return false; + } + + if (op_type == "AveragePool" || op_type == "MaxPool") { + NodeAttrHelper helper(node); + const auto storage_order = helper.Get("storage_order", 0); + if (storage_order == 1) { + LOGS(logger, VERBOSE) << "storage_order == 1 is not supported"; + return false; + } + + if (helper.Get("kernel_shape", std::vector{1, 1}).size() != 2) { + LOGS(logger, VERBOSE) << "Only pooling 2d is supported"; + return false; + } + + if (node.OutputDefs().size() != 1) { + LOGS(logger, VERBOSE) << "Argmax in maxpooling is not supported"; + return false; + } + } + + return true; +} + +void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = + { + "GlobalAveragePool", + "GlobalMaxPool", + "AveragePool", + "MaxPool", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc new file mode 100644 index 0000000000000..2e4dc3f4addbb --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/cpu/tensor/reshape_helper.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ReshapeOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; + + // Reshape opset 4- uses attributes for new shape which we do not support for now. + int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; } +}; + +// Add operator related. + +void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); +} + +Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& initializers(model_builder.GetInitializerTensors()); + const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name()); + const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() + ? reinterpret_cast(target_shape_tensor.raw_data().data()) + : target_shape_tensor.int64_data().data(); + + const auto size = target_shape_tensor.dims()[0]; + TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + ReshapeHelper helper(TensorShape(input_shape), target_shape); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + std::vector new_shape; + std::transform(target_shape.cbegin(), target_shape.cend(), + std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + emscripten::val output = model_builder.GetBuilder().call("reshape", + input, emscripten::val::array(new_shape)); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& perm_name = input_defs[1]->Name(); + if (!Contains(initializers, perm_name)) { + LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer"; + return false; + } + + const auto& perm_tensor = *initializers.at(perm_name); + std::vector unpacked_tensor; + auto status = onnxruntime::utils::UnpackInitializerData(perm_tensor, unpacked_tensor); + if (!status.IsOK()) { + LOGS(logger, ERROR) << "Error while unpacking perm_tensor: " << status.ErrorMessage(); + return false; + } + + const int64_t* raw_new_shape = reinterpret_cast(unpacked_tensor.data()); + const auto& perm_dims = perm_tensor.dims(); + if (perm_dims.empty() || perm_dims[0] == 0) { + LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_shape.empty()) { + LOGS(logger, VERBOSE) << "Reshape does not support empty input shape"; + return false; + } + + // WebNN reshape does not support 0 as dimension. + NodeAttrHelper helper(node); + const bool allow_zero = helper.Get("allowzero ", 0) == 1; + if (allow_zero) { + for (int64_t i = 0; i < perm_dims[0]; i++) { + if (raw_new_shape[i] == 0) { + LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled"; + return false; + } + } + } + + return true; +} + +void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc new file mode 100644 index 0000000000000..27e7310002564 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/cpu/tensor/reshape_helper.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ResizeOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; + + // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. + // We only support Resize opset 11+ here. + int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } +}; + +// Helper functions +bool GetResizeScales(const InitializedTensorSet& initializers, + const Node& node, std::vector& scales, + const logging::Logger& logger) { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 3) + return false; + + const auto& scales_tensor = *initializers.at(input_defs[2]->Name()); + if (scales_tensor.dims_size() != 1 || scales_tensor.dims()[0] != 4) + return false; + + std::vector unpacked_tensor; + auto status = onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor); + if (!status.IsOK()) { + LOGS(logger, ERROR) << "Error while unpacking scales_tensor: " << status.ErrorMessage(); + return false; + } + const float* scales_data = reinterpret_cast(unpacked_tensor.data()); + scales = std::vector{scales_data, scales_data + 4}; + return true; +} + +bool GetResizeOutputSizes(const InitializedTensorSet& initializers, + const Node& node, std::vector& sizes, + const logging::Logger& logger) { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 4) + return false; + + const auto& sizes_tensor = *initializers.at(input_defs[3]->Name()); + if (sizes_tensor.dims_size() != 1 || sizes_tensor.dims()[0] != 4) + return false; + + std::vector unpacked_tensor; + auto status = onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor); + if (!status.IsOK()) { + LOGS(logger, ERROR) << "Error while unpacking sizes_tensor: " << status.ErrorMessage(); + return false; + } + const int64_t* sizes_data = reinterpret_cast(unpacked_tensor.data()); + sizes = std::vector{sizes_data, sizes_data + 4}; + return true; +} + +// Add operator related. + +void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // We don't really use ROI here, so add it to skipped list if it's an initializer tensor. + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI + model_builder.AddInputToSkip(node.InputDefs()[1]->Name()); // ROI + + // We will still add scales to the skipped list even sizes are present, + // since there is no use of it, we will not process it later. + model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // scales + model_builder.AddInputToSkip(node.InputDefs()[2]->Name()); // scales + + if (node.InputDefs().size() > 3) { + model_builder.AddInitializerToSkip(node.InputDefs()[3]->Name()); // sizes + model_builder.AddInputToSkip(node.InputDefs()[3]->Name()); // sizes + } +} + +Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + emscripten::val options = emscripten::val::object(); + NodeAttrHelper helper(node); + const auto mode = helper.Get("mode", "nearest"); + if (mode == "linear") { + options.set("mode", emscripten::val("linear")); + } else { // we already checked the mode must be NN or Bilinear in IsOpSupportedImpl. + options.set("mode", emscripten::val("nearest-neighbor")); + } + + const auto& input_defs = node.InputDefs(); + const auto& initializers(model_builder.GetInitializerTensors()); + + std::vector scales; + std::vector sizes; + std::vector scales_hw; + std::vector sizes_hw; + if (input_defs.size() == 3) { // Use scales. + ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); + scales_hw = {scales[2], scales[3]}; + options.set("scales", emscripten::val::array(scales_hw)); + } else { // We already checked number of inputs in IsOpSupportedImpl. + std::vector output_sizes; + ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), + "Error getting resize output_sizes"); + std::transform(output_sizes.cbegin(), output_sizes.cend(), + std::back_inserter(sizes), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + sizes_hw = {sizes[1], sizes[2]}; + options.set("sizes", emscripten::val::array(sizes_hw)); + } + + std::vector axes = {1, 2}; + options.set("axes", emscripten::val::array(axes)); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val output = model_builder.GetBuilder().call("resample2d", input, options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << "Resize only support 4d shape, input is " + << input_size << "d shape"; + return false; + } + + { // Check attributes. + NodeAttrHelper helper(node); + const auto mode = helper.Get("mode", "nearest"); + bool is_linear_resize = mode == "linear"; + bool is_nearest_resize = mode == "nearest"; + if (!is_linear_resize && !is_nearest_resize) { + LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode; + return false; + } + + const auto exclude_outside = helper.Get("exclude_outside", 0); + if (exclude_outside != 0) { + LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; + return false; + } + } + + { // scales and sizes (if present) must be initializers. + if (input_defs.size() < 3) { + LOGS(logger, VERBOSE) << "Input scales or sizes of Resize must be known"; + return false; + } + + // scales + if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) { + LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + return false; + } + + // sizes + if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) { + LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; + return false; + } + + // We want to check if the scales or sizes are not trying to resize on N/C channels here. + if (input_defs.size() == 3) { // We are using scales. + std::vector scales; + if (!GetResizeScales(initializers, node, scales, logger)) + return false; + + float scale_n = scales[0]; + float scale_c = scales[1]; + if (scale_n != 1.0f || scale_c != 1.0f) { + LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" + << "Resize of N/C channels are not supported" + << ", scale_n, " << scale_n << ", scale_c, " << scale_c; + return false; + } + + // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1. + // TODO support ResizeBilinear. + float scale_h = scales[2]; + float scale_w = scales[3]; + + // Onnx spec requires scale to be a positive float, so we are not checking that here. + if (roundf(scale_h) != scale_h) { + LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " is not a whole number"; + return false; + } + + if (roundf(scale_w) != scale_w) { + LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; + return false; + } + } else { + // We are using sizes. + std::vector output_sizes; + if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) + return false; + + bool is_NHWC = input_shape[3] == output_sizes[3]; + auto output_size_n = output_sizes[0]; + const int c_idx = is_NHWC ? 3 : 1; + if (output_size_n != input_shape[0] || output_sizes[c_idx] != input_shape[c_idx]) { + LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " + << "Resize of N/C channels are not supported" + << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n + << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << output_sizes[c_idx]; + return false; + } + + // For now we only support upscale, so the output_size_h and output_size_w should be an integer >= 1. + // TODO support ResizeBilinear + auto output_size_h = output_sizes[2]; + auto output_size_w = output_sizes[3]; + auto input_size_h = input_shape[2]; + auto input_size_w = input_shape[3]; + + // Onnx spec requires output sizes to be a positive integer, so we are not checking that here. + if (output_size_h % input_size_h != 0) { + LOGS(logger, VERBOSE) << "Resize: output_size_h: " << output_size_h + << " is not a multiple of input_size_h: " << input_size_h; + return false; + } + + if (output_size_w % input_size_w != 0) { + LOGS(logger, VERBOSE) << "Resize: output_size_w: " << output_size_w + << " is not a multiple of input_size_w: " << input_size_w; + return false; + } + } + } + + return true; +} + +void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc new file mode 100644 index 0000000000000..eca1521384643 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class TransposeOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; +}; + +// Add operator related. + +Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + std::vector perm = helper.Get("perm", std::vector()); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + auto input_dims = input_shape.size(); + if (perm.empty()) { + for (int64_t i = input_dims - 1; i >= 0; i--) + perm.push_back(i); + } else { + ORT_RETURN_IF_NOT(perm.size() == input_dims, "Perm and input should have same dimension"); + } + + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val options = emscripten::val::object(); + std::vector permutation; + std::transform(perm.cbegin(), perm.cend(), + std::back_inserter(permutation), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + options.set("permutation", emscripten::val::array(permutation)); + emscripten::val output = model_builder.GetBuilder().call("transpose", input, options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc new file mode 100644 index 0000000000000..bc385dea6c2a8 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" +#include "core/common/safeint.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" +#include "core/providers/webnn/builders/helper.h" +#include "model.h" + +namespace onnxruntime { +namespace webnn { + +Model::Model(const emscripten::val& context, const emscripten::val& graph, const logging::Logger& logger) + : wnn_context_(context), + wnn_graph_(graph), + logger_(logger) {} + +Model::~Model() {} + +Status Model::Predict(const InlinedHashMap& inputs, + const InlinedHashMap& outputs) { + for (const auto& input : inputs) { + const std::string& name = input.first; + const struct OnnxTensorData tensor = input.second; + if (tensor.tensor_info.data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The input of graph has unsupported type, name: ", + name, " type: ", tensor.tensor_info.data_type); + } + auto num_elements = SafeInt(Product(tensor.tensor_info.shape)); + emscripten::val view{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers. + wnn_inputs_[name].call("set", view); +#else + wnn_inputs_.set(name, view); +#endif + } + +#ifdef ENABLE_WEBASSEMBLY_THREADS + // This vector uses for recording output buffers from WebNN graph compution when WebAssembly + // multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView, + // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + // and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory + // address is different from the non-shared one, additional memory copy is required here. + InlinedHashMap output_views; +#endif + for (const auto& output : outputs) { + const std::string& name = output.first; + const struct OnnxTensorData tensor = output.second; + if (tensor.tensor_info.data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The input of graph has unsupported type, name: ", + name, " type: ", tensor.tensor_info.data_type); + } + auto num_elements = SafeInt(Product(tensor.tensor_info.shape)); + emscripten::val view{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; +#ifdef ENABLE_WEBASSEMBLY_THREADS + output_views.insert({name, view}); +#else + wnn_outputs_.set(name, view); +#endif + } + wnn_context_.call("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_); + +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer. + for (const auto& output : outputs) { + const std::string& name = output.first; + emscripten::val view = output_views.at(name); + view.call("set", wnn_outputs_[name]); + } +#endif + return Status::OK(); +} + +bool Model::IsScalarOutput(const std::string& output_name) const { + return Contains(scalar_outputs_, output_name); +} + +const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { + return input_output_info_.at(name); +} + +void Model::SetInputMap(InlinedHashMap&& input_map) { + input_map_ = std::move(input_map); +} + +void Model::SetOutputMap(InlinedHashMap&& output_map) { + output_map_ = std::move(output_map); +} + +// Pre-allocate the input and output buffers for the WebNN graph. +void Model::AllocateInputOutputBuffers() { + for (const auto& input : inputs_) { + const auto& input_info = input_output_info_.at(input); + const auto input_shape = input_info.shape; + const auto num_elements = SafeInt(Product(input_shape)); + wnn_inputs_.set(input, + emscripten::val::global("Float32Array").new_(static_cast(num_elements))); + } + for (const auto& output : outputs_) { + const auto& output_info = input_output_info_.at(output); + const auto output_shape = output_info.shape; + const auto num_elements = SafeInt(Product(output_shape)); + wnn_outputs_.set(output, + emscripten::val::global("Float32Array").new_(static_cast(num_elements))); + } +} + +size_t Model::GetMappedInputIdx(const std::string& name) const { + return input_map_.at(name); +} + +size_t Model::GetMappedOutputIdx(const std::string& name) const { + return output_map_.at(name); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model.h b/onnxruntime/core/providers/webnn/builders/model.h new file mode 100644 index 0000000000000..4af82a2675691 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/model.h @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/common/status.h" +#include "core/platform/ort_mutex.h" + +#include +#include + +namespace onnxruntime { +namespace webnn { + +struct OnnxTensorInfo { + const int32_t data_type; // Uses TensorProto::DataType. + const std::vector shape; +}; + +struct OnnxTensorData { + OnnxTensorInfo tensor_info; + void* buffer{nullptr}; +}; + +class Model { + friend class ModelBuilder; + + public: + ~Model(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); + + onnxruntime::common::Status Predict(const InlinedHashMap& inputs, + const InlinedHashMap& outputs); + + bool IsScalarOutput(const std::string& output_name) const; + + // Mutex for exclusive lock to this model object. + OrtMutex& GetMutex() { return mutex_; } + + // Input and output names in the onnx model's order. + const std::vector& GetInputs() const { return inputs_; } + void SetInputs(std::vector&& inputs) { inputs_ = std::move(inputs); } + + const std::vector& GetOutputs() const { return outputs_; } + void SetOutputs(std::vector&& outputs) { outputs_ = std::move(outputs); } + + const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const; + + // Set the mapping between input/output name and ORT kernel context + // input/output index, at execution time. + void SetInputMap(InlinedHashMap&& input_map); + void SetOutputMap(InlinedHashMap&& output_map); + + // Get the ORT kernel context input/output index with given name. + size_t GetMappedInputIdx(const std::string& name) const; + size_t GetMappedOutputIdx(const std::string& name) const; + + private: + emscripten::val wnn_context_ = emscripten::val::object(); + emscripten::val wnn_graph_ = emscripten::val::object(); + const logging::Logger& logger_; + + emscripten::val wnn_inputs_ = emscripten::val::object(); + emscripten::val wnn_outputs_ = emscripten::val::object(); + + InlinedHashSet scalar_outputs_; + + std::vector inputs_; + std::vector outputs_; + + InlinedHashMap input_output_info_; + + InlinedHashMap input_map_; + InlinedHashMap output_map_; + + OrtMutex mutex_; + + Model(const emscripten::val& context, const emscripten::val& path, const logging::Logger& logger); + + void SetInputOutputInfo(InlinedHashMap&& input_output_info) { + input_output_info_ = std::move(input_output_info); + } + + void SetScalarOutputs(InlinedHashSet&& scalar_outputs) { + scalar_outputs_ = std::move(scalar_outputs); + } + + void AllocateInputOutputBuffers(); +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc new file mode 100644 index 0000000000000..1b3ba5fcde88e --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "model_builder.h" +#include "model.h" +#include "helper.h" +#include "op_builder_factory.h" + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace webnn { + +ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + const emscripten::val& context, const emscripten::val& builder) + : graph_viewer_(graph_viewer), + logger_(logger), + wnn_context_(context), + wnn_builder_(builder) {} + +Status ModelBuilder::Initialize() { + PreprocessInitializers(); + PreprocessActivations(); + ORT_RETURN_IF_ERROR(RegisterInitializers()); + ORT_RETURN_IF_ERROR(RegisterModelInputs()); + ORT_RETURN_IF_ERROR(AddOperations()); + ORT_RETURN_IF_ERROR(RegisterModelOutputs()); + + return Status::OK(); +} + +/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) { + const auto& op_builders = GetOpBuilders(); + const auto it = op_builders.find(node.OpType()); + if (it != op_builders.cend()) + return it->second; + + return nullptr; +} + +void ModelBuilder::PreprocessInitializers() { + const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + for (size_t i = 0; i < node_indices.size(); i++) { + const auto* node(graph_viewer_.GetNode(node_indices[i])); + if (const auto* op_builder = GetOpBuilder(*node)) { + op_builder->AddInitializersToSkip(*this, *node); + } + } +} + +emscripten::val GetClampOperator( + const emscripten::val& builder, float min_value, float max_value) { + emscripten::val options = emscripten::val::object(); + options.set("minValue", min_value); + options.set("maxValue", max_value); + return builder.call("clamp", options); +} + +void ModelBuilder::PreprocessActivations() { + const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + for (size_t i = 0; i < node_indices.size(); i++) { + const auto* node(graph_viewer_.GetNode(node_indices[i])); + const auto& op_type(node->OpType()); + + if (op_type == "Relu") { + activation_nodes_.emplace(node->Index(), wnn_builder_.call("relu")); + } else if (op_type == "LeakyRelu") { + NodeAttrHelper helper(*node); + emscripten::val options = emscripten::val::object(); + options.set("alpha", helper.Get("alpha", (float)0.0)); + activation_nodes_.emplace(node->Index(), wnn_builder_.call("leakyRelu", options)); + } else if (op_type == "Sigmoid") { + activation_nodes_.emplace(node->Index(), wnn_builder_.call("sigmoid")); + } else if (op_type == "Clip") { + float minValue, maxValue; + GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); + activation_nodes_.emplace(node->Index(), GetClampOperator(wnn_builder_, minValue, maxValue)); + } + } +} + +Status ModelBuilder::RegisterInitializers() { + for (const auto& pair : GetInitializerTensors()) { + const auto& tensor = *pair.second; + const auto& name = tensor.name(); + if (Contains(skipped_initializers_, name)) + continue; + + const auto& shape = tensor.dims(); + std::vector dims; + if (shape.empty()) { + // This is a scalar initializer, WebNN requires a shape, make this a {1} tensor. + dims = {1}; + } else { + std::transform(shape.cbegin(), shape.cend(), + std::back_inserter(dims), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + } + + emscripten::val desc = emscripten::val::object(); + desc.set("dimensions", emscripten::val::array(dims)); + auto data_type = tensor.data_type(); + emscripten::val operand = emscripten::val::object(); + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + unpacked_tensors_.push_back({}); + std::vector& unpacked_tensor = unpacked_tensors_.back(); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); + auto num_elements = SafeInt(Product(tensor.dims())); + desc.set("type", emscripten::val("float32")); + emscripten::val view{emscripten::typed_memory_view(num_elements, + reinterpret_cast(unpacked_tensor.data()))}; +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Workaround for WebAssembly multi-threads enabled since WebNN API only accepts non-shared ArrayBufferView. + // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + operand = wnn_builder_.call("constant", desc, view.call("slice")); +#else + operand = wnn_builder_.call("constant", desc, view); +#endif + + } else { + // TODO: support other type. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The initializer of graph has unsupported type, name: ", + tensor.name(), " type: ", data_type); + } + wnn_operands_.insert(std::make_pair(name, operand)); + } + + return Status::OK(); +} + +Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_input) { + const auto& name = node_arg.Name(); + const std::string input_output_type = is_input ? "input" : "output"; + + if (is_input) { + // Input should not be an initializer. + if (Contains(GetInitializerTensors(), name)) + return Status::OK(); + + // This input will not be used. + if (Contains(skipped_inputs_, name)) + return Status::OK(); + } + + std::vector dims; + { // input_output shape. + const auto* shape_proto = node_arg.Shape(); + ORT_RETURN_IF(shape_proto == nullptr, + "shape_proto cannot be null for ", input_output_type, ": ", name); + const auto& shape = shape_proto->dim(); + if (shape.empty()) { + // If we have an empty shape, this is a scalar input. + dims.push_back(1); + + // We need to change the shapes of these scalar outputs back to {} + // when WebNN EP returns these values to ORT. + if (!is_input) { + AddScalarOutput(name); + } + } else { + dims.reserve(shape.size()); + for (const auto& dim : shape) { + if (!dim.has_dim_value()) { + // FIXME: support dyanmic shape. + dims.push_back(1); + } else { + dims.push_back(SafeInt(dim.dim_value())); + } + } + } + } + + emscripten::val desc = emscripten::val::object(); + + desc.set("dimensions", emscripten::val::array(dims)); + + int32_t data_type; + { // type + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->tensor_type().has_elem_type()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The ", input_output_type, " of graph doesn't have elem_type: ", name); + } + + data_type = type_proto->tensor_type().elem_type(); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + desc.set("type", emscripten::val("float32")); + break; + default: { + // TODO: support other type. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The ", input_output_type, " of graph doesn't have valid type, name: ", name, + " type: ", type_proto->tensor_type().elem_type()); + } + } + } + + if (is_input) { + wnn_operands_.insert(std::make_pair(name, wnn_builder_.call("input", name, desc))); + input_names_.push_back(name); + } else { + output_names_.push_back(name); + } + + std::vector shape; + std::transform(dims.cbegin(), dims.cend(), + std::back_inserter(shape), + [](int32_t dim) -> int64_t { return SafeInt(dim); }); + input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape}); + + return Status::OK(); +} + +Status ModelBuilder::RegisterModelInputs() { + for (const auto* node_arg : graph_viewer_.GetInputs()) { + ORT_RETURN_IF_ERROR(RegisterModelInputOutput(*node_arg, true /* is_input */)); + } + + return Status::OK(); +} + +Status ModelBuilder::AddOperations() { + const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + for (size_t i = 0; i < node_indices.size(); i++) { + const auto* node(graph_viewer_.GetNode(node_indices[i])); + if (const auto* op_builder = GetOpBuilder(*node)) { + ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node, logger_)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Node [", node->Name(), "], type [", node->OpType(), "] is not supported"); + } + } + + return Status::OK(); +} + +Status ModelBuilder::AddOperandFromPersistMemoryBuffer( + const std::string& name, const void* buffer, const size_t size, + const std::vector shape, const size_t element_size) { + auto persist_buffer = std::make_unique(size); + uint8_t* dest = persist_buffer.get(); + memcpy(dest, buffer, size); + emscripten::val view{emscripten::typed_memory_view(size / element_size, reinterpret_cast(dest))}; + emscripten::val desc = emscripten::val::object(); + desc.set("dimensions", emscripten::val::array(shape)); + desc.set("type", emscripten::val("float32")); + emscripten::val operand = emscripten::val::object(); +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Workaround for WebAssembly multi-threads enabled since WebNN API only accepts non-shared ArrayBufferView. + // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + operand = wnn_builder_.call("constant", desc, view.call("slice")); +#else + operand = wnn_builder_.call("constant", desc, view); +#endif + AddOperand(name, operand); + mem_persist_buffers_.push_back(std::move(persist_buffer)); + return Status::OK(); +} + +Status ModelBuilder::RegisterModelOutputs() { + for (const auto* node_arg : graph_viewer_.GetOutputs()) { + ORT_RETURN_IF_ERROR(RegisterModelInputOutput(*node_arg, false /* is_input */)); + } + + return Status::OK(); +} + +Status ModelBuilder::Compile(std::unique_ptr& model) { + ORT_RETURN_IF_ERROR(Initialize()); + emscripten::val named_operands = emscripten::val::object(); + for (auto& name : output_names_) { + named_operands.set(name, wnn_operands_.at(name)); + } + emscripten::val wnn_graph = wnn_builder_.call("buildSync", named_operands); + if (!wnn_graph.as()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); + } + model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); + model->SetInputs(std::move(input_names_)); + model->SetOutputs(std::move(output_names_)); + model->SetScalarOutputs(std::move(scalar_outputs_)); + model->SetInputOutputInfo(std::move(input_output_info_)); +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Pre-allocate the input and output tensors for the WebNN graph + // when WebAssembly multi-threads is enabled since WebNN API only + // accepts non-shared ArrayBufferView. + // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + model->AllocateInputOutputBuffers(); +#endif + return Status::OK(); +} + +// supported_nodes is provided by the op to indicate whether it can be fused with the activation node. +emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output, + const InlinedHashSet supported_nodes) { + emscripten::val fused_op = emscripten::val::null(); + for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { + const auto& dst_node = it->GetNode(); + const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()]; + if (!Contains(supported_nodes, dst_node.OpType())) { + return emscripten::val::null(); + } + if (Contains(activation_nodes_, dst_node.Index())) { + if (&output == dst_input) { + fused_op = activation_nodes_.at(dst_node.Index()); + } + } else { + // If there is any other non-relu node using the output + // will add relu separately. + if (&output == dst_input) { + return emscripten::val::null(); + } + } + } + + // If output is a graph output, will add relu separately. + if (fused_op != emscripten::val::null()) { + for (const auto* graph_output : graph_viewer_.GetOutputs()) { + if (&output == graph_output) { + return emscripten::val::null(); + } + } + + LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType() + << "], fused the output [" << output.Name() << "]"; + + fused_activations_.insert(output.Name()); + } + + return fused_op; +} + +void ModelBuilder::AddScalarOutput(const std::string& output_name) { + scalar_outputs_.insert(output_name); +} + +void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& operand) { + wnn_operands_.insert(std::make_pair(name, operand)); +} + +void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { + skipped_initializers_.insert(tensor_name); +} + +void ModelBuilder::AddInputToSkip(const std::string& input_name) { + skipped_inputs_.insert(input_name); +} + +std::string ModelBuilder::GetUniqueName(const std::string& base_name) { + std::string unique_name; + do { + std::ostringstream os; + os << base_name << "_token_" << name_token_++; + unique_name = os.str(); + } while (Contains(unique_names_, unique_name)); + + return unique_name; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h new file mode 100644 index 0000000000000..56cd756688da5 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include + +#include "model.h" + +#include +#include + +namespace onnxruntime { +namespace webnn { + +class IOpBuilder; + +class ModelBuilder { + public: + ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + const emscripten::val& context, const emscripten::val& builder); + ~ModelBuilder() = default; + + Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; + + // Accessors for members. + const GraphViewer& GetGraphViewer() const { return graph_viewer_; } + const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } + + const emscripten::val& GetBuilder() const { return wnn_builder_; } + const emscripten::val& GetContext() const { return wnn_context_; } + const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } + void AddOperand(const std::string& name, const emscripten::val& operand); + // Use the buffers to persist WebNN allocated data like transposed weight. + // It ensures the validity during inference session. + std::vector> mem_persist_buffers_; + // Add a constant operand (allocate persist buffer and move the ownership to mem_persist_buffers_). + Status AddOperandFromPersistMemoryBuffer( + const std::string& name, const void* buffer, + const size_t size, const std::vector shape, const size_t element_size = 4); + // Find if an output has a fuseable activation (e.g., Relu). + emscripten::val FindActivation(const Node& node, const NodeArg& output, + const InlinedHashSet supported_nodes = {}); + + const InlinedHashSet& + GetFusedActivations() const { return fused_activations_; } + + // The initializer will be processed separately, skip it as an initializer. + void AddInitializerToSkip(const std::string& tensor_name); + + // There are some input which will not be used, add it to a list which will not + // be added to CoreML model, since CoreML does not like input unused. + void AddInputToSkip(const std::string& input_name); + + std::string GetUniqueName(const std::string& base_name); + + private: + const GraphViewer& graph_viewer_; + const logging::Logger& logger_; + + emscripten::val wnn_context_ = emscripten::val::object(); + emscripten::val wnn_builder_ = emscripten::val::object(); + std::vector> unpacked_tensors_; + InlinedHashMap wnn_operands_; + std::vector input_names_; + std::vector output_names_; + + InlinedHashSet scalar_outputs_; + InlinedHashMap input_output_info_; + + InlinedHashSet skipped_initializers_; + InlinedHashSet skipped_inputs_; + + InlinedHashSet fused_activations_; + + uint32_t name_token_{0}; + InlinedHashSet unique_names_; + + // All activation nodes (e.g., Relu) as a map . + InlinedHashMap activation_nodes_; + + // Convert the onnx model to WebNN operands + Status Initialize() ORT_MUST_USE_RESULT; + + void PreprocessInitializers(); + // Preprocess all the activation nodes (e.g., Relu) for easy query later. + void PreprocessActivations(); + + // Copy and process all the initializers to WebNN constants. + Status RegisterInitializers() ORT_MUST_USE_RESULT; + + Status AddOperations() ORT_MUST_USE_RESULT; + Status RegisterModelInputs() ORT_MUST_USE_RESULT; + Status RegisterModelOutputs() ORT_MUST_USE_RESULT; + Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input) ORT_MUST_USE_RESULT; + + // Record the onnx scalar output names. + void AddScalarOutput(const std::string& output_name); + + static const IOpBuilder* GetOpBuilder(const Node& node); +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder.h b/onnxruntime/core/providers/webnn/builders/op_builder.h new file mode 100644 index 0000000000000..efa70ab3d5455 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/op_builder.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace webnn { + +class ModelBuilder; + +class IOpBuilder { + public: + virtual ~IOpBuilder() = default; + + // Add operator related. + public: + // Check if the initializers of this operator need preprocess, + // which will not be copied. + virtual void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const = 0; + + // Add the operator to WebNN model. + virtual Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0; + + // Operator support related. + public: + // Check if an operator is supported. + virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const = 0; +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc new file mode 100644 index 0000000000000..c13677547ff3b --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include + +#include "op_builder_factory.h" + +namespace onnxruntime { +namespace webnn { + +static OpBuilderRegistrations CreateOpBuilderRegistrations() { + OpBuilderRegistrations op_registrations; + + { // Binary + CreateBinaryOpBuilder("Add", op_registrations); + CreateBinaryOpBuilder("Sub", op_registrations); + CreateBinaryOpBuilder("Mul", op_registrations); + CreateBinaryOpBuilder("Div", op_registrations); + } + + { // Activations + CreateActivationOpBuilder("Relu", op_registrations); + CreateActivationOpBuilder("LeakyRelu", op_registrations); + CreateActivationOpBuilder("Sigmoid", op_registrations); + } + + { // Clip + CreateClipOpBuilder("Clip", op_registrations); + } + + { // Conv + CreateConvOpBuilder("Conv", op_registrations); + CreateConvOpBuilder("ConvTranspose", op_registrations); + } + + { // Concat + CreateConcatOpBuilder("Concat", op_registrations); + } + + { // Gemm + CreateGemmOpBuilder("Gemm", op_registrations); + } + + { // Pool + CreatePoolOpBuilder("GlobalAveragePool", op_registrations); + CreatePoolOpBuilder("GlobalMaxPool", op_registrations); + CreatePoolOpBuilder("AveragePool", op_registrations); + CreatePoolOpBuilder("MaxPool", op_registrations); + } + + { // Reshape + CreateReshapeOpBuilder("Reshape", op_registrations); + } + + { // Resize + CreateResizeOpBuilder("Resize", op_registrations); + } + + { // Transpose + CreateTransposeOpBuilder("Transpose", op_registrations); + } + + return op_registrations; +} + +const InlinedHashMap& GetOpBuilders() { + static const OpBuilderRegistrations op_registrations = CreateOpBuilderRegistrations(); + return op_registrations.op_builder_map; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h new file mode 100644 index 0000000000000..ffbbf2d92d432 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "op_builder.h" + +namespace onnxruntime { +namespace webnn { + +struct OpBuilderRegistrations { + std::vector> builders; + InlinedHashMap op_builder_map; +}; + +// Get the lookup table with IOpBuilder delegates for different onnx operators. +// Note, the lookup table should have same number of entries as the result of CreateOpSupportCheckers() +// in op_support_checker.h. +const InlinedHashMap& GetOpBuilders(); + +void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc new file mode 100644 index 0000000000000..78b32f57a74f5 --- /dev/null +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -0,0 +1,388 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "webnn_execution_provider.h" + +#include "core/framework/allocatormgr.h" +#include "core/framework/compute_capability.h" +#include "core/framework/memcpy.h" +#include "core/framework/kernel_registry.h" +#include "core/graph/graph_viewer.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/common/safeint.h" + +#include "builders/model.h" +#include "builders/helper.h" +#include "builders/model_builder.h" + +namespace onnxruntime { + +constexpr const char* WEBNN = "WebNN"; + +WebNNExecutionProvider::WebNNExecutionProvider(uint32_t webnn_device_flags, uint32_t webnn_power_flags) + : IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} { + AllocatorCreationInfo device_info( + [](int) { + return std::make_unique(OrtMemoryInfo(WEBNN, OrtAllocatorType::OrtDeviceAllocator)); + }); + + InsertAllocator(CreateAllocator(device_info)); + + AllocatorCreationInfo cpu_memory_info( + [](int) { + return std::make_unique( + OrtMemoryInfo(WEBNN, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); + }); + + InsertAllocator(CreateAllocator(cpu_memory_info)); + + // Create WebNN context and graph builder. + InlinedHashMap device_type_name_s = { + {0, "auto"}, {1, "gpu"}, {2, "cpu"}}; + InlinedHashMap power_preference_name_s = { + {0, "auto"}, {1, "high-performance"}, {2, "low-power"}}; + std::string device_type_name_ = device_type_name_s[webnn_device_flags]; + std::string power_preference_name_ = power_preference_name_s[webnn_power_flags]; + const emscripten::val ml = emscripten::val::global("navigator")["ml"]; + if (!ml.as()) { + ORT_THROW("Failed to get ml from navigator."); + } + emscripten::val context_options = emscripten::val::object(); + // Currently WebNN implementation in Chromium temporarily reuses the MLContextOptions + // defined in Model Loader API, which uses MLDevicePreference instead of MLDeviceType + // defined in WebNN. Because there's an ongoing spec discussion to simplify this API at + // https://github.com/webmachinelearning/webnn/issues/302. + context_options.set("devicePreference", emscripten::val(device_type_name_)); + context_options.set("powerPreference", emscripten::val(power_preference_name_)); + wnn_context_ = ml.call("createContextSync", context_options); + if (!wnn_context_.as()) { + ORT_THROW("Failed to create WebNN context."); + } + wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + if (!wnn_builder_.as()) { + ORT_THROW("Failed to create WebNN builder."); + } +} + +WebNNExecutionProvider::~WebNNExecutionProvider() {} + +std::vector> +WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_registries*/) const { + std::vector> result; + + // We do not run WebNN EP on subgraph, instead we cover this in the control flow nodes. + // TODO investigate whether we want to support subgraph using WebNN EP. + if (graph_viewer.IsSubgraph()) { + return result; + } + + /* + Very basic search for groups of nodes that can be handled by the EP. + This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP + but B is between them in the topological sort as you'll get two single node capabilities. However if can also + be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected. + Not sure how often each of these scenarios happens. + + A B C + | / | + D E + | | + + Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order, + accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all + connected nodes that can be handled are grouped together. + */ + + const auto& logger = *GetLogger(); + + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, logger); + + if (node_groups.empty()) { + return result; + } + + const auto& graph_output_list = graph_viewer.GetOutputs(); + InlinedHashSet graph_outputs(graph_output_list.cbegin(), graph_output_list.cend()); + + size_t num_of_supported_nodes = 0; + for (const auto& group : node_groups) { + if (group.empty()) + continue; + + num_of_supported_nodes += group.size(); + LOGS(logger, VERBOSE) << "WebNNExecutionProvider::GetCapability, current supported node group size: " + << group.size(); + + InlinedHashSet node_set; + node_set.reserve(group.size()); + for (const auto& index : group) { + node_set.insert(index); + } + + std::unique_ptr sub_graph = std::make_unique(); + + InlinedHashSet node_outputs; + InlinedHashSet subgraph_inputs; + InlinedHashSet subgraph_outputs; + std::vector ordered_subgraph_inputs; + std::vector ordered_subgraph_outputs; + + for (const auto& index : group) { + sub_graph->nodes.push_back(index); + const auto* node = graph_viewer.GetNode(index); + + for (const auto* input : node->InputDefs()) { + // if the node input was not produced by this subgraph, add it to the subgraph inputs. + if (node_outputs.count(input) == 0) { + if (subgraph_inputs.count(input) == 0) { + subgraph_inputs.insert(input); + ordered_subgraph_inputs.push_back(input); + } + } + } + + const auto& output_defs = node->OutputDefs(); + for (const auto* output_def : output_defs) { + node_outputs.insert(output_def); + // if output is overall graph output we need to produce it. + if (graph_outputs.count(output_def) != 0) { + ordered_subgraph_outputs.push_back(output_def); + } + } + + // if output connects to a node not in this subgraph we need to produce it. + for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { + if (node_set.count(it->GetNode().Index()) == 0) { + const auto* output_def = output_defs[it->GetSrcArgIndex()]; + if (subgraph_outputs.count(output_def) == 0) { + subgraph_outputs.insert(output_def); + ordered_subgraph_outputs.push_back(output_def); + } + } + } + } + + // Assign inputs and outputs to subgraph's meta_def. + uint64_t model_hash; + int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); + meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); + meta_def->domain = kMSDomain; + meta_def->since_version = 1; + meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + + for (const auto& input : ordered_subgraph_inputs) { + meta_def->inputs.push_back(input->Name()); + } + + for (const auto& output : ordered_subgraph_outputs) { + meta_def->outputs.push_back(output->Name()); + } + + sub_graph->SetMetaDef(std::move(meta_def)); + + result.push_back(std::make_unique(std::move(sub_graph))); + } + + auto num_of_partitions = result.size(); + const auto summary_msg = MakeString( + "WebNNExecutionProvider::GetCapability,", + " number of partitions supported by WebNN: ", num_of_partitions, + " number of nodes in the graph: ", graph_viewer.NumberOfNodes(), + " number of nodes supported by WebNN: ", num_of_supported_nodes); + + // If the graph is partitioned in multiple subgraphs, and this may impact performance, + // we want to give users a summary message at warning level. + if (num_of_partitions > 1) { + LOGS(logger, WARNING) << summary_msg; + } else { + LOGS(logger, INFO) << summary_msg; + } + + return result; +} + +common::Status WebNNExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { + Node& fused_node = fused_node_and_graph.fused_node; + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + + webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_, wnn_builder_); + std::unique_ptr model; + ORT_RETURN_IF_ERROR(builder.Compile(model)); + // Build map from input name to its index in input definitions. + { + InlinedHashMap input_map; + const auto& input_defs = fused_node.InputDefs(); + input_map.reserve(input_defs.size()); + for (size_t i = 0, end = input_defs.size(); i < end; ++i) { + input_map[input_defs[i]->Name()] = i; + } + model->SetInputMap(std::move(input_map)); + } + // Build map from output name to its index in output definitions. + { + InlinedHashMap output_map; + const auto& output_defs = fused_node.OutputDefs(); + output_map.reserve(output_defs.size()); + for (size_t i = 0, end = output_defs.size(); i < end; ++i) { + output_map[output_defs[i]->Name()] = i; + } + model->SetOutputMap(std::move(output_map)); + } + models_.emplace(fused_node.Name(), std::move(model)); + + NodeComputeInfo compute_info; + compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) { + *state = models_[context->node_name].get(); + return 0; + }; + + compute_info.release_state_func = [](FunctionState state) { + // The `state` is a webnn::model managed by unique_ptr. + ORT_UNUSED_PARAMETER(state); + }; + + compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + Ort::KernelContext ctx(context); + + const size_t num_inputs = ctx.GetInputCount(); + const size_t num_outputs = ctx.GetOutputCount(); + + // Ort::CustomOpApi ort{*api}; + webnn::Model* model = reinterpret_cast(state); + + const auto& model_inputs = model->GetInputs(); + const auto& model_outputs = model->GetOutputs(); + + ORT_RETURN_IF_NOT(model_inputs.size() <= num_inputs, "Inconsistent input sizes"); + ORT_RETURN_IF_NOT(model_outputs.size() == num_outputs, "Inconsistent output sizes"); + + InlinedHashMap inputs; + inputs.reserve(model_inputs.size()); + for (size_t i = 0; i < model_inputs.size(); i++) { + const auto& input_name = model_inputs[i]; + auto input_idx = model->GetMappedInputIdx(input_name); + auto input_tensor = ctx.GetInput(input_idx); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + // If we have an empty shape, this is a scalar input, + // Since all the input output of WebNN EP is MultiArray, we will make the scalar input as a {1} MultiArray. + if (shape.empty()) + shape.push_back(1); + std::vector temp(shape.size()); + transform(shape.begin(), shape.end(), temp.begin(), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + const void* inputBuffer = const_cast(input_tensor.GetTensorRawData()); + inputs.emplace( + input_name, + webnn::OnnxTensorData{ + webnn::OnnxTensorInfo{tensor_info.GetElementType(), shape}, + const_cast(inputBuffer), + }); + } + + // From this point we will need to take the exclusive lock on the model until the Predict is + // performed, to block other threads to perform Predict on the same model. + // TODO, investigate concurrent runs for different executions from the same model. + { + std::unique_lock lock(model->GetMutex()); + InlinedHashMap outputs; + outputs.reserve(model_outputs.size()); + for (size_t i = 0; i < model_outputs.size(); i++) { + const auto& output_name = model_outputs[i]; + const auto& output_info = model->GetInputOutputInfo(output_name); + auto output_shape = output_info.shape; + auto output_type = output_info.data_type; + + // Since WebNN EP use {1} tensor as scalar, if the model output should have empty shape. + // We are going to replace the {1} shape of the output back to {}. + if (model->IsScalarOutput(output_name)) + output_shape.clear(); + + auto output_tensor = + ctx.GetOutput(i, output_shape.data(), output_shape.size()); + + void* output_buffer; + switch (output_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + output_buffer = output_tensor.GetTensorMutableRawData(); + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unsupported type: ", output_type, " for output: ", output_name); + break; + } + + outputs.emplace(output_name, + webnn::OnnxTensorData{ + webnn::OnnxTensorInfo{output_type, output_shape}, + output_buffer, + }); + } + + return model->Predict(inputs, outputs); + } + }; + + node_compute_funcs.push_back(compute_info); + } + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kWebNNExecutionProvider, + KernelDefBuilder() + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kWebNNExecutionProvider, + KernelDefBuilder() + .OutputMemoryType(OrtMemTypeCPUOutput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME( + kWebNNExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME( + kWebNNExecutionProvider, kOnnxDomain, 1, MemcpyToHost); + +static void RegisterWebNNKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : function_table) { + ORT_ENFORCE(kernel_registry.Register(function_table_entry()).IsOK()); + } +} + +std::shared_ptr GetWebNNKernelRegistry() { + std::shared_ptr kernel_registry = + std::make_shared(); + RegisterWebNNKernels(*kernel_registry); + + return kernel_registry; +} + +std::shared_ptr +WebNNExecutionProvider::GetKernelRegistry() const { + static std::shared_ptr kernel_registry = + onnxruntime::GetWebNNKernelRegistry(); + return kernel_registry; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h new file mode 100644 index 0000000000000..a67f3793a326d --- /dev/null +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/execution_provider.h" + +#include +#include + +namespace onnxruntime { +namespace webnn { +class Model; +} + +class WebNNExecutionProvider : public IExecutionProvider { + public: + WebNNExecutionProvider(uint32_t webnn_device_flags, uint32_t webnn_power_flags); + virtual ~WebNNExecutionProvider(); + + std::vector> + GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_registries*/) const override; + + DataLayout GetPreferredLayout() const override { return DataLayout::NHWC; } + + // We implement the Compile that takes FusedNodeAndGraph instances. + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } + + // WebNN does not support concurrent execution of a kernel. + bool ConcurrentRunSupported() const override { return false; } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + common::Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; +#endif + + std::shared_ptr GetKernelRegistry() const override; + + private: + emscripten::val wnn_context_ = emscripten::val::object(); + emscripten::val wnn_builder_ = emscripten::val::object(); + + InlinedHashMap> models_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc new file mode 100644 index 0000000000000..8294852479a99 --- /dev/null +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webnn/webnn_provider_factory_creator.h" +#include "core/session/abi_session_options_impl.h" +#include "webnn_execution_provider.h" + +using namespace onnxruntime; + +namespace onnxruntime { +struct WebNNProviderFactory : IExecutionProviderFactory { + WebNNProviderFactory(uint32_t webnn_device_flags, uint32_t webnn_power_flags) + : webnn_device_flags_(webnn_device_flags), webnn_power_flags_(webnn_power_flags) {} + ~WebNNProviderFactory() override {} + + std::unique_ptr CreateProvider() override; + + uint32_t webnn_device_flags_; + uint32_t webnn_power_flags_; +}; + +std::unique_ptr WebNNProviderFactory::CreateProvider() { + return std::make_unique(webnn_device_flags_, webnn_power_flags_); +} + +std::shared_ptr WebNNProviderFactoryCreator::Create( + const ProviderOptions& provider_options) { + uint32_t webnn_device_flags = 2, webnn_power_flags = 0; + if (auto it = provider_options.find("deviceType"); it != provider_options.end()) { + webnn_device_flags = std::stoi(it->second); + } + if (auto it = provider_options.find("powerPreference"); it != provider_options.end()) { + webnn_power_flags = std::stoi(it->second); + } + return std::make_shared(webnn_device_flags, webnn_power_flags); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory_creator.h b/onnxruntime/core/providers/webnn/webnn_provider_factory_creator.h new file mode 100644 index 0000000000000..d2d21358e0e48 --- /dev/null +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory_creator.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/provider_options.h" +#include "core/providers/providers.h" + +namespace onnxruntime { + +struct WebNNProviderFactoryCreator { + static std::shared_ptr Create(const ProviderOptions& provider_options); +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 8cadccd0ef376..28bd00a7047cc 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -83,6 +83,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(XnnpackProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "WEBNN") == 0) { +#if defined(USE_WEBNN) + std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "2"); + std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "0"); + provider_options["deviceType"] = deviceType; + provider_options["powerPreference"] = powerPreference; + options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); #endif } else if (strcmp(provider_name, "AZURE") == 0) { #if defined(USE_AZURE) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 228f39beae84b..39ac971616908 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -477,6 +477,7 @@ def convert_arg_line_to_args(self, arg_line): help="Build with OpenVINO for specific hardware.", ) parser.add_argument("--use_coreml", action="store_true", help="Build with CoreML support.") + parser.add_argument("--use_webnn", action="store_true", help="Build with WebNN support.") parser.add_argument("--use_snpe", action="store_true", help="Build with SNPE support.") parser.add_argument("--snpe_root", help="Path to SNPE SDK root.") parser.add_argument("--use_nnapi", action="store_true", help="Build with NNAPI support.") @@ -979,6 +980,7 @@ def generate_build_tree( "-Donnxruntime_ENABLE_CUDA_PROFILING=" + ("ON" if args.enable_cuda_profiling else "OFF"), "-Donnxruntime_ENABLE_ROCM_PROFILING=" + ("ON" if args.enable_rocm_profiling else "OFF"), "-Donnxruntime_USE_XNNPACK=" + ("ON" if args.use_xnnpack else "OFF"), + "-Donnxruntime_USE_WEBNN=" + ("ON" if args.use_webnn else "OFF"), "-Donnxruntime_USE_CANN=" + ("ON" if args.use_cann else "OFF"), ] @@ -1194,6 +1196,11 @@ def generate_build_tree( if args.use_coreml: cmake_args += ["-Donnxruntime_USE_COREML=ON"] + if args.use_webnn: + if not args.build_wasm: + raise BuildError("WebNN is only available for WASM build.") + cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] + if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] From 74ee55a4ccaf1c5f085057100c2f9139e6f05617 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 27 Apr 2023 10:09:03 +0800 Subject: [PATCH 2/7] Fixed lint errors --- js/web/lib/wasm/session-options.ts | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index abd27e8f7f5ac..fba1e9790d95d 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -69,17 +69,23 @@ const setExecutionProviders = if (typeof ep !== 'string') { const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; if (webnnOptions?.deviceType) { - const keyDataOffset = allocWasmString("deviceType", allocs); + const keyDataOffset = allocWasmString('deviceType', allocs); const valueDataOffset = allocWasmString(webnnOptions.deviceType.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - throw new Error(`Can't set a session config entry: "deviceType" - ${webnnOptions.deviceType}`); + if (getInstance()._OrtAddSessionConfigEntry( + sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + throw new Error( + `Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}` + ); } } if (webnnOptions?.powerPreference) { - const keyDataOffset = allocWasmString("powerPreference", allocs); + const keyDataOffset = allocWasmString('powerPreference', allocs); const valueDataOffset = allocWasmString(webnnOptions.powerPreference.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - throw new Error(`Can't set a session config entry: "powerPreference" - ${webnnOptions.powerPreference}`); + if (getInstance()._OrtAddSessionConfigEntry( + sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + throw new Error( + `Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}` + ); } } } From 6db9969bb9ab0a2f2193853ecbd89af1f0bd5620 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 27 Apr 2023 12:13:43 +0800 Subject: [PATCH 3/7] Did npm run format --- js/common/lib/inference-session.ts | 2 +- js/web/lib/wasm/session-options.ts | 15 ++++++--------- js/web/script/test-runner-cli-args.ts | 2 +- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index f629858f99b72..2dd15a9c3e282 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -203,7 +203,7 @@ export declare namespace InferenceSession { } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; - deviceType?: number; // 0 - auto, 1 - gpu, 2 - cpu + deviceType?: number; // 0 - auto, 1 - gpu, 2 - cpu powerPreference?: number; // 0 - auto, 1 - high-performance, 2 - low-power } // #endregion diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index fba1e9790d95d..4cce53f85725e 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -71,21 +71,18 @@ const setExecutionProviders = if (webnnOptions?.deviceType) { const keyDataOffset = allocWasmString('deviceType', allocs); const valueDataOffset = allocWasmString(webnnOptions.deviceType.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry( - sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - throw new Error( - `Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}` - ); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + throw new Error(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}`); } } if (webnnOptions?.powerPreference) { const keyDataOffset = allocWasmString('powerPreference', allocs); const valueDataOffset = allocWasmString(webnnOptions.powerPreference.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry( - sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { throw new Error( - `Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}` - ); + `Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}`); } } } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index c189c22cfdf91..12788b7b545c0 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -365,7 +365,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs // TODO: remove this when Chrome support WebGPU or WebNN. // we need this for now because Chrome does not support webgpu and webnn yet, // and ChromeCanary is not in CI. - const defaultBrowserBackends = ['webgl', /* 'webgpu', */ 'wasm', 'xnnpack'/*, 'webnn'*/]; + const defaultBrowserBackends = ['webgl', /* 'webgpu', */ 'wasm', 'xnnpack' /*, 'webnn'*/]; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; const backend = (typeof backendArgs !== 'string') ? (env === 'node' ? nodejsBackends : defaultBrowserBackends) : From ad461ad8788114235132ea1d76cbb093d62f093d Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 27 Apr 2023 15:06:05 +0800 Subject: [PATCH 4/7] Throw error when wasm-enable-proxy is false for webnn ep --- js/web/script/test-runner-cli-args.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 12788b7b545c0..61f85f2e67666 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -376,6 +376,11 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } } + if (backend.includes('webnn') && args['wasm-enable-proxy'] !== 'true') { + throw new Error( + 'backend webnn is restricted in the dedicated worker, set "--wasm-enable-proxy true" to enable proxy worker'); + } + const globalEnvFlags = parseGlobalEnvFlags(args); // Options: From cc8084eb8de37c99a6e8676a0c8b9d65bbc6df2a Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 28 Apr 2023 16:30:23 +0800 Subject: [PATCH 5/7] Addressed @fs-eire's comments --- cmake/adjust_global_compile_flags.cmake | 6 +----- js/common/lib/inference-session.ts | 4 ++-- js/web/lib/wasm/session-options.ts | 4 ++-- js/web/script/test-runner-cli-args.ts | 6 +++--- .../providers/webnn/webnn_execution_provider.cc | 15 ++++++--------- .../providers/webnn/webnn_execution_provider.h | 2 +- .../providers/webnn/webnn_provider_factory.cc | 12 ++++++------ onnxruntime/core/session/provider_registration.cc | 4 ++-- tools/ci_build/build.py | 6 +++++- 9 files changed, 28 insertions(+), 31 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 58027b2cf2e96..58a9271d26e7f 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -131,11 +131,7 @@ if (onnxruntime_DISABLE_RTTI) # Disable RTTI and turn usage of dynamic_cast and typeid into errors add_compile_options("$<$:/GR->" "$<$:/we4541>") else() - # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled - # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/16911 - if(NOT onnxruntime_USE_WEBNN) - add_compile_options("$<$:-fno-rtti>") - endif() + add_compile_options("$<$:-fno-rtti>") endif() else() #MSVC RTTI flag /GR is not added to CMAKE_CXX_FLAGS by default. But, anyway VC++2019 treats "/GR" default on. diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 2dd15a9c3e282..e32cd0dfa067a 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -203,8 +203,8 @@ export declare namespace InferenceSession { } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; - deviceType?: number; // 0 - auto, 1 - gpu, 2 - cpu - powerPreference?: number; // 0 - auto, 1 - high-performance, 2 - low-power + deviceType?: 'cpu'|'gpu'; + powerPreference?: 'default'|'low-power'|'high-performance'; } // #endregion diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 4cce53f85725e..c27d2f1e17e21 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -70,7 +70,7 @@ const setExecutionProviders = const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; if (webnnOptions?.deviceType) { const keyDataOffset = allocWasmString('deviceType', allocs); - const valueDataOffset = allocWasmString(webnnOptions.deviceType.toString(), allocs); + const valueDataOffset = allocWasmString(webnnOptions.deviceType, allocs); if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { throw new Error(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}`); @@ -78,7 +78,7 @@ const setExecutionProviders = } if (webnnOptions?.powerPreference) { const keyDataOffset = allocWasmString('powerPreference', allocs); - const valueDataOffset = allocWasmString(webnnOptions.powerPreference.toString(), allocs); + const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs); if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { throw new Error( diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 61f85f2e67666..d0d3f8cf906a1 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -376,13 +376,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } } - if (backend.includes('webnn') && args['wasm-enable-proxy'] !== 'true') { + const globalEnvFlags = parseGlobalEnvFlags(args); + + if (backend.includes('webnn') && !globalEnvFlags.wasm.proxy) { throw new Error( 'backend webnn is restricted in the dedicated worker, set "--wasm-enable-proxy true" to enable proxy worker'); } - const globalEnvFlags = parseGlobalEnvFlags(args); - // Options: // --log-verbose=<...> // --log-info=<...> diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 78b32f57a74f5..a8589f1d4d1aa 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -20,7 +20,8 @@ namespace onnxruntime { constexpr const char* WEBNN = "WebNN"; -WebNNExecutionProvider::WebNNExecutionProvider(uint32_t webnn_device_flags, uint32_t webnn_power_flags) +WebNNExecutionProvider::WebNNExecutionProvider( + const std::string& webnn_device_flags, const std::string& webnn_power_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} { AllocatorCreationInfo device_info( [](int) { @@ -38,12 +39,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(uint32_t webnn_device_flags, uint InsertAllocator(CreateAllocator(cpu_memory_info)); // Create WebNN context and graph builder. - InlinedHashMap device_type_name_s = { - {0, "auto"}, {1, "gpu"}, {2, "cpu"}}; - InlinedHashMap power_preference_name_s = { - {0, "auto"}, {1, "high-performance"}, {2, "low-power"}}; - std::string device_type_name_ = device_type_name_s[webnn_device_flags]; - std::string power_preference_name_ = power_preference_name_s[webnn_power_flags]; const emscripten::val ml = emscripten::val::global("navigator")["ml"]; if (!ml.as()) { ORT_THROW("Failed to get ml from navigator."); @@ -53,8 +48,10 @@ WebNNExecutionProvider::WebNNExecutionProvider(uint32_t webnn_device_flags, uint // defined in Model Loader API, which uses MLDevicePreference instead of MLDeviceType // defined in WebNN. Because there's an ongoing spec discussion to simplify this API at // https://github.com/webmachinelearning/webnn/issues/302. - context_options.set("devicePreference", emscripten::val(device_type_name_)); - context_options.set("powerPreference", emscripten::val(power_preference_name_)); + context_options.set("devicePreference", emscripten::val(webnn_device_flags)); + if (webnn_power_flags.compare("default") != 0) { + context_options.set("powerPreference", emscripten::val(webnn_power_flags)); + } wnn_context_ = ml.call("createContextSync", context_options); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index a67f3793a326d..697b7ecd6d462 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -17,7 +17,7 @@ class Model; class WebNNExecutionProvider : public IExecutionProvider { public: - WebNNExecutionProvider(uint32_t webnn_device_flags, uint32_t webnn_power_flags); + WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_power_flags); virtual ~WebNNExecutionProvider(); std::vector> diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc index 8294852479a99..dc5f5a4c7da93 100644 --- a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -10,14 +10,14 @@ using namespace onnxruntime; namespace onnxruntime { struct WebNNProviderFactory : IExecutionProviderFactory { - WebNNProviderFactory(uint32_t webnn_device_flags, uint32_t webnn_power_flags) + WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_power_flags) : webnn_device_flags_(webnn_device_flags), webnn_power_flags_(webnn_power_flags) {} ~WebNNProviderFactory() override {} std::unique_ptr CreateProvider() override; - uint32_t webnn_device_flags_; - uint32_t webnn_power_flags_; + std::string webnn_device_flags_; + std::string webnn_power_flags_; }; std::unique_ptr WebNNProviderFactory::CreateProvider() { @@ -26,12 +26,12 @@ std::unique_ptr WebNNProviderFactory::CreateProvider() { std::shared_ptr WebNNProviderFactoryCreator::Create( const ProviderOptions& provider_options) { - uint32_t webnn_device_flags = 2, webnn_power_flags = 0; + std::string webnn_device_flags = "cpu", webnn_power_flags = "default"; if (auto it = provider_options.find("deviceType"); it != provider_options.end()) { - webnn_device_flags = std::stoi(it->second); + webnn_device_flags = it->second; } if (auto it = provider_options.find("powerPreference"); it != provider_options.end()) { - webnn_power_flags = std::stoi(it->second); + webnn_power_flags = it->second; } return std::make_shared(webnn_device_flags, webnn_power_flags); } diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 28bd00a7047cc..3d712a81cd5a2 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -86,8 +86,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, #endif } else if (strcmp(provider_name, "WEBNN") == 0) { #if defined(USE_WEBNN) - std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "2"); - std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "0"); + std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu"); + std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default"); provider_options["deviceType"] = deviceType; provider_options["powerPreference"] = powerPreference; options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 39ac971616908..8817e6242cd53 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1198,7 +1198,11 @@ def generate_build_tree( if args.use_webnn: if not args.build_wasm: - raise BuildError("WebNN is only available for WASM build.") + raise BuildError("WebNN is only available for WebAssembly build.") + if args.disable_rtti: + # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled + # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/16911 + raise BuildError("WebNN is not supported with RTTI disabled.") cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] if args.use_snpe: From 46532468696f4ef6dc25c42cd95b936596562ed8 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 4 May 2023 15:17:48 +0800 Subject: [PATCH 6/7] Addressed further comments --- js/web/script/test-runner-cli-args.ts | 4 +- .../webnn/builders/impl/concat_op_builder.cc | 23 ------------ .../webnn/builders/impl/gemm_op_builder.cc | 37 ++++++++----------- .../providers/webnn/webnn_provider_factory.cc | 10 +---- 4 files changed, 20 insertions(+), 54 deletions(-) diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index d0d3f8cf906a1..c0ced63be6852 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -379,8 +379,8 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); if (backend.includes('webnn') && !globalEnvFlags.wasm.proxy) { - throw new Error( - 'backend webnn is restricted in the dedicated worker, set "--wasm-enable-proxy true" to enable proxy worker'); + // Backend webnn is restricted in the dedicated worker. + globalEnvFlags.wasm.proxy = true; } // Options: diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 1ed516a3c4d9e..c39927b3cc26b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -19,11 +19,6 @@ class ConcatOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const logging::Logger& logger) const override; }; // Add operator related. @@ -53,24 +48,6 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. -bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const logging::Logger& logger) const { - std::vector input_shape; - const auto& input_defs(node.InputDefs()); - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - const auto input_size = input_shape.size(); - if (input_size > 4 || input_size == 0) { - LOGS_DEFAULT(VERBOSE) << "Concat only supports up to 1-4d shape, input is " - << input_size << "d shape"; - return false; - } - - return true; -} - void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 6a41b75e26a14..29710ed040caa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -36,29 +36,24 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val a = model_builder.GetOperand(node.InputDefs()[a_idx]->Name()); emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); - emscripten::val output = emscripten::val::object(); - if (op_type == "MatMul") { - output = model_builder.GetBuilder().call("matmul", a, b); - } else { // Gemm - emscripten::val options = emscripten::val::object(); - NodeAttrHelper helper(node); - const auto transA = helper.Get("transA", 0); - options.set("aTranspose", emscripten::val(transA == 1)); - const auto transB = helper.Get("transB", 0); - options.set("bTranspose", emscripten::val(transB == 1)); - const auto alpha = helper.Get("alpha", 1.0f); - const auto beta = helper.Get("beta", 1.0f); - options.set("alpha", alpha); - options.set("beta", beta); - - // Add bias if present. - if (input_defs.size() > 2) { - options.set("c", model_builder.GetOperand(node.InputDefs()[c_idx]->Name())); - } - - output = model_builder.GetBuilder().call("gemm", a, b, options); + emscripten::val options = emscripten::val::object(); + NodeAttrHelper helper(node); + const auto transA = helper.Get("transA", 0); + options.set("aTranspose", emscripten::val(transA == 1)); + const auto transB = helper.Get("transB", 0); + options.set("bTranspose", emscripten::val(transB == 1)); + const auto alpha = helper.Get("alpha", 1.0f); + const auto beta = helper.Get("beta", 1.0f); + options.set("alpha", alpha); + options.set("beta", beta); + + // Add bias if present. + if (input_defs.size() > 2) { + options.set("c", model_builder.GetOperand(node.InputDefs()[c_idx]->Name())); } + emscripten::val output = model_builder.GetBuilder().call("gemm", a, b, options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc index dc5f5a4c7da93..4d6b04c8e76d8 100644 --- a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -26,14 +26,8 @@ std::unique_ptr WebNNProviderFactory::CreateProvider() { std::shared_ptr WebNNProviderFactoryCreator::Create( const ProviderOptions& provider_options) { - std::string webnn_device_flags = "cpu", webnn_power_flags = "default"; - if (auto it = provider_options.find("deviceType"); it != provider_options.end()) { - webnn_device_flags = it->second; - } - if (auto it = provider_options.find("powerPreference"); it != provider_options.end()) { - webnn_power_flags = it->second; - } - return std::make_shared(webnn_device_flags, webnn_power_flags); + return std::make_shared(provider_options.at("deviceType"), + provider_options.at("powerPreference")); } } // namespace onnxruntime From 58a8bec3425d29297e9a487a44ab87e31db6efae Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 5 May 2023 10:16:16 +0800 Subject: [PATCH 7/7] Fixed nit --- .../webnn/builders/impl/gemm_op_builder.cc | 98 +++++++++---------- 1 file changed, 47 insertions(+), 51 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 29710ed040caa..2330d85f911e5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -30,7 +30,6 @@ class GemmOpBuilder : public BaseOpBuilder { // Add operator related. Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C @@ -63,71 +62,68 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const logging::Logger& logger) const { (void)initializers; - const auto& op_type = node.OpType(); const auto& input_defs(node.InputDefs()); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C - if (op_type == "Gemm") { - std::vector a_shape; - { - if (!GetShape(*input_defs[a_idx], a_shape, logger)) - return false; + std::vector a_shape; + { + if (!GetShape(*input_defs[a_idx], a_shape, logger)) + return false; - if (a_shape.size() != 2) { - LOGS(logger, VERBOSE) << "A must be 2D"; - return false; - } + if (a_shape.size() != 2) { + LOGS(logger, VERBOSE) << "A must be 2D"; + return false; + } - if (Product(a_shape) == 0) { - LOGS(logger, VERBOSE) << "A must be non-empty"; - return false; - } + if (Product(a_shape) == 0) { + LOGS(logger, VERBOSE) << "A must be non-empty"; + return false; } + } - std::vector b_shape; - { - if (!GetShape(*input_defs[b_idx], b_shape, logger)) - return false; + std::vector b_shape; + { + if (!GetShape(*input_defs[b_idx], b_shape, logger)) + return false; - if (b_shape.size() != 2) { - LOGS(logger, VERBOSE) << "B must be 2D"; - return false; - } + if (b_shape.size() != 2) { + LOGS(logger, VERBOSE) << "B must be 2D"; + return false; + } - if (Product(b_shape) == 0) { - LOGS(logger, VERBOSE) << "B must be non-empty"; - return false; - } + if (Product(b_shape) == 0) { + LOGS(logger, VERBOSE) << "B must be non-empty"; + return false; } + } - // C of Gemm. - if (input_defs.size() == 3) { - std::vector c_shape; - if (!GetShape(*input_defs[c_idx], c_shape, logger)) - return false; + // C of Gemm. + if (input_defs.size() == 3) { + std::vector c_shape; + if (!GetShape(*input_defs[c_idx], c_shape, logger)) + return false; + + size_t c_dim = c_shape.size(); - size_t c_dim = c_shape.size(); + if (c_dim > 1) { + // TODO: Supports other shape of C. + // Currently WebNN implementation in Chromium only supports 1-D C. + return false; + } + if (c_dim == 0) { + LOGS(logger, VERBOSE) << "C of Gemm is a scalar"; + } else { + auto c_size = c_shape[c_dim - 1]; + NodeAttrHelper helper(node); + const auto transB = helper.Get("transB", 0); + if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) { + LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape[" + << (transB == 0 ? "1" : "0") << "]" + << " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]" + << " c_size: " << c_size; - if (c_dim > 1) { - // TODO: Supports other shape of C. - // Currently WebNN implementation in Chromium only supports 1-D C. return false; } - if (c_dim == 0) { - LOGS(logger, VERBOSE) << "C of Gemm is a scalar"; - } else { - auto c_size = c_shape[c_dim - 1]; - NodeAttrHelper helper(node); - const auto transB = helper.Get("transB", 0); - if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) { - LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape[" - << (transB == 0 ? "1" : "0") << "]" - << " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]" - << " c_size: " << c_size; - - return false; - } - } } }