diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index d5ab7f4b74a9a..2804063e1fc68 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -122,6 +122,7 @@ option(onnxruntime_TVM_CUDA_RUNTIME "Build TVM with CUDA support" OFF) option(onnxruntime_TVM_USE_LLVM "Build TVM with LLVM. Set customized path to llvm-config.exe here if need" OFF) option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algorithm. It is defined for TVM only") option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF) +option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF) # Options related to reducing the binary size produced by the build # XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON @@ -722,6 +723,11 @@ if (onnxruntime_USE_XNNPACK) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_XNNPACK=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES xnnpack) endif() +if (onnxruntime_USE_WEBNN) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBNN=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBNN=1) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES webnn) +endif() if (onnxruntime_USE_CANN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CANN=1) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 9f34d1f46751d..571ceec193e8b 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -185,6 +185,7 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_ROCM} ${PROVIDERS_VITISAI} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} ${PROVIDERS_INTERNAL_TESTING} ${onnxruntime_winml} diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index caae2aacfd582..73469afa6116b 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -147,6 +147,9 @@ endif() if (onnxruntime_USE_XNNPACK) set(PROVIDERS_XNNPACK onnxruntime_providers_xnnpack) endif() +if(onnxruntime_USE_WEBNN) + set(PROVIDERS_WEBNN onnxruntime_providers_webnn) +endif() if(onnxruntime_USE_SNPE) include(onnxruntime_snpe_provider.cmake) endif() @@ -983,6 +986,31 @@ if (onnxruntime_USE_COREML) endif() endif() +if (onnxruntime_USE_WEBNN) + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "WebNN EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + + add_compile_definitions(USE_WEBNN=1) + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1) + endif() + file(GLOB_RECURSE onnxruntime_providers_webnn_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/webnn/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/webnn/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + + source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webnn_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_webnn ${onnxruntime_providers_webnn_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_webnn onnxruntime_common onnx onnx_proto Boost::mp11) + + add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) +endif() + if (onnxruntime_USE_NNAPI_BUILTIN) if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) message(FATAL_ERROR "NNAPI can not be used in a basic minimal build. Please build with '--minimal_build extended'") diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 193315f541b85..0d6ef7a9711b5 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -110,6 +110,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB) onnxruntime_providers ${PROVIDERS_JS} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBNN} onnxruntime_session onnxruntime_util re2::re2 @@ -186,6 +187,7 @@ else() onnxruntime_providers ${PROVIDERS_JS} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBNN} onnxruntime_session onnxruntime_util re2::re2 @@ -194,6 +196,10 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) endif() + if(onnxruntime_USE_WEBNN) + target_link_libraries(onnxruntime_webassembly PRIVATE onnxruntime_providers_webnn) + endif() + if (onnxruntime_ENABLE_TRAINING) target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() @@ -255,6 +261,10 @@ else() ) endif() + if (onnxruntime_USE_WEBNN) + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind") + endif() + # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 6fc9ef6e1c8c3..7e59aad80cc47 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -48,6 +48,7 @@ constexpr const char* kJsExecutionProvider = "JsExecutionProvider"; constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider"; constexpr const char* kTvmExecutionProvider = "TvmExecutionProvider"; constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider"; +constexpr const char* kWebNNExecutionProvider = "WebNNExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 638cb90f36716..e32cd0dfa067a 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -165,7 +165,7 @@ export declare namespace InferenceSession { // Currently, we have the following backends to support execution providers: // Backend Node.js binding: supports 'cpu' and 'cuda'. - // Backend WebAssembly: supports 'cpu', 'wasm' and 'xnnpack'. + // Backend WebAssembly: supports 'cpu', 'wasm', 'xnnpack' and 'webnn'. // Backend ONNX.js: supports 'webgl'. interface ExecutionProviderOptionMap { cpu: CpuExecutionProviderOption; @@ -173,6 +173,7 @@ export declare namespace InferenceSession { wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; xnnpack: XnnpackExecutionProviderOption; + webnn: WebNNExecutionProviderOption; } type ExecutionProviderName = keyof ExecutionProviderOptionMap; @@ -200,6 +201,11 @@ export declare namespace InferenceSession { export interface XnnpackExecutionProviderOption extends ExecutionProviderOption { readonly name: 'xnnpack'; } + export interface WebNNExecutionProviderOption extends ExecutionProviderOption { + readonly name: 'webnn'; + deviceType?: 'cpu'|'gpu'; + powerPreference?: 'default'|'low-power'|'high-performance'; + } // #endregion // #endregion diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 2a4e71e064632..3d71df7687b01 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -86,13 +86,56 @@ module.exports = function (config) { hostname, listenAddress, customLaunchers: { - ChromeTest: { base: 'ChromeHeadless', flags: ['--enable-features=SharedArrayBuffer'] }, - ChromePerf: { base: 'Chrome', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer'] }, - ChromeDebug: { debug: true, base: 'Chrome', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer'] }, - ChromeCanaryTest: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] }, - ChromeCanaryProfileTest: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu', '--disable-dawn-features=disallow_unsafe_apis'] }, - ChromeCanaryDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] }, - ChromeCanaryProfileDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu', '--disable-dawn-features=disallow_unsafe_apis'] }, + ChromeTest: { + base: 'ChromeHeadless', + flags: ['--enable-features=SharedArrayBuffer'] + }, + ChromePerf: { + base: 'Chrome', + flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer'] + }, + ChromeDebug: { + debug: true, + base: 'Chrome', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer'] + }, + ChromeCanaryTest: { + base: 'ChromeCanary', + flags: [ + '--window-size=1,1', + '--enable-features=SharedArrayBuffer', + '--enable-unsafe-webgpu', + '--enable-experimental-web-platform-features' + ] + }, + ChromeCanaryProfileTest: { + base: 'ChromeCanary', + flags: [ + '--window-size=1,1', + '--enable-features=SharedArrayBuffer', + '--enable-unsafe-webgpu', + '--disable-dawn-features=disallow_unsafe_apis' + ] + }, + ChromeCanaryDebug: { + debug: true, + base: 'ChromeCanary', + flags: [ + '--remote-debugging-port=9333', + '--enable-features=SharedArrayBuffer', + '--enable-unsafe-webgpu', + '--enable-experimental-web-platform-features' + ] + }, + ChromeCanaryProfileDebug: { + debug: true, + base: 'ChromeCanary', + flags: [ + '--remote-debugging-port=9333', + '--enable-features=SharedArrayBuffer', + '--enable-unsafe-webgpu', + '--disable-dawn-features=disallow_unsafe_apis', + ] + }, // // ==== BrowserStack browsers ==== // diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 749331058cc4a..41024de108acb 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -22,4 +22,5 @@ if (!BUILD_DEFS.DISABLE_WASM) { registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); registerBackend('xnnpack', wasmBackend, 9); + registerBackend('webnn', wasmBackend, 9); } diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 10cc48257dc52..c27d2f1e17e21 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -64,6 +64,29 @@ const setExecutionProviders = case 'xnnpack': epName = 'XNNPACK'; break; + case 'webnn': + epName = 'WEBNN'; + if (typeof ep !== 'string') { + const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; + if (webnnOptions?.deviceType) { + const keyDataOffset = allocWasmString('deviceType', allocs); + const valueDataOffset = allocWasmString(webnnOptions.deviceType, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + throw new Error(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}`); + } + } + if (webnnOptions?.powerPreference) { + const keyDataOffset = allocWasmString('powerPreference', allocs); + const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + throw new Error( + `Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}`); + } + } + } + break; case 'webgpu': epName = 'JS'; break; diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index e20c391513c67..c0ced63be6852 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -37,6 +37,7 @@ Options: webgpu wasm xnnpack + webnn -e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following: chrome (default) edge (Windows only) @@ -104,7 +105,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'; + type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; } @@ -359,12 +360,12 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } // Option: -b=<...>, --backend=<...> - const browserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack']; + const browserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack', 'webnn']; - // TODO: remove this when Chrome support WebGPU. - // we need this for now because Chrome does not support webgpu yet, + // TODO: remove this when Chrome support WebGPU or WebNN. + // we need this for now because Chrome does not support webgpu and webnn yet, // and ChromeCanary is not in CI. - const defaultBrowserBackends = ['webgl', /* 'webgpu', */ 'wasm', 'xnnpack']; + const defaultBrowserBackends = ['webgl', /* 'webgpu', */ 'wasm', 'xnnpack' /*, 'webnn'*/]; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; const backend = (typeof backendArgs !== 'string') ? (env === 'node' ? nodejsBackends : defaultBrowserBackends) : @@ -377,6 +378,11 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); + if (backend.includes('webnn') && !globalEnvFlags.wasm.proxy) { + // Backend webnn is restricted in the dedicated worker. + globalEnvFlags.wasm.proxy = true; + } + // Options: // --log-verbose=<...> // --log-info=<...> diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 72938789bc2df..90eb32ddaddd8 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -53,7 +53,7 @@ async function main() { // The default backends and opset version lists. Those will be used in suite tests. const DEFAULT_BACKENDS: readonly TestRunnerCliArgs.Backend[] = - args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl', 'webgpu']; + args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl', 'webgpu', 'webnn']; const DEFAULT_OPSET_VERSIONS = fs.readdirSync(TEST_DATA_MODEL_NODE_ROOT, {withFileTypes: true}) .filter(dir => dir.isDirectory() && dir.name.startsWith('opset')) .map(dir => dir.name.slice(5)); @@ -459,12 +459,13 @@ async function main() { // STEP 5. use Karma to run test npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); const webgpu = args.backends.indexOf('webgpu') > -1; + const webnn = args.backends.indexOf('webnn') > -1; const browser = getBrowserNameFromEnv( args.env, args.bundleMode === 'perf' ? 'perf' : args.debug ? 'debug' : 'test', - webgpu, config.options.globalEnvFlags?.webgpu?.profilingMode === 'default'); + webgpu, webnn, config.options.globalEnvFlags?.webgpu?.profilingMode === 'default'); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); @@ -474,7 +475,7 @@ async function main() { if (args.noSandbox) { karmaArgs.push('--no-sandbox'); } - if (webgpu) { + if (webgpu || webnn) { karmaArgs.push('--force-localhost'); } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); @@ -569,10 +570,10 @@ async function main() { } function getBrowserNameFromEnv( - env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean, profile: boolean) { + env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean, profile: boolean) { switch (env) { case 'chrome': - return selectChromeBrowser(mode, webgpu, profile); + return selectChromeBrowser(mode, webgpu, webnn, profile); case 'edge': return 'Edge'; case 'firefox': @@ -588,7 +589,7 @@ async function main() { } } - function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean, profile: boolean) { + function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean, profile: boolean) { if (webgpu) { switch (mode) { case 'debug': @@ -596,6 +597,13 @@ async function main() { default: return profile ? 'ChromeCanaryProfileTest' : 'ChromeCanaryDebug'; } + } else if (webnn) { + switch (mode) { + case 'debug': + return 'ChromeCanaryDebug'; + default: + return 'ChromeCanaryTest'; + } } else { switch (mode) { case 'debug': diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 17928899c91b1..f2965c506b2bf 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1419,5 +1419,104 @@ "test_instancenorm_example" ], "ops": [] + }, + "webnn": { + "onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"], + "node": [ + // Check in node tests that have native Wasm implementations. + // (i.e.) not tests that rely on the fallback cpu implementations. + // Use the 'cpu' level of node tests to test those implementations. + "test_add_bcast", + "test_add", + "test_sub_bcast", + "test_sub_example", + "test_sub", + "test_mul_bcast", + "test_mul_example", + "test_mul", + "test_div_bcast", + "test_div_example", + "test_div", + "test_xor_bcast3v1d", + "test_xor_bcast3v2d", + "test_xor_bcast4v2d", + "test_xor_bcast4v3d", + "test_xor_bcast4v4d", + "test_xor2d", + "test_xor3d", + "test_xor4d", + "test_or_bcast3v1d", + "test_or_bcast3v2d", + "test_or_bcast4v2d", + "test_or_bcast4v3d", + "test_or_bcast4v4d", + "test_and_bcast3v1d", + "test_and_bcast3v2d", + "test_and_bcast4v2d", + "test_and_bcast4v3d", + "test_and_bcast4v4d", + "test_and2d", + "test_and3d", + "test_and4d", + "test_prelu_broadcast", + "test_prelu_example", + "test_basic_conv_with_padding", + "test_basic_conv_without_padding", + "test_batchnorm_epsilon", + "test_batchnorm_example", + "opset{10,11,12}/test_cast_STRING_to_FLOAT", + "test_clip_splitbounds", + "test_clip_outbounds", + "test_clip_inbounds", + "test_clip_example", + "test_clip_default_min", + "test_clip_default_max", + "test_clip_default_inbounds", + "test_clip", + "test_conv_with_strides_and_asymmetric_padding", + "test_conv_with_strides_no_padding", + "test_conv_with_strides_padding", + "test_gemm_nobroadcast", + "test_gemm_broadcast", + "test_matmul_2d", + "test_matmul_3d", + "test_matmul_4d", + "test_softmax_axis_0", + "test_softmax_axis_1", + "test_softmax_axis_2", + "test_softmax_default_axis", + "test_softmax_example", + "test_softmax_large_number", + "test_sum_example", + "test_sum_one_input", + "test_sum_two_inputs", + "test_averagepool_1d_default", + "test_averagepool_2d_default", + "test_averagepool_2d_pads", + "test_averagepool_2d_precomputed_pads", + "test_averagepool_2d_precomputed_same_upper", + "test_averagepool_2d_precomputed_strides", + "test_averagepool_2d_same_upper", + "test_averagepool_2d_same_lower", + "test_averagepool_2d_strides", + "test_averagepool_3d_default", + "test_maxpool_1d_default", + "test_maxpool_2d_default", + "test_maxpool_2d_pads", + "test_maxpool_2d_precomputed_pads", + "test_maxpool_2d_precomputed_same_upper", + "test_maxpool_2d_precomputed_strides", + "test_maxpool_2d_same_lower", + "test_maxpool_2d_same_upper", + "test_maxpool_2d_strides", + "test_maxpool_3d_default", + "test_globalaveragepool_precomputed", + "test_globalaveragepool", + "test_globalmaxpool_precomputed", + "test_globalmaxpool", + "test_instancenorm_epsilon", + "test_instancenorm_example" + ], + "ops": [] } } diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 26ebcbbd6e212..6035b32a8f149 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -311,7 +311,7 @@ export class TensorResultValidator { } else if (backend === 'webgpu') { this.absoluteThreshold = WEBGPU_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WEBGPU_THRESHOLD_RELATIVE_ERROR; - } else if (backend === 'wasm' || backend === 'xnnpack') { + } else if (backend === 'wasm' || backend === 'xnnpack' || backend === 'webnn') { this.absoluteThreshold = WASM_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WASM_THRESHOLD_RELATIVE_ERROR; } else if (backend === 'onnxruntime') { diff --git a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc index 65a037cfa0a0f..b04b02e74c1bf 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc @@ -873,7 +873,7 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { { "FusedConv", "QLinearAveragePool", "QLinearGlobalAveragePool" -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) // The CUDA/ROCM Resize kernel is layout sensitive as it only handles NCHW input. // The CPU kernel and ONNX spec are not limited to handling NCHW input so are not layout sensitive, and // onnx_layout_transformation::HandleResize is used. diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index f7b372d1eff74..b0f510f054a03 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -146,6 +146,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kWebNNExecutionProvider, +#ifdef USE_WEBNN + true, +#else + false, #endif }, { diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index b019ede434b83..42a58097e1635 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -86,6 +86,10 @@ #include "core/providers/xnnpack/xnnpack_provider_factory_creator.h" #endif +#if defined(USE_WEBNN) +#include "core/providers/webnn/webnn_provider_factory_creator.h" +#endif + #if defined(USE_CANN) #include "core/providers/cann/cann_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc new file mode 100644 index 0000000000000..c52a6b18c696f --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "helper.h" +#include + +#include "op_builder_factory.h" + +namespace onnxruntime { +namespace webnn { + +bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger) { + const auto* shape_proto = node_arg.Shape(); + if (!shape_proto) { + LOGS(logger, WARNING) << "NodeArg [" << node_arg.Name() << "] has no shape info"; + return false; + } + + // We already checked the shape has no dynamic dimension. + for (const auto& dim : shape_proto->dim()) { + shape.push_back(dim.dim_value()); + } + + return true; +} + +bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const logging::Logger& logger) { + const auto& op_builders = GetOpBuilders(); + if (Contains(op_builders, node.OpType())) { + const auto* op_builder = op_builders.at(node.OpType()); + return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, logger); + } else { + return false; + } +} + +bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) { + const auto& input_name = input.Name(); + const auto* shape_proto = input.Shape(); + // We do not support input with no shape. + if (!shape_proto) { + LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name + << "] has not shape"; + return false; + } + + for (const auto& dim : shape_proto->dim()) { + // For now we workaround dynamic shape support by assuming 1. + if (!dim.has_dim_value()) { + LOGS(logger, VERBOSE) << "Dynamic shape is not supported for now, assume to be 1, for input:" << input_name; + } + } + + return true; +} + +std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, + const emscripten::val& wnn_builder_, + const logging::Logger& logger) { + std::vector> supported_node_groups; + + for (const auto* input : graph_viewer.GetInputs()) { + if (!IsInputSupported(*input, "graph", logger)) { + return supported_node_groups; + } + } + + std::vector supported_node_group; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + for (size_t i = 0; i < node_indices.size(); i++) { + auto node_idx = node_indices[i]; + const auto* node(graph_viewer.GetNode(node_idx)); + bool supported = false; + // Firstly check if platform supports the WebNN op. + if (CheckSingleOp(node->OpType(), wnn_builder_)) { + LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; + supported = IsNodeSupported(*node, graph_viewer, logger); + } + + LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() + << "] index: [" << node_idx + << "] name: [" << node->Name() + << "] supported: [" << supported + << "]"; + if (supported) { + supported_node_group.push_back(node_idx); + } else { + if (!supported_node_group.empty()) { + supported_node_groups.push_back(supported_node_group); + supported_node_group.clear(); + } + } + } + + if (!supported_node_group.empty()) { + supported_node_groups.push_back(supported_node_group); + } + + return supported_node_groups; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h new file mode 100644 index 0000000000000..9a386e6e34f25 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/inlined_containers.h" +#include +#include "core/providers/common.h" + +#include +#include + +namespace onnxruntime { + +class GraphViewer; +class NodeArg; + +namespace logging { +class Logger; +} + +namespace webnn { + +bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); + +bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); + +// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. +std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, + const emscripten::val& wnn_builder_, + const logging::Logger& logger); +static const InlinedHashMap op_map = { + {"Add", "add"}, + {"Sub", "sub"}, + {"Mul", "mul"}, + {"Div", "div"}, + {"Relu", "relu"}, + {"LeakyRelu", "leakyRelu"}, + {"Sigmoid", "sigmoid"}, + {"Clip", "clamp"}, + {"Conv", "conv2d"}, + {"ConvTranspose", "convTranspose2d"}, + {"Concat", "concat"}, + {"Gemm", "gemm"}, + {"GlobalAveragePool", "averagePool2d"}, + {"GlobalMaxPool", "maxPool2d"}, + {"AveragePool", "averagePool2d"}, + {"MaxPool", "maxPool2d"}, + {"Reshape", "reshape"}, + {"Resize", "resample2d"}, + {"Transpose", "transpose"}}; + +inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_) { + return op_map.find(op_type) != op_map.end() && wnn_builder_[op_map.find(op_type)->second].as(); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc new file mode 100644 index 0000000000000..0dcbf9b527ab3 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ActivationOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + int GetMinSupportedOpSet(const Node& node) const override; +}; + +// Add operator related. + +Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& /* logger */) const { + const auto& op_type(node.OpType()); + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val output = emscripten::val::object(); + if (op_type == "Relu") { + if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { + LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node.Name() << "] fused"; + output = input; + } else { + output = model_builder.GetBuilder().call("relu", input); + } + } else if (op_type == "LeakyRelu") { + if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { + LOGS_DEFAULT(VERBOSE) << "LeakyRelu Node [" << node.Name() << "] fused"; + output = input; + } else { + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); + options.set("alpha", helper.Get("alpha", (float)0.0)); + output = model_builder.GetBuilder().call("leakyRelu", input, options); + } + } else if (op_type == "Sigmoid") { + if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { + LOGS_DEFAULT(VERBOSE) << "Sigmoid Node [" << node.Name() << "] fused"; + output = input; + } else { + output = model_builder.GetBuilder().call("sigmoid", input); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +int ActivationOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { + // All ops opset 5- uses consumed_inputs attribute which is not supported for now. + return 6; +} + +void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = {"Relu", "LeakyRelu", "Sigmoid"}; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc new file mode 100644 index 0000000000000..13c0f0131f2a6 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/shared/utils/utils.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +// Shared functions. +bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) { + for (const auto* node_arg : node.InputDefs()) { + const auto& input_name(node_arg->Name()); + if (!Contains(initializers, input_name)) + continue; + + const auto& tensor = *initializers.at(input_name); + if (tensor.has_data_location() && + tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + LOGS(logger, VERBOSE) << "Initializer [" << input_name + << "] with external data location are not currently supported"; + return true; + } + } + + return false; +} + +// Add operator related. + +Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + ORT_RETURN_IF_NOT( + IsOpSupported(model_builder.GetInitializerTensors(), node, logger), + "Unsupported operator ", + node.OpType()); + ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); + LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() + << "] type: [" << node.OpType() << "] was added"; + return Status::OK(); +} + +// Operator support related. + +bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + if (!HasSupportedInputs(node, logger)) + return false; + + // We do not support external initializers for now. + if (HasExternalInitializer(initializers, node, logger)) + return false; + + if (!HasSupportedOpSet(node, logger)) + return false; + + return IsOpSupportedImpl(initializers, node, logger); +} + +bool BaseOpBuilder::HasSupportedInputs(const Node& node, const logging::Logger& logger) const { + const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); + for (const auto* input : node.InputDefs()) { + if (!IsInputSupported(*input, node_name, logger)) { + return false; + } + } + + return HasSupportedInputsImpl(node, logger); +} + +bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { + // We only check the type of input 0 by default, specific op builder can override this. + const auto& input = *node.InputDefs()[0]; + + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported for now"; + return false; + } + + return true; +} + +bool BaseOpBuilder::HasSupportedOpSet(const Node& node, + const logging::Logger& logger) const { + auto since_version = node.SinceVersion(); + if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) { + LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset [" + << GetMinSupportedOpSet(node) << ", " + << GetMaxSupportedOpSet(node) << "]"; + return false; + } + + return true; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h new file mode 100644 index 0000000000000..7b4148c34d4a7 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webnn/builders/op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ModelBuilder; + +class BaseOpBuilder : public IOpBuilder { + public: + virtual ~BaseOpBuilder() = default; + + // Add operator related. + public: + virtual void AddInitializersToSkip(ModelBuilder& /* model_builder */, const Node& /* node */) const override {} + Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + + protected: + virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0; + + // Operator support related. + public: + bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; + + protected: + virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, + const logging::Logger& /* logger */) const { + return true; + } + + virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; + + virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 14; } + + private: + bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; + bool HasSupportedInputs(const Node& node, const logging::Logger& logger) const; +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc new file mode 100644 index 0000000000000..96ee9cf3f11d9 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" +#include "core/providers/webnn/builders/helper.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class BinaryOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + int GetMinSupportedOpSet(const Node& node) const override; +}; + +// Add operator related. + +Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& /* logger */) const { + const auto& op_type(node.OpType()); + + emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val output = emscripten::val::object(); + if (op_type == "Add") { + output = model_builder.GetBuilder().call("add", input0, input1); + } else if (op_type == "Sub") { + output = model_builder.GetBuilder().call("sub", input0, input1); + } else if (op_type == "Mul") { + output = model_builder.GetBuilder().call("mul", input0, input1); + } else if (op_type == "Div") { + output = model_builder.GetBuilder().call("div", input0, input1); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { + // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now. + return 7; +} + +void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = + { + "Add", + "Sub", + "Mul", + "Div", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc new file mode 100644 index 0000000000000..516ac7464345b --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/shared/utils/utils.h" + +#include "builder_utils.h" +#include "core/providers/webnn/builders/helper.h" + +namespace onnxruntime { +namespace webnn { + +common::Status ComputeConvPads(const std::vector input_shape, + const int64_t weight_size_y, + const int64_t weight_size_x, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + AutoPadType auto_pad_type, + std::vector& pads_out) { + const int64_t input_size_y = input_shape[2]; + const int64_t input_size_x = input_shape[3]; + const int64_t stride_y = onnx_strides[0]; + const int64_t stride_x = onnx_strides[1]; + const int64_t dilation_y = onnx_dilations[0]; + const int64_t dilation_x = onnx_dilations[1]; + + int64_t padding_top = onnx_pads[0]; + int64_t padding_bottom = onnx_pads[2]; + int64_t padding_left = onnx_pads[1]; + int64_t padding_right = onnx_pads[3]; + + ORT_RETURN_IF_ERROR(ComputePad(input_size_y, + stride_y, weight_size_y, dilation_y, + auto_pad_type, + padding_top, padding_bottom)); + ORT_RETURN_IF_ERROR(ComputePad(input_size_x, + stride_x, weight_size_x, dilation_x, + auto_pad_type, + padding_left, padding_right)); + + pads_out = {padding_top, padding_left, padding_bottom, padding_right}; + + return Status::OK(); +} + +common::Status HandleAutoPad(const std::vector input_shape, + const int64_t weight_size_y, + const int64_t weight_size_x, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + AutoPadType auto_pad_type, + AutoPadType& auto_pad_type_out) { + auto_pad_type_out = auto_pad_type; + if (auto_pad_type == AutoPadType::NOTSET && onnx_dilations == std::vector{1, 1}) { + { + std::vector same_upper_pads; + ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, + onnx_pads, onnx_strides, onnx_dilations, + AutoPadType::SAME_UPPER, same_upper_pads)); + if (onnx_pads == same_upper_pads) { + auto_pad_type_out = AutoPadType::SAME_UPPER; + return Status::OK(); + } + } + + { + std::vector same_lower_pads; + ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, + onnx_pads, onnx_strides, onnx_dilations, + AutoPadType::SAME_LOWER, same_lower_pads)); + if (onnx_pads == same_lower_pads) { + auto_pad_type_out = AutoPadType::SAME_LOWER; + return Status::OK(); + } + } + } + + return Status::OK(); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h new file mode 100644 index 0000000000000..76acbca0536ea --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +// This contains the utility functions which will be used to build a webnn model + +#pragma once + +#include "core/common/status.h" +#include "core/graph/basic_types.h" + +namespace onnxruntime { +namespace webnn { + +// Try to see if we can map explicit padding to auto padding for Conv/Pool. +// Since usually use auto padding is more efficient. +common::Status HandleAutoPad(const std::vector input_shape, + const int64_t weight_size_y, + const int64_t weight_size_x, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + AutoPadType auto_pad_type, + AutoPadType& auto_pad_type_out) ORT_MUST_USE_RESULT; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc new file mode 100644 index 0000000000000..1c07439c8e8cc --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" +#include "core/providers/shared/utils/utils.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ClipOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Both min and max values will be injected into the layer, no need to add to the model. + if (node.SinceVersion() >= 11) { + if (node.InputDefs().size() > 1) + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + + if (node.InputDefs().size() > 2) + model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); + } +} + +Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_name = node.InputDefs()[0]->Name(); + const auto& output_name = node.OutputDefs()[0]->Name(); + emscripten::val options = emscripten::val::object(); + float minValue, maxValue; + ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetInitializerTensors(), node, minValue, maxValue, logger), + "GetClipMinMax failed"); + options.set("minValue", minValue); + options.set("maxValue", maxValue); + emscripten::val input = model_builder.GetOperand(input_name); + emscripten::val output = emscripten::val::object(); + if (Contains(model_builder.GetFusedActivations(), input_name)) { + LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node.Name() << "] fused"; + output = input; + } else { + output = model_builder.GetBuilder().call("clamp", input, options); + } + + model_builder.AddOperand(output_name, std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + float min, max; + return GetClipMinMax(initializers, node, min, max, logger); +} + +void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc new file mode 100644 index 0000000000000..c39927b3cc26b --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ConcatOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; +}; + +// Add operator related. + +Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ConcatOpBuilder::AddToModelBuilderImpl, cannot get input shape"); + } + auto rank = input_shape.size(); + NodeAttrHelper helper(node); + uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); + + emscripten::val inputs = emscripten::val::array(); + for (const auto* input : node.InputDefs()) { + LOGS(logger, VERBOSE) << "input name " << input->Name(); + inputs.call("push", model_builder.GetOperand(input->Name())); + } + + emscripten::val output = model_builder.GetBuilder().call("concat", inputs, axis); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc new file mode 100644 index 0000000000000..93746090a29bd --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -0,0 +1,287 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" +#include "builder_utils.h" + +namespace onnxruntime { +namespace webnn { + +class ConvOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, + const logging::Logger& /* logger */) const override; +}; + +void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // skip the weight for conv as we need to transpose. + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W + model_builder.AddInputToSkip(node.InputDefs()[1]->Name()); +} + +// Helper functions +common::Status SetConvBaseOptions(ModelBuilder& model_builder, + const Node& node, emscripten::val& options, + const std::vector& strides, + const std::vector& dilations, + const std::vector& pads, + const logging::Logger& logger) { + NodeAttrHelper helper(node); + const auto group = helper.Get("group", static_cast(1)); + const auto& input_defs = node.InputDefs(); + const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + const auto& weight_shape = weight_tensor.dims(); + + options.set("strides", emscripten::val::array(strides)); + options.set("dilations", emscripten::val::array(dilations)); + options.set("inputLayout", emscripten::val("nhwc")); + options.set("groups", group); + // Add Padding. + // Usually using autopadding is more efficient than using explicit padding. + // Try to see if we can map explicit padding to auto padding. + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], + helper.Get("pads", std::vector{0, 0, 0, 0}), + helper.Get("strides", std::vector{1, 1}), + helper.Get("dilations", std::vector{1, 1}), + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + options.set("autoPad", emscripten::val("same-lower")); + } else { + options.set("autoPad", emscripten::val("same-upper")); + } + } else { + options.set("padding", emscripten::val::array(pads)); + } + + // Add bias if present. + if (input_defs.size() > 2) { + options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); + } + InlinedHashSet supported_nodes{"Clip", "Relu"}; + emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0], supported_nodes); + if (emscripten::val::null() != activation) { + options.set("activation", activation); + } + + return Status::OK(); +} + +// Both depthwise Conv and ConvTranspose share the same logic to add the layout. +Status AddInitializerInNewLayout(ModelBuilder& model_builder, + const std::string& name, + bool is_conv) { + const auto& tensor = *model_builder.GetInitializerTensors().at(name); + auto data_type = tensor.data_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The initializer of graph has unsupported type, name: ", + tensor.name(), " type: ", data_type); + } + + const auto& shape = tensor.dims(); + std::vector dims; + std::transform(shape.cbegin(), shape.cend(), + std::back_inserter(dims), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + + ORT_RETURN_IF_NOT(dims.size() == 4, + "The initializer is not 4D: ", name, " actual dim ", dims.size()); + const uint8_t* src = nullptr; + Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath()); + src = unpacked_tensor.DataAsByteSpan().data(); + const auto out_t = dims[0], in_t = dims[1], + h_t = dims[2], w_t = dims[3]; + std::vector dest_shape; + if (is_conv == 1) + dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 + else + dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv weight + + SafeInt num_elements = SafeInt(Product(dest_shape)); + + size_t element_size = 4; + std::unique_ptr buffer_holder(new uint8_t[element_size * num_elements]); + uint8_t* buffer = buffer_holder.get(); + + for (uint32_t out = 0; out < out_t; out++) { + for (uint32_t in = 0; in < in_t; in++) { + for (uint32_t h = 0; h < h_t; h++) { + for (uint32_t w = 0; w < w_t; w++) { + auto onnx_idx = out * in_t * h_t * w_t + + in * h_t * w_t + + h * w_t + + w; + + uint32_t nnapi_idx; + if (is_conv == 1) { // L_0231 + nnapi_idx = out * h_t * w_t * in_t + + h * w_t * in_t + + w * in_t + + in; + } else { // L_1230 for depthwise conv weight + nnapi_idx = in * h_t * w_t * out_t + + h * w_t * out_t + + w * out_t + + out; + } + + for (size_t i = 0; i < element_size; i++) { + buffer[element_size * nnapi_idx + i] = src[element_size * onnx_idx + i]; + } + } + } + } + } + ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(name, buffer, num_elements * element_size, + dest_shape, 4)); + return Status::OK(); +} + +// Add operator related. + +Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val output = emscripten::val::object(); + + NodeAttrHelper helper(node); + const auto strides = helper.Get("strides", std::vector{1, 1}); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const auto& weight = input_defs[1]->Name(); + + if (op_type == "Conv") { + emscripten::val options = emscripten::val::object(); + ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); + int groups = options["groups"].as(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + bool depthwise = (groups == input_shape[3] && groups != 1); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout( + model_builder, weight, !depthwise)); + if (!depthwise) { + options.set("filterLayout", emscripten::val("ohwi")); + } else { + options.set("filterLayout", emscripten::val("ihwo")); + } + emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); + output = model_builder.GetBuilder().call("conv2d", input, filter, options); + } else { + emscripten::val options = emscripten::val::object(); + ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout( + model_builder, weight, false)); + options.set("filterLayout", emscripten::val("ohwi")); + // When the 'output_shape' is specificed, the 'output_padding' values + // in options.outputPadding are ignored. + std::vector dim; + std::vector output_padding{0, 0}; + if (helper.HasAttr("output_shape")) { + // Default value of 'output_shape' will be ignore as we already check if + // it's existed. + dim = helper.Get("output_shape", std::vector{-1, -1}); + // Extract the height and width. + std::vector output_shape; + if (dim.size() == 2) { + output_shape = dim; + } else if (dim.size() == 4) { + output_shape = {dim[2], dim[3]}; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); + } + // Padding values are auto generated. + if (helper.HasAttr("kernel_shape")) { + std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); + std::vector total_padding(2); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + for (size_t i = 0; i < 2; i++) { + total_padding[i] = strides[i] * (input_shape[i + 1] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + } + pads[0] = total_padding[0] - (total_padding[0] / 2); + pads[1] = total_padding[0] / 2; + pads[2] = total_padding[1] - (total_padding[1] / 2); + pads[3] = total_padding[1] / 2; + options.set("padding", emscripten::val::array(pads)); + } + options.set("outputSizes", emscripten::val::array(output_shape)); + } else { + output_padding = helper.Get("output_padding", std::vector{0, 0}); + options.set("outputPadding", emscripten::val::array(output_padding)); + } + emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); + + output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + const auto& name = node.Name(); + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + + const auto& weight_name = input_defs[1]->Name(); + if (Contains(initializers, weight_name)) { + const auto& tensor = *initializers.at(weight_name); + if (tensor.dims().size() != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() + << " Only conv 2d is supported."; + return false; + } + } else { + LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; + return false; + } + + return true; +} + +void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = + { + "Conv", + "ConvTranspose", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc new file mode 100644 index 0000000000000..2330d85f911e5 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" +#include "builder_utils.h" + +namespace onnxruntime { +namespace webnn { + +class GemmOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, + const logging::Logger& /* logger */) const override; +}; + +// Add operator related. +Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& /* logger */) const { + const auto& input_defs = node.InputDefs(); + const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C + + emscripten::val a = model_builder.GetOperand(node.InputDefs()[a_idx]->Name()); + emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); + emscripten::val options = emscripten::val::object(); + NodeAttrHelper helper(node); + const auto transA = helper.Get("transA", 0); + options.set("aTranspose", emscripten::val(transA == 1)); + const auto transB = helper.Get("transB", 0); + options.set("bTranspose", emscripten::val(transB == 1)); + const auto alpha = helper.Get("alpha", 1.0f); + const auto beta = helper.Get("beta", 1.0f); + options.set("alpha", alpha); + options.set("beta", beta); + + // Add bias if present. + if (input_defs.size() > 2) { + options.set("c", model_builder.GetOperand(node.InputDefs()[c_idx]->Name())); + } + + emscripten::val output = model_builder.GetBuilder().call("gemm", a, b, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + (void)initializers; + const auto& input_defs(node.InputDefs()); + const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C + + std::vector a_shape; + { + if (!GetShape(*input_defs[a_idx], a_shape, logger)) + return false; + + if (a_shape.size() != 2) { + LOGS(logger, VERBOSE) << "A must be 2D"; + return false; + } + + if (Product(a_shape) == 0) { + LOGS(logger, VERBOSE) << "A must be non-empty"; + return false; + } + } + + std::vector b_shape; + { + if (!GetShape(*input_defs[b_idx], b_shape, logger)) + return false; + + if (b_shape.size() != 2) { + LOGS(logger, VERBOSE) << "B must be 2D"; + return false; + } + + if (Product(b_shape) == 0) { + LOGS(logger, VERBOSE) << "B must be non-empty"; + return false; + } + } + + // C of Gemm. + if (input_defs.size() == 3) { + std::vector c_shape; + if (!GetShape(*input_defs[c_idx], c_shape, logger)) + return false; + + size_t c_dim = c_shape.size(); + + if (c_dim > 1) { + // TODO: Supports other shape of C. + // Currently WebNN implementation in Chromium only supports 1-D C. + return false; + } + if (c_dim == 0) { + LOGS(logger, VERBOSE) << "C of Gemm is a scalar"; + } else { + auto c_size = c_shape[c_dim - 1]; + NodeAttrHelper helper(node); + const auto transB = helper.Get("transB", 0); + if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) { + LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape[" + << (transB == 0 ? "1" : "0") << "]" + << " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]" + << " c_size: " << c_size; + + return false; + } + } + } + + return true; +} + +void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc new file mode 100644 index 0000000000000..a5200f612e860 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" +#include "builder_utils.h" + +namespace onnxruntime { +namespace webnn { + +class PoolOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + + bool is_global_pooling = false; + bool is_average_pool = false; + if (op_type == "GlobalAveragePool") { + is_global_pooling = true; + is_average_pool = true; + } else if (op_type == "GlobalMaxPool") { + is_global_pooling = true; + } else if (op_type == "AveragePool") { + is_average_pool = true; + } else if (op_type == "MaxPool") { + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unknown op: ", op_type); + } + + emscripten::val options = emscripten::val::object(); + NodeAttrHelper helper(node); + + const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + if (!is_global_pooling) { + options.set("windowDimensions", emscripten::val::array(kernel_shape)); + } + const auto strides = helper.Get("strides", std::vector{1, 1}); + options.set("strides", emscripten::val::array(strides)); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + options.set("dilations", emscripten::val::array(dilations)); + options.set("layout", emscripten::val("nhwc")); + + // Add Padding. + // Usually using autopadding is more efficient than using explicit padding. + // Try to see if we can map explicit padding to auto padding. + const auto onnx_kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + const auto onnx_strides = helper.Get("strides", std::vector{1, 1}); + const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, onnx_kernel_shape[0], onnx_kernel_shape[1], + onnx_pads, onnx_strides, {1, 1} /* dilations */, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + options.set("autoPad", "same-lower"); + } else { + options.set("autoPad", "same-upper"); + } + } else { + options.set("padding", emscripten::val::array(pads)); + } + + const auto ceil_mode = helper.Get("ceil_mode", 0); + options.set("roundingType", ceil_mode == 0 ? emscripten::val("floor") + : emscripten::val("ceil")); + + emscripten::val output = emscripten::val::object(); + if (is_average_pool) { + output = model_builder.GetBuilder().call("averagePool2d", input, options); + } else { + output = model_builder.GetBuilder().call("maxPool2d", input, options); + } + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. +bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const logging::Logger& logger) const { + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) + << op_type << " only supports rank-4 tensor, input [" + << input_defs[0]->Name() << "] has actual dim count " << input_size; + return false; + } + + if (op_type == "AveragePool" || op_type == "MaxPool") { + NodeAttrHelper helper(node); + const auto storage_order = helper.Get("storage_order", 0); + if (storage_order == 1) { + LOGS(logger, VERBOSE) << "storage_order == 1 is not supported"; + return false; + } + + if (helper.Get("kernel_shape", std::vector{1, 1}).size() != 2) { + LOGS(logger, VERBOSE) << "Only pooling 2d is supported"; + return false; + } + + if (node.OutputDefs().size() != 1) { + LOGS(logger, VERBOSE) << "Argmax in maxpooling is not supported"; + return false; + } + } + + return true; +} + +void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = + { + "GlobalAveragePool", + "GlobalMaxPool", + "AveragePool", + "MaxPool", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc new file mode 100644 index 0000000000000..2e4dc3f4addbb --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/cpu/tensor/reshape_helper.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ReshapeOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; + + // Reshape opset 4- uses attributes for new shape which we do not support for now. + int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; } +}; + +// Add operator related. + +void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); +} + +Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& initializers(model_builder.GetInitializerTensors()); + const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name()); + const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() + ? reinterpret_cast(target_shape_tensor.raw_data().data()) + : target_shape_tensor.int64_data().data(); + + const auto size = target_shape_tensor.dims()[0]; + TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + ReshapeHelper helper(TensorShape(input_shape), target_shape); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + std::vector new_shape; + std::transform(target_shape.cbegin(), target_shape.cend(), + std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + emscripten::val output = model_builder.GetBuilder().call("reshape", + input, emscripten::val::array(new_shape)); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& perm_name = input_defs[1]->Name(); + if (!Contains(initializers, perm_name)) { + LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer"; + return false; + } + + const auto& perm_tensor = *initializers.at(perm_name); + std::vector unpacked_tensor; + auto status = onnxruntime::utils::UnpackInitializerData(perm_tensor, unpacked_tensor); + if (!status.IsOK()) { + LOGS(logger, ERROR) << "Error while unpacking perm_tensor: " << status.ErrorMessage(); + return false; + } + + const int64_t* raw_new_shape = reinterpret_cast(unpacked_tensor.data()); + const auto& perm_dims = perm_tensor.dims(); + if (perm_dims.empty() || perm_dims[0] == 0) { + LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_shape.empty()) { + LOGS(logger, VERBOSE) << "Reshape does not support empty input shape"; + return false; + } + + // WebNN reshape does not support 0 as dimension. + NodeAttrHelper helper(node); + const bool allow_zero = helper.Get("allowzero ", 0) == 1; + if (allow_zero) { + for (int64_t i = 0; i < perm_dims[0]; i++) { + if (raw_new_shape[i] == 0) { + LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled"; + return false; + } + } + } + + return true; +} + +void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc new file mode 100644 index 0000000000000..27e7310002564 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/cpu/tensor/reshape_helper.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ResizeOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; + + // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. + // We only support Resize opset 11+ here. + int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } +}; + +// Helper functions +bool GetResizeScales(const InitializedTensorSet& initializers, + const Node& node, std::vector& scales, + const logging::Logger& logger) { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 3) + return false; + + const auto& scales_tensor = *initializers.at(input_defs[2]->Name()); + if (scales_tensor.dims_size() != 1 || scales_tensor.dims()[0] != 4) + return false; + + std::vector unpacked_tensor; + auto status = onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor); + if (!status.IsOK()) { + LOGS(logger, ERROR) << "Error while unpacking scales_tensor: " << status.ErrorMessage(); + return false; + } + const float* scales_data = reinterpret_cast(unpacked_tensor.data()); + scales = std::vector{scales_data, scales_data + 4}; + return true; +} + +bool GetResizeOutputSizes(const InitializedTensorSet& initializers, + const Node& node, std::vector& sizes, + const logging::Logger& logger) { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 4) + return false; + + const auto& sizes_tensor = *initializers.at(input_defs[3]->Name()); + if (sizes_tensor.dims_size() != 1 || sizes_tensor.dims()[0] != 4) + return false; + + std::vector unpacked_tensor; + auto status = onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor); + if (!status.IsOK()) { + LOGS(logger, ERROR) << "Error while unpacking sizes_tensor: " << status.ErrorMessage(); + return false; + } + const int64_t* sizes_data = reinterpret_cast(unpacked_tensor.data()); + sizes = std::vector{sizes_data, sizes_data + 4}; + return true; +} + +// Add operator related. + +void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // We don't really use ROI here, so add it to skipped list if it's an initializer tensor. + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI + model_builder.AddInputToSkip(node.InputDefs()[1]->Name()); // ROI + + // We will still add scales to the skipped list even sizes are present, + // since there is no use of it, we will not process it later. + model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // scales + model_builder.AddInputToSkip(node.InputDefs()[2]->Name()); // scales + + if (node.InputDefs().size() > 3) { + model_builder.AddInitializerToSkip(node.InputDefs()[3]->Name()); // sizes + model_builder.AddInputToSkip(node.InputDefs()[3]->Name()); // sizes + } +} + +Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + emscripten::val options = emscripten::val::object(); + NodeAttrHelper helper(node); + const auto mode = helper.Get("mode", "nearest"); + if (mode == "linear") { + options.set("mode", emscripten::val("linear")); + } else { // we already checked the mode must be NN or Bilinear in IsOpSupportedImpl. + options.set("mode", emscripten::val("nearest-neighbor")); + } + + const auto& input_defs = node.InputDefs(); + const auto& initializers(model_builder.GetInitializerTensors()); + + std::vector scales; + std::vector sizes; + std::vector scales_hw; + std::vector sizes_hw; + if (input_defs.size() == 3) { // Use scales. + ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); + scales_hw = {scales[2], scales[3]}; + options.set("scales", emscripten::val::array(scales_hw)); + } else { // We already checked number of inputs in IsOpSupportedImpl. + std::vector output_sizes; + ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), + "Error getting resize output_sizes"); + std::transform(output_sizes.cbegin(), output_sizes.cend(), + std::back_inserter(sizes), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + sizes_hw = {sizes[1], sizes[2]}; + options.set("sizes", emscripten::val::array(sizes_hw)); + } + + std::vector axes = {1, 2}; + options.set("axes", emscripten::val::array(axes)); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val output = model_builder.GetBuilder().call("resample2d", input, options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << "Resize only support 4d shape, input is " + << input_size << "d shape"; + return false; + } + + { // Check attributes. + NodeAttrHelper helper(node); + const auto mode = helper.Get("mode", "nearest"); + bool is_linear_resize = mode == "linear"; + bool is_nearest_resize = mode == "nearest"; + if (!is_linear_resize && !is_nearest_resize) { + LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode; + return false; + } + + const auto exclude_outside = helper.Get("exclude_outside", 0); + if (exclude_outside != 0) { + LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; + return false; + } + } + + { // scales and sizes (if present) must be initializers. + if (input_defs.size() < 3) { + LOGS(logger, VERBOSE) << "Input scales or sizes of Resize must be known"; + return false; + } + + // scales + if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) { + LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + return false; + } + + // sizes + if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) { + LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; + return false; + } + + // We want to check if the scales or sizes are not trying to resize on N/C channels here. + if (input_defs.size() == 3) { // We are using scales. + std::vector scales; + if (!GetResizeScales(initializers, node, scales, logger)) + return false; + + float scale_n = scales[0]; + float scale_c = scales[1]; + if (scale_n != 1.0f || scale_c != 1.0f) { + LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" + << "Resize of N/C channels are not supported" + << ", scale_n, " << scale_n << ", scale_c, " << scale_c; + return false; + } + + // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1. + // TODO support ResizeBilinear. + float scale_h = scales[2]; + float scale_w = scales[3]; + + // Onnx spec requires scale to be a positive float, so we are not checking that here. + if (roundf(scale_h) != scale_h) { + LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " is not a whole number"; + return false; + } + + if (roundf(scale_w) != scale_w) { + LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; + return false; + } + } else { + // We are using sizes. + std::vector output_sizes; + if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) + return false; + + bool is_NHWC = input_shape[3] == output_sizes[3]; + auto output_size_n = output_sizes[0]; + const int c_idx = is_NHWC ? 3 : 1; + if (output_size_n != input_shape[0] || output_sizes[c_idx] != input_shape[c_idx]) { + LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " + << "Resize of N/C channels are not supported" + << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n + << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << output_sizes[c_idx]; + return false; + } + + // For now we only support upscale, so the output_size_h and output_size_w should be an integer >= 1. + // TODO support ResizeBilinear + auto output_size_h = output_sizes[2]; + auto output_size_w = output_sizes[3]; + auto input_size_h = input_shape[2]; + auto input_size_w = input_shape[3]; + + // Onnx spec requires output sizes to be a positive integer, so we are not checking that here. + if (output_size_h % input_size_h != 0) { + LOGS(logger, VERBOSE) << "Resize: output_size_h: " << output_size_h + << " is not a multiple of input_size_h: " << input_size_h; + return false; + } + + if (output_size_w % input_size_w != 0) { + LOGS(logger, VERBOSE) << "Resize: output_size_w: " << output_size_w + << " is not a multiple of input_size_w: " << input_size_w; + return false; + } + } + } + + return true; +} + +void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc new file mode 100644 index 0000000000000..eca1521384643 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class TransposeOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; +}; + +// Add operator related. + +Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + std::vector perm = helper.Get("perm", std::vector()); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + auto input_dims = input_shape.size(); + if (perm.empty()) { + for (int64_t i = input_dims - 1; i >= 0; i--) + perm.push_back(i); + } else { + ORT_RETURN_IF_NOT(perm.size() == input_dims, "Perm and input should have same dimension"); + } + + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val options = emscripten::val::object(); + std::vector permutation; + std::transform(perm.cbegin(), perm.cend(), + std::back_inserter(permutation), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + options.set("permutation", emscripten::val::array(permutation)); + emscripten::val output = model_builder.GetBuilder().call("transpose", input, options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc new file mode 100644 index 0000000000000..bc385dea6c2a8 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" +#include "core/common/safeint.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" +#include "core/providers/webnn/builders/helper.h" +#include "model.h" + +namespace onnxruntime { +namespace webnn { + +Model::Model(const emscripten::val& context, const emscripten::val& graph, const logging::Logger& logger) + : wnn_context_(context), + wnn_graph_(graph), + logger_(logger) {} + +Model::~Model() {} + +Status Model::Predict(const InlinedHashMap& inputs, + const InlinedHashMap& outputs) { + for (const auto& input : inputs) { + const std::string& name = input.first; + const struct OnnxTensorData tensor = input.second; + if (tensor.tensor_info.data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The input of graph has unsupported type, name: ", + name, " type: ", tensor.tensor_info.data_type); + } + auto num_elements = SafeInt(Product(tensor.tensor_info.shape)); + emscripten::val view{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers. + wnn_inputs_[name].call("set", view); +#else + wnn_inputs_.set(name, view); +#endif + } + +#ifdef ENABLE_WEBASSEMBLY_THREADS + // This vector uses for recording output buffers from WebNN graph compution when WebAssembly + // multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView, + // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + // and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory + // address is different from the non-shared one, additional memory copy is required here. + InlinedHashMap output_views; +#endif + for (const auto& output : outputs) { + const std::string& name = output.first; + const struct OnnxTensorData tensor = output.second; + if (tensor.tensor_info.data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The input of graph has unsupported type, name: ", + name, " type: ", tensor.tensor_info.data_type); + } + auto num_elements = SafeInt(Product(tensor.tensor_info.shape)); + emscripten::val view{emscripten::typed_memory_view(num_elements, static_cast(tensor.buffer))}; +#ifdef ENABLE_WEBASSEMBLY_THREADS + output_views.insert({name, view}); +#else + wnn_outputs_.set(name, view); +#endif + } + wnn_context_.call("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_); + +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer. + for (const auto& output : outputs) { + const std::string& name = output.first; + emscripten::val view = output_views.at(name); + view.call("set", wnn_outputs_[name]); + } +#endif + return Status::OK(); +} + +bool Model::IsScalarOutput(const std::string& output_name) const { + return Contains(scalar_outputs_, output_name); +} + +const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { + return input_output_info_.at(name); +} + +void Model::SetInputMap(InlinedHashMap&& input_map) { + input_map_ = std::move(input_map); +} + +void Model::SetOutputMap(InlinedHashMap&& output_map) { + output_map_ = std::move(output_map); +} + +// Pre-allocate the input and output buffers for the WebNN graph. +void Model::AllocateInputOutputBuffers() { + for (const auto& input : inputs_) { + const auto& input_info = input_output_info_.at(input); + const auto input_shape = input_info.shape; + const auto num_elements = SafeInt(Product(input_shape)); + wnn_inputs_.set(input, + emscripten::val::global("Float32Array").new_(static_cast(num_elements))); + } + for (const auto& output : outputs_) { + const auto& output_info = input_output_info_.at(output); + const auto output_shape = output_info.shape; + const auto num_elements = SafeInt(Product(output_shape)); + wnn_outputs_.set(output, + emscripten::val::global("Float32Array").new_(static_cast(num_elements))); + } +} + +size_t Model::GetMappedInputIdx(const std::string& name) const { + return input_map_.at(name); +} + +size_t Model::GetMappedOutputIdx(const std::string& name) const { + return output_map_.at(name); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model.h b/onnxruntime/core/providers/webnn/builders/model.h new file mode 100644 index 0000000000000..4af82a2675691 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/model.h @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/common/status.h" +#include "core/platform/ort_mutex.h" + +#include +#include + +namespace onnxruntime { +namespace webnn { + +struct OnnxTensorInfo { + const int32_t data_type; // Uses TensorProto::DataType. + const std::vector shape; +}; + +struct OnnxTensorData { + OnnxTensorInfo tensor_info; + void* buffer{nullptr}; +}; + +class Model { + friend class ModelBuilder; + + public: + ~Model(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); + + onnxruntime::common::Status Predict(const InlinedHashMap& inputs, + const InlinedHashMap& outputs); + + bool IsScalarOutput(const std::string& output_name) const; + + // Mutex for exclusive lock to this model object. + OrtMutex& GetMutex() { return mutex_; } + + // Input and output names in the onnx model's order. + const std::vector& GetInputs() const { return inputs_; } + void SetInputs(std::vector&& inputs) { inputs_ = std::move(inputs); } + + const std::vector& GetOutputs() const { return outputs_; } + void SetOutputs(std::vector&& outputs) { outputs_ = std::move(outputs); } + + const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const; + + // Set the mapping between input/output name and ORT kernel context + // input/output index, at execution time. + void SetInputMap(InlinedHashMap&& input_map); + void SetOutputMap(InlinedHashMap&& output_map); + + // Get the ORT kernel context input/output index with given name. + size_t GetMappedInputIdx(const std::string& name) const; + size_t GetMappedOutputIdx(const std::string& name) const; + + private: + emscripten::val wnn_context_ = emscripten::val::object(); + emscripten::val wnn_graph_ = emscripten::val::object(); + const logging::Logger& logger_; + + emscripten::val wnn_inputs_ = emscripten::val::object(); + emscripten::val wnn_outputs_ = emscripten::val::object(); + + InlinedHashSet scalar_outputs_; + + std::vector inputs_; + std::vector outputs_; + + InlinedHashMap input_output_info_; + + InlinedHashMap input_map_; + InlinedHashMap output_map_; + + OrtMutex mutex_; + + Model(const emscripten::val& context, const emscripten::val& path, const logging::Logger& logger); + + void SetInputOutputInfo(InlinedHashMap&& input_output_info) { + input_output_info_ = std::move(input_output_info); + } + + void SetScalarOutputs(InlinedHashSet&& scalar_outputs) { + scalar_outputs_ = std::move(scalar_outputs); + } + + void AllocateInputOutputBuffers(); +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc new file mode 100644 index 0000000000000..1b3ba5fcde88e --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "model_builder.h" +#include "model.h" +#include "helper.h" +#include "op_builder_factory.h" + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace webnn { + +ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + const emscripten::val& context, const emscripten::val& builder) + : graph_viewer_(graph_viewer), + logger_(logger), + wnn_context_(context), + wnn_builder_(builder) {} + +Status ModelBuilder::Initialize() { + PreprocessInitializers(); + PreprocessActivations(); + ORT_RETURN_IF_ERROR(RegisterInitializers()); + ORT_RETURN_IF_ERROR(RegisterModelInputs()); + ORT_RETURN_IF_ERROR(AddOperations()); + ORT_RETURN_IF_ERROR(RegisterModelOutputs()); + + return Status::OK(); +} + +/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) { + const auto& op_builders = GetOpBuilders(); + const auto it = op_builders.find(node.OpType()); + if (it != op_builders.cend()) + return it->second; + + return nullptr; +} + +void ModelBuilder::PreprocessInitializers() { + const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + for (size_t i = 0; i < node_indices.size(); i++) { + const auto* node(graph_viewer_.GetNode(node_indices[i])); + if (const auto* op_builder = GetOpBuilder(*node)) { + op_builder->AddInitializersToSkip(*this, *node); + } + } +} + +emscripten::val GetClampOperator( + const emscripten::val& builder, float min_value, float max_value) { + emscripten::val options = emscripten::val::object(); + options.set("minValue", min_value); + options.set("maxValue", max_value); + return builder.call("clamp", options); +} + +void ModelBuilder::PreprocessActivations() { + const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + for (size_t i = 0; i < node_indices.size(); i++) { + const auto* node(graph_viewer_.GetNode(node_indices[i])); + const auto& op_type(node->OpType()); + + if (op_type == "Relu") { + activation_nodes_.emplace(node->Index(), wnn_builder_.call("relu")); + } else if (op_type == "LeakyRelu") { + NodeAttrHelper helper(*node); + emscripten::val options = emscripten::val::object(); + options.set("alpha", helper.Get("alpha", (float)0.0)); + activation_nodes_.emplace(node->Index(), wnn_builder_.call("leakyRelu", options)); + } else if (op_type == "Sigmoid") { + activation_nodes_.emplace(node->Index(), wnn_builder_.call("sigmoid")); + } else if (op_type == "Clip") { + float minValue, maxValue; + GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); + activation_nodes_.emplace(node->Index(), GetClampOperator(wnn_builder_, minValue, maxValue)); + } + } +} + +Status ModelBuilder::RegisterInitializers() { + for (const auto& pair : GetInitializerTensors()) { + const auto& tensor = *pair.second; + const auto& name = tensor.name(); + if (Contains(skipped_initializers_, name)) + continue; + + const auto& shape = tensor.dims(); + std::vector dims; + if (shape.empty()) { + // This is a scalar initializer, WebNN requires a shape, make this a {1} tensor. + dims = {1}; + } else { + std::transform(shape.cbegin(), shape.cend(), + std::back_inserter(dims), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + } + + emscripten::val desc = emscripten::val::object(); + desc.set("dimensions", emscripten::val::array(dims)); + auto data_type = tensor.data_type(); + emscripten::val operand = emscripten::val::object(); + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + unpacked_tensors_.push_back({}); + std::vector& unpacked_tensor = unpacked_tensors_.back(); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); + auto num_elements = SafeInt(Product(tensor.dims())); + desc.set("type", emscripten::val("float32")); + emscripten::val view{emscripten::typed_memory_view(num_elements, + reinterpret_cast(unpacked_tensor.data()))}; +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Workaround for WebAssembly multi-threads enabled since WebNN API only accepts non-shared ArrayBufferView. + // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + operand = wnn_builder_.call("constant", desc, view.call("slice")); +#else + operand = wnn_builder_.call("constant", desc, view); +#endif + + } else { + // TODO: support other type. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The initializer of graph has unsupported type, name: ", + tensor.name(), " type: ", data_type); + } + wnn_operands_.insert(std::make_pair(name, operand)); + } + + return Status::OK(); +} + +Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_input) { + const auto& name = node_arg.Name(); + const std::string input_output_type = is_input ? "input" : "output"; + + if (is_input) { + // Input should not be an initializer. + if (Contains(GetInitializerTensors(), name)) + return Status::OK(); + + // This input will not be used. + if (Contains(skipped_inputs_, name)) + return Status::OK(); + } + + std::vector dims; + { // input_output shape. + const auto* shape_proto = node_arg.Shape(); + ORT_RETURN_IF(shape_proto == nullptr, + "shape_proto cannot be null for ", input_output_type, ": ", name); + const auto& shape = shape_proto->dim(); + if (shape.empty()) { + // If we have an empty shape, this is a scalar input. + dims.push_back(1); + + // We need to change the shapes of these scalar outputs back to {} + // when WebNN EP returns these values to ORT. + if (!is_input) { + AddScalarOutput(name); + } + } else { + dims.reserve(shape.size()); + for (const auto& dim : shape) { + if (!dim.has_dim_value()) { + // FIXME: support dyanmic shape. + dims.push_back(1); + } else { + dims.push_back(SafeInt(dim.dim_value())); + } + } + } + } + + emscripten::val desc = emscripten::val::object(); + + desc.set("dimensions", emscripten::val::array(dims)); + + int32_t data_type; + { // type + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->tensor_type().has_elem_type()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The ", input_output_type, " of graph doesn't have elem_type: ", name); + } + + data_type = type_proto->tensor_type().elem_type(); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + desc.set("type", emscripten::val("float32")); + break; + default: { + // TODO: support other type. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The ", input_output_type, " of graph doesn't have valid type, name: ", name, + " type: ", type_proto->tensor_type().elem_type()); + } + } + } + + if (is_input) { + wnn_operands_.insert(std::make_pair(name, wnn_builder_.call("input", name, desc))); + input_names_.push_back(name); + } else { + output_names_.push_back(name); + } + + std::vector shape; + std::transform(dims.cbegin(), dims.cend(), + std::back_inserter(shape), + [](int32_t dim) -> int64_t { return SafeInt(dim); }); + input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape}); + + return Status::OK(); +} + +Status ModelBuilder::RegisterModelInputs() { + for (const auto* node_arg : graph_viewer_.GetInputs()) { + ORT_RETURN_IF_ERROR(RegisterModelInputOutput(*node_arg, true /* is_input */)); + } + + return Status::OK(); +} + +Status ModelBuilder::AddOperations() { + const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + for (size_t i = 0; i < node_indices.size(); i++) { + const auto* node(graph_viewer_.GetNode(node_indices[i])); + if (const auto* op_builder = GetOpBuilder(*node)) { + ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node, logger_)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Node [", node->Name(), "], type [", node->OpType(), "] is not supported"); + } + } + + return Status::OK(); +} + +Status ModelBuilder::AddOperandFromPersistMemoryBuffer( + const std::string& name, const void* buffer, const size_t size, + const std::vector shape, const size_t element_size) { + auto persist_buffer = std::make_unique(size); + uint8_t* dest = persist_buffer.get(); + memcpy(dest, buffer, size); + emscripten::val view{emscripten::typed_memory_view(size / element_size, reinterpret_cast(dest))}; + emscripten::val desc = emscripten::val::object(); + desc.set("dimensions", emscripten::val::array(shape)); + desc.set("type", emscripten::val("float32")); + emscripten::val operand = emscripten::val::object(); +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Workaround for WebAssembly multi-threads enabled since WebNN API only accepts non-shared ArrayBufferView. + // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + operand = wnn_builder_.call("constant", desc, view.call("slice")); +#else + operand = wnn_builder_.call("constant", desc, view); +#endif + AddOperand(name, operand); + mem_persist_buffers_.push_back(std::move(persist_buffer)); + return Status::OK(); +} + +Status ModelBuilder::RegisterModelOutputs() { + for (const auto* node_arg : graph_viewer_.GetOutputs()) { + ORT_RETURN_IF_ERROR(RegisterModelInputOutput(*node_arg, false /* is_input */)); + } + + return Status::OK(); +} + +Status ModelBuilder::Compile(std::unique_ptr& model) { + ORT_RETURN_IF_ERROR(Initialize()); + emscripten::val named_operands = emscripten::val::object(); + for (auto& name : output_names_) { + named_operands.set(name, wnn_operands_.at(name)); + } + emscripten::val wnn_graph = wnn_builder_.call("buildSync", named_operands); + if (!wnn_graph.as()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); + } + model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); + model->SetInputs(std::move(input_names_)); + model->SetOutputs(std::move(output_names_)); + model->SetScalarOutputs(std::move(scalar_outputs_)); + model->SetInputOutputInfo(std::move(input_output_info_)); +#ifdef ENABLE_WEBASSEMBLY_THREADS + // Pre-allocate the input and output tensors for the WebNN graph + // when WebAssembly multi-threads is enabled since WebNN API only + // accepts non-shared ArrayBufferView. + // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + model->AllocateInputOutputBuffers(); +#endif + return Status::OK(); +} + +// supported_nodes is provided by the op to indicate whether it can be fused with the activation node. +emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output, + const InlinedHashSet supported_nodes) { + emscripten::val fused_op = emscripten::val::null(); + for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { + const auto& dst_node = it->GetNode(); + const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()]; + if (!Contains(supported_nodes, dst_node.OpType())) { + return emscripten::val::null(); + } + if (Contains(activation_nodes_, dst_node.Index())) { + if (&output == dst_input) { + fused_op = activation_nodes_.at(dst_node.Index()); + } + } else { + // If there is any other non-relu node using the output + // will add relu separately. + if (&output == dst_input) { + return emscripten::val::null(); + } + } + } + + // If output is a graph output, will add relu separately. + if (fused_op != emscripten::val::null()) { + for (const auto* graph_output : graph_viewer_.GetOutputs()) { + if (&output == graph_output) { + return emscripten::val::null(); + } + } + + LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType() + << "], fused the output [" << output.Name() << "]"; + + fused_activations_.insert(output.Name()); + } + + return fused_op; +} + +void ModelBuilder::AddScalarOutput(const std::string& output_name) { + scalar_outputs_.insert(output_name); +} + +void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& operand) { + wnn_operands_.insert(std::make_pair(name, operand)); +} + +void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { + skipped_initializers_.insert(tensor_name); +} + +void ModelBuilder::AddInputToSkip(const std::string& input_name) { + skipped_inputs_.insert(input_name); +} + +std::string ModelBuilder::GetUniqueName(const std::string& base_name) { + std::string unique_name; + do { + std::ostringstream os; + os << base_name << "_token_" << name_token_++; + unique_name = os.str(); + } while (Contains(unique_names_, unique_name)); + + return unique_name; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h new file mode 100644 index 0000000000000..56cd756688da5 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include + +#include "model.h" + +#include +#include + +namespace onnxruntime { +namespace webnn { + +class IOpBuilder; + +class ModelBuilder { + public: + ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + const emscripten::val& context, const emscripten::val& builder); + ~ModelBuilder() = default; + + Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; + + // Accessors for members. + const GraphViewer& GetGraphViewer() const { return graph_viewer_; } + const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } + + const emscripten::val& GetBuilder() const { return wnn_builder_; } + const emscripten::val& GetContext() const { return wnn_context_; } + const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } + void AddOperand(const std::string& name, const emscripten::val& operand); + // Use the buffers to persist WebNN allocated data like transposed weight. + // It ensures the validity during inference session. + std::vector> mem_persist_buffers_; + // Add a constant operand (allocate persist buffer and move the ownership to mem_persist_buffers_). + Status AddOperandFromPersistMemoryBuffer( + const std::string& name, const void* buffer, + const size_t size, const std::vector shape, const size_t element_size = 4); + // Find if an output has a fuseable activation (e.g., Relu). + emscripten::val FindActivation(const Node& node, const NodeArg& output, + const InlinedHashSet supported_nodes = {}); + + const InlinedHashSet& + GetFusedActivations() const { return fused_activations_; } + + // The initializer will be processed separately, skip it as an initializer. + void AddInitializerToSkip(const std::string& tensor_name); + + // There are some input which will not be used, add it to a list which will not + // be added to CoreML model, since CoreML does not like input unused. + void AddInputToSkip(const std::string& input_name); + + std::string GetUniqueName(const std::string& base_name); + + private: + const GraphViewer& graph_viewer_; + const logging::Logger& logger_; + + emscripten::val wnn_context_ = emscripten::val::object(); + emscripten::val wnn_builder_ = emscripten::val::object(); + std::vector> unpacked_tensors_; + InlinedHashMap wnn_operands_; + std::vector input_names_; + std::vector output_names_; + + InlinedHashSet scalar_outputs_; + InlinedHashMap input_output_info_; + + InlinedHashSet skipped_initializers_; + InlinedHashSet skipped_inputs_; + + InlinedHashSet fused_activations_; + + uint32_t name_token_{0}; + InlinedHashSet unique_names_; + + // All activation nodes (e.g., Relu) as a map . + InlinedHashMap activation_nodes_; + + // Convert the onnx model to WebNN operands + Status Initialize() ORT_MUST_USE_RESULT; + + void PreprocessInitializers(); + // Preprocess all the activation nodes (e.g., Relu) for easy query later. + void PreprocessActivations(); + + // Copy and process all the initializers to WebNN constants. + Status RegisterInitializers() ORT_MUST_USE_RESULT; + + Status AddOperations() ORT_MUST_USE_RESULT; + Status RegisterModelInputs() ORT_MUST_USE_RESULT; + Status RegisterModelOutputs() ORT_MUST_USE_RESULT; + Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input) ORT_MUST_USE_RESULT; + + // Record the onnx scalar output names. + void AddScalarOutput(const std::string& output_name); + + static const IOpBuilder* GetOpBuilder(const Node& node); +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder.h b/onnxruntime/core/providers/webnn/builders/op_builder.h new file mode 100644 index 0000000000000..efa70ab3d5455 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/op_builder.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace webnn { + +class ModelBuilder; + +class IOpBuilder { + public: + virtual ~IOpBuilder() = default; + + // Add operator related. + public: + // Check if the initializers of this operator need preprocess, + // which will not be copied. + virtual void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const = 0; + + // Add the operator to WebNN model. + virtual Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0; + + // Operator support related. + public: + // Check if an operator is supported. + virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const = 0; +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc new file mode 100644 index 0000000000000..c13677547ff3b --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include + +#include "op_builder_factory.h" + +namespace onnxruntime { +namespace webnn { + +static OpBuilderRegistrations CreateOpBuilderRegistrations() { + OpBuilderRegistrations op_registrations; + + { // Binary + CreateBinaryOpBuilder("Add", op_registrations); + CreateBinaryOpBuilder("Sub", op_registrations); + CreateBinaryOpBuilder("Mul", op_registrations); + CreateBinaryOpBuilder("Div", op_registrations); + } + + { // Activations + CreateActivationOpBuilder("Relu", op_registrations); + CreateActivationOpBuilder("LeakyRelu", op_registrations); + CreateActivationOpBuilder("Sigmoid", op_registrations); + } + + { // Clip + CreateClipOpBuilder("Clip", op_registrations); + } + + { // Conv + CreateConvOpBuilder("Conv", op_registrations); + CreateConvOpBuilder("ConvTranspose", op_registrations); + } + + { // Concat + CreateConcatOpBuilder("Concat", op_registrations); + } + + { // Gemm + CreateGemmOpBuilder("Gemm", op_registrations); + } + + { // Pool + CreatePoolOpBuilder("GlobalAveragePool", op_registrations); + CreatePoolOpBuilder("GlobalMaxPool", op_registrations); + CreatePoolOpBuilder("AveragePool", op_registrations); + CreatePoolOpBuilder("MaxPool", op_registrations); + } + + { // Reshape + CreateReshapeOpBuilder("Reshape", op_registrations); + } + + { // Resize + CreateResizeOpBuilder("Resize", op_registrations); + } + + { // Transpose + CreateTransposeOpBuilder("Transpose", op_registrations); + } + + return op_registrations; +} + +const InlinedHashMap& GetOpBuilders() { + static const OpBuilderRegistrations op_registrations = CreateOpBuilderRegistrations(); + return op_registrations.op_builder_map; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h new file mode 100644 index 0000000000000..ffbbf2d92d432 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "op_builder.h" + +namespace onnxruntime { +namespace webnn { + +struct OpBuilderRegistrations { + std::vector> builders; + InlinedHashMap op_builder_map; +}; + +// Get the lookup table with IOpBuilder delegates for different onnx operators. +// Note, the lookup table should have same number of entries as the result of CreateOpSupportCheckers() +// in op_support_checker.h. +const InlinedHashMap& GetOpBuilders(); + +void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc new file mode 100644 index 0000000000000..a8589f1d4d1aa --- /dev/null +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -0,0 +1,385 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "webnn_execution_provider.h" + +#include "core/framework/allocatormgr.h" +#include "core/framework/compute_capability.h" +#include "core/framework/memcpy.h" +#include "core/framework/kernel_registry.h" +#include "core/graph/graph_viewer.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/common/safeint.h" + +#include "builders/model.h" +#include "builders/helper.h" +#include "builders/model_builder.h" + +namespace onnxruntime { + +constexpr const char* WEBNN = "WebNN"; + +WebNNExecutionProvider::WebNNExecutionProvider( + const std::string& webnn_device_flags, const std::string& webnn_power_flags) + : IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} { + AllocatorCreationInfo device_info( + [](int) { + return std::make_unique(OrtMemoryInfo(WEBNN, OrtAllocatorType::OrtDeviceAllocator)); + }); + + InsertAllocator(CreateAllocator(device_info)); + + AllocatorCreationInfo cpu_memory_info( + [](int) { + return std::make_unique( + OrtMemoryInfo(WEBNN, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); + }); + + InsertAllocator(CreateAllocator(cpu_memory_info)); + + // Create WebNN context and graph builder. + const emscripten::val ml = emscripten::val::global("navigator")["ml"]; + if (!ml.as()) { + ORT_THROW("Failed to get ml from navigator."); + } + emscripten::val context_options = emscripten::val::object(); + // Currently WebNN implementation in Chromium temporarily reuses the MLContextOptions + // defined in Model Loader API, which uses MLDevicePreference instead of MLDeviceType + // defined in WebNN. Because there's an ongoing spec discussion to simplify this API at + // https://github.com/webmachinelearning/webnn/issues/302. + context_options.set("devicePreference", emscripten::val(webnn_device_flags)); + if (webnn_power_flags.compare("default") != 0) { + context_options.set("powerPreference", emscripten::val(webnn_power_flags)); + } + wnn_context_ = ml.call("createContextSync", context_options); + if (!wnn_context_.as()) { + ORT_THROW("Failed to create WebNN context."); + } + wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + if (!wnn_builder_.as()) { + ORT_THROW("Failed to create WebNN builder."); + } +} + +WebNNExecutionProvider::~WebNNExecutionProvider() {} + +std::vector> +WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_registries*/) const { + std::vector> result; + + // We do not run WebNN EP on subgraph, instead we cover this in the control flow nodes. + // TODO investigate whether we want to support subgraph using WebNN EP. + if (graph_viewer.IsSubgraph()) { + return result; + } + + /* + Very basic search for groups of nodes that can be handled by the EP. + This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP + but B is between them in the topological sort as you'll get two single node capabilities. However if can also + be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected. + Not sure how often each of these scenarios happens. + + A B C + | / | + D E + | | + + Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order, + accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all + connected nodes that can be handled are grouped together. + */ + + const auto& logger = *GetLogger(); + + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, logger); + + if (node_groups.empty()) { + return result; + } + + const auto& graph_output_list = graph_viewer.GetOutputs(); + InlinedHashSet graph_outputs(graph_output_list.cbegin(), graph_output_list.cend()); + + size_t num_of_supported_nodes = 0; + for (const auto& group : node_groups) { + if (group.empty()) + continue; + + num_of_supported_nodes += group.size(); + LOGS(logger, VERBOSE) << "WebNNExecutionProvider::GetCapability, current supported node group size: " + << group.size(); + + InlinedHashSet node_set; + node_set.reserve(group.size()); + for (const auto& index : group) { + node_set.insert(index); + } + + std::unique_ptr sub_graph = std::make_unique(); + + InlinedHashSet node_outputs; + InlinedHashSet subgraph_inputs; + InlinedHashSet subgraph_outputs; + std::vector ordered_subgraph_inputs; + std::vector ordered_subgraph_outputs; + + for (const auto& index : group) { + sub_graph->nodes.push_back(index); + const auto* node = graph_viewer.GetNode(index); + + for (const auto* input : node->InputDefs()) { + // if the node input was not produced by this subgraph, add it to the subgraph inputs. + if (node_outputs.count(input) == 0) { + if (subgraph_inputs.count(input) == 0) { + subgraph_inputs.insert(input); + ordered_subgraph_inputs.push_back(input); + } + } + } + + const auto& output_defs = node->OutputDefs(); + for (const auto* output_def : output_defs) { + node_outputs.insert(output_def); + // if output is overall graph output we need to produce it. + if (graph_outputs.count(output_def) != 0) { + ordered_subgraph_outputs.push_back(output_def); + } + } + + // if output connects to a node not in this subgraph we need to produce it. + for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { + if (node_set.count(it->GetNode().Index()) == 0) { + const auto* output_def = output_defs[it->GetSrcArgIndex()]; + if (subgraph_outputs.count(output_def) == 0) { + subgraph_outputs.insert(output_def); + ordered_subgraph_outputs.push_back(output_def); + } + } + } + } + + // Assign inputs and outputs to subgraph's meta_def. + uint64_t model_hash; + int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); + meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); + meta_def->domain = kMSDomain; + meta_def->since_version = 1; + meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + + for (const auto& input : ordered_subgraph_inputs) { + meta_def->inputs.push_back(input->Name()); + } + + for (const auto& output : ordered_subgraph_outputs) { + meta_def->outputs.push_back(output->Name()); + } + + sub_graph->SetMetaDef(std::move(meta_def)); + + result.push_back(std::make_unique(std::move(sub_graph))); + } + + auto num_of_partitions = result.size(); + const auto summary_msg = MakeString( + "WebNNExecutionProvider::GetCapability,", + " number of partitions supported by WebNN: ", num_of_partitions, + " number of nodes in the graph: ", graph_viewer.NumberOfNodes(), + " number of nodes supported by WebNN: ", num_of_supported_nodes); + + // If the graph is partitioned in multiple subgraphs, and this may impact performance, + // we want to give users a summary message at warning level. + if (num_of_partitions > 1) { + LOGS(logger, WARNING) << summary_msg; + } else { + LOGS(logger, INFO) << summary_msg; + } + + return result; +} + +common::Status WebNNExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { + Node& fused_node = fused_node_and_graph.fused_node; + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + + webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_, wnn_builder_); + std::unique_ptr model; + ORT_RETURN_IF_ERROR(builder.Compile(model)); + // Build map from input name to its index in input definitions. + { + InlinedHashMap input_map; + const auto& input_defs = fused_node.InputDefs(); + input_map.reserve(input_defs.size()); + for (size_t i = 0, end = input_defs.size(); i < end; ++i) { + input_map[input_defs[i]->Name()] = i; + } + model->SetInputMap(std::move(input_map)); + } + // Build map from output name to its index in output definitions. + { + InlinedHashMap output_map; + const auto& output_defs = fused_node.OutputDefs(); + output_map.reserve(output_defs.size()); + for (size_t i = 0, end = output_defs.size(); i < end; ++i) { + output_map[output_defs[i]->Name()] = i; + } + model->SetOutputMap(std::move(output_map)); + } + models_.emplace(fused_node.Name(), std::move(model)); + + NodeComputeInfo compute_info; + compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) { + *state = models_[context->node_name].get(); + return 0; + }; + + compute_info.release_state_func = [](FunctionState state) { + // The `state` is a webnn::model managed by unique_ptr. + ORT_UNUSED_PARAMETER(state); + }; + + compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + Ort::KernelContext ctx(context); + + const size_t num_inputs = ctx.GetInputCount(); + const size_t num_outputs = ctx.GetOutputCount(); + + // Ort::CustomOpApi ort{*api}; + webnn::Model* model = reinterpret_cast(state); + + const auto& model_inputs = model->GetInputs(); + const auto& model_outputs = model->GetOutputs(); + + ORT_RETURN_IF_NOT(model_inputs.size() <= num_inputs, "Inconsistent input sizes"); + ORT_RETURN_IF_NOT(model_outputs.size() == num_outputs, "Inconsistent output sizes"); + + InlinedHashMap inputs; + inputs.reserve(model_inputs.size()); + for (size_t i = 0; i < model_inputs.size(); i++) { + const auto& input_name = model_inputs[i]; + auto input_idx = model->GetMappedInputIdx(input_name); + auto input_tensor = ctx.GetInput(input_idx); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + // If we have an empty shape, this is a scalar input, + // Since all the input output of WebNN EP is MultiArray, we will make the scalar input as a {1} MultiArray. + if (shape.empty()) + shape.push_back(1); + std::vector temp(shape.size()); + transform(shape.begin(), shape.end(), temp.begin(), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + const void* inputBuffer = const_cast(input_tensor.GetTensorRawData()); + inputs.emplace( + input_name, + webnn::OnnxTensorData{ + webnn::OnnxTensorInfo{tensor_info.GetElementType(), shape}, + const_cast(inputBuffer), + }); + } + + // From this point we will need to take the exclusive lock on the model until the Predict is + // performed, to block other threads to perform Predict on the same model. + // TODO, investigate concurrent runs for different executions from the same model. + { + std::unique_lock lock(model->GetMutex()); + InlinedHashMap outputs; + outputs.reserve(model_outputs.size()); + for (size_t i = 0; i < model_outputs.size(); i++) { + const auto& output_name = model_outputs[i]; + const auto& output_info = model->GetInputOutputInfo(output_name); + auto output_shape = output_info.shape; + auto output_type = output_info.data_type; + + // Since WebNN EP use {1} tensor as scalar, if the model output should have empty shape. + // We are going to replace the {1} shape of the output back to {}. + if (model->IsScalarOutput(output_name)) + output_shape.clear(); + + auto output_tensor = + ctx.GetOutput(i, output_shape.data(), output_shape.size()); + + void* output_buffer; + switch (output_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + output_buffer = output_tensor.GetTensorMutableRawData(); + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unsupported type: ", output_type, " for output: ", output_name); + break; + } + + outputs.emplace(output_name, + webnn::OnnxTensorData{ + webnn::OnnxTensorInfo{output_type, output_shape}, + output_buffer, + }); + } + + return model->Predict(inputs, outputs); + } + }; + + node_compute_funcs.push_back(compute_info); + } + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kWebNNExecutionProvider, + KernelDefBuilder() + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kWebNNExecutionProvider, + KernelDefBuilder() + .OutputMemoryType(OrtMemTypeCPUOutput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME( + kWebNNExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME( + kWebNNExecutionProvider, kOnnxDomain, 1, MemcpyToHost); + +static void RegisterWebNNKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : function_table) { + ORT_ENFORCE(kernel_registry.Register(function_table_entry()).IsOK()); + } +} + +std::shared_ptr GetWebNNKernelRegistry() { + std::shared_ptr kernel_registry = + std::make_shared(); + RegisterWebNNKernels(*kernel_registry); + + return kernel_registry; +} + +std::shared_ptr +WebNNExecutionProvider::GetKernelRegistry() const { + static std::shared_ptr kernel_registry = + onnxruntime::GetWebNNKernelRegistry(); + return kernel_registry; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h new file mode 100644 index 0000000000000..697b7ecd6d462 --- /dev/null +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/execution_provider.h" + +#include +#include + +namespace onnxruntime { +namespace webnn { +class Model; +} + +class WebNNExecutionProvider : public IExecutionProvider { + public: + WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_power_flags); + virtual ~WebNNExecutionProvider(); + + std::vector> + GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_registries*/) const override; + + DataLayout GetPreferredLayout() const override { return DataLayout::NHWC; } + + // We implement the Compile that takes FusedNodeAndGraph instances. + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } + + // WebNN does not support concurrent execution of a kernel. + bool ConcurrentRunSupported() const override { return false; } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + common::Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; +#endif + + std::shared_ptr GetKernelRegistry() const override; + + private: + emscripten::val wnn_context_ = emscripten::val::object(); + emscripten::val wnn_builder_ = emscripten::val::object(); + + InlinedHashMap> models_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc new file mode 100644 index 0000000000000..4d6b04c8e76d8 --- /dev/null +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webnn/webnn_provider_factory_creator.h" +#include "core/session/abi_session_options_impl.h" +#include "webnn_execution_provider.h" + +using namespace onnxruntime; + +namespace onnxruntime { +struct WebNNProviderFactory : IExecutionProviderFactory { + WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_power_flags) + : webnn_device_flags_(webnn_device_flags), webnn_power_flags_(webnn_power_flags) {} + ~WebNNProviderFactory() override {} + + std::unique_ptr CreateProvider() override; + + std::string webnn_device_flags_; + std::string webnn_power_flags_; +}; + +std::unique_ptr WebNNProviderFactory::CreateProvider() { + return std::make_unique(webnn_device_flags_, webnn_power_flags_); +} + +std::shared_ptr WebNNProviderFactoryCreator::Create( + const ProviderOptions& provider_options) { + return std::make_shared(provider_options.at("deviceType"), + provider_options.at("powerPreference")); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory_creator.h b/onnxruntime/core/providers/webnn/webnn_provider_factory_creator.h new file mode 100644 index 0000000000000..d2d21358e0e48 --- /dev/null +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory_creator.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/provider_options.h" +#include "core/providers/providers.h" + +namespace onnxruntime { + +struct WebNNProviderFactoryCreator { + static std::shared_ptr Create(const ProviderOptions& provider_options); +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 8cadccd0ef376..3d712a81cd5a2 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -83,6 +83,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(XnnpackProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "WEBNN") == 0) { +#if defined(USE_WEBNN) + std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu"); + std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default"); + provider_options["deviceType"] = deviceType; + provider_options["powerPreference"] = powerPreference; + options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); #endif } else if (strcmp(provider_name, "AZURE") == 0) { #if defined(USE_AZURE) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 228f39beae84b..8817e6242cd53 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -477,6 +477,7 @@ def convert_arg_line_to_args(self, arg_line): help="Build with OpenVINO for specific hardware.", ) parser.add_argument("--use_coreml", action="store_true", help="Build with CoreML support.") + parser.add_argument("--use_webnn", action="store_true", help="Build with WebNN support.") parser.add_argument("--use_snpe", action="store_true", help="Build with SNPE support.") parser.add_argument("--snpe_root", help="Path to SNPE SDK root.") parser.add_argument("--use_nnapi", action="store_true", help="Build with NNAPI support.") @@ -979,6 +980,7 @@ def generate_build_tree( "-Donnxruntime_ENABLE_CUDA_PROFILING=" + ("ON" if args.enable_cuda_profiling else "OFF"), "-Donnxruntime_ENABLE_ROCM_PROFILING=" + ("ON" if args.enable_rocm_profiling else "OFF"), "-Donnxruntime_USE_XNNPACK=" + ("ON" if args.use_xnnpack else "OFF"), + "-Donnxruntime_USE_WEBNN=" + ("ON" if args.use_webnn else "OFF"), "-Donnxruntime_USE_CANN=" + ("ON" if args.use_cann else "OFF"), ] @@ -1194,6 +1196,15 @@ def generate_build_tree( if args.use_coreml: cmake_args += ["-Donnxruntime_USE_COREML=ON"] + if args.use_webnn: + if not args.build_wasm: + raise BuildError("WebNN is only available for WebAssembly build.") + if args.disable_rtti: + # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled + # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/16911 + raise BuildError("WebNN is not supported with RTTI disabled.") + cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] + if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"]