Skip to content

Commit cd1b01a

Browse files
committed
PR feedback
* Added shouldTransferToMLBuffer to avoid creating MLBuffers for initializers * Switches from getMLBuffer to ensureBuffer to fix issues when graph is partitioned * Switches from custom DataType enum to the common one.
1 parent 1e07f85 commit cd1b01a

File tree

5 files changed

+39
-29
lines changed

5 files changed

+39
-29
lines changed

js/web/lib/wasm/jsep/backend-webnn.ts

+17-25
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,25 @@
88

99
import {Tensor} from 'onnxruntime-common';
1010

11+
import {DataType} from '../wasm-common';
12+
import {getInstance} from '../wasm-factory';
13+
1114
import {createView} from './tensor-view';
1215
import {BufferId, BufferManager, createBufferManager} from './webnn/buffer-manager';
1316

14-
/*
15-
* TensorProto::data_type from the ONNX specification.
16-
*/
17-
enum TensorProtoDataType {
18-
float = 1,
19-
uint8 = 2,
20-
int8 = 3,
21-
int32 = 6,
22-
int64 = 7,
23-
bool = 9,
24-
float16 = 10,
25-
uint32 = 12,
26-
uint64 = 13,
27-
}
28-
2917
/*
3018
* TensorProto::data_type to WebNN OperandType mapping.
3119
*/
32-
const onnxDataTypeToWebnnDataType = new Map<TensorProtoDataType, MLOperandDataType>([
33-
[TensorProtoDataType.float, 'float32'],
34-
[TensorProtoDataType.float16, 'float16'],
35-
[TensorProtoDataType.int32, 'int32'],
36-
[TensorProtoDataType.uint32, 'uint32'],
37-
[TensorProtoDataType.int64, 'int64'],
38-
[TensorProtoDataType.uint64, 'uint64'],
39-
[TensorProtoDataType.int8, 'int8'],
40-
[TensorProtoDataType.uint8, 'uint8'],
41-
[TensorProtoDataType.bool, 'uint8'],
20+
const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
21+
[DataType.float, 'float32'],
22+
[DataType.float16, 'float16'],
23+
[DataType.int32, 'int32'],
24+
[DataType.uint32, 'uint32'],
25+
[DataType.int64, 'int64'],
26+
[DataType.uint64, 'uint64'],
27+
[DataType.int8, 'int8'],
28+
[DataType.uint8, 'uint8'],
29+
[DataType.bool, 'uint8'],
4230
]);
4331

4432
/**
@@ -130,6 +118,10 @@ export class WebNNBackend {
130118
}
131119

132120
public uploadBuffer(bufferId: BufferId, data: Uint8Array): void {
121+
const wasm = getInstance();
122+
if (!wasm.shouldTransferToMLBuffer) {
123+
throw new Error('Trying to upload to a MLBuffer while shouldTransferToMLBuffer is false');
124+
}
133125
this.bufferManager.upload(bufferId, data);
134126
}
135127

js/web/lib/wasm/wasm-core-impl.ts

+8-3
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ export const createSession = async(
261261
for (const provider of options?.executionProviders ?? []) {
262262
const providerName = typeof provider === 'string' ? provider : provider.name;
263263
if (providerName === 'webnn') {
264+
wasm.shouldTransferToMLBuffer = false;
264265
if (wasm.currentContext) {
265266
throw new Error('WebNN execution provider is already set.');
266267
}
@@ -294,6 +295,7 @@ export const createSession = async(
294295
if (wasm.currentContext) {
295296
wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);
296297
wasm.currentContext = undefined;
298+
wasm.shouldTransferToMLBuffer = true;
297299
}
298300

299301
const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
@@ -691,16 +693,19 @@ export const run = async(
691693
'gpu-buffer'
692694
]);
693695
} else if (preferredLocation === 'ml-buffer' && size > 0) {
694-
const getMLBuffer = wasm.jsepGetMLBuffer;
695-
if (!getMLBuffer) {
696+
const ensureBuffer = wasm.jsepEnsureBuffer;
697+
if (!ensureBuffer) {
696698
throw new Error('preferredLocation "ml-buffer" is not supported without using WebNN.');
697699
}
698-
const mlBuffer = getMLBuffer(dataOffset);
699700
const elementSize = getTensorElementSize(dataType);
700701
if (elementSize === undefined || !isMLBufferSupportedType(type)) {
701702
throw new Error(`Unsupported data type: ${type}`);
702703
}
703704

705+
// If the graph has been partitioned, the output tensor may have not been created. For this reason, we use
706+
// ensureBuffer to get/create the MLBuffer.
707+
const mlBuffer = ensureBuffer(dataOffset, dataType, dims);
708+
704709
// do not release the tensor right now. it will be released when user calls tensor.dispose().
705710
keepOutputTensor = true;
706711

js/web/lib/wasm/wasm-types.ts

+5
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ export declare namespace JSEP {
126126
*/
127127
currentContext: MLContext;
128128

129+
/**
130+
* Disables creating MLBuffers. This is used to avoid creating MLBuffers for graph initializers.
131+
*/
132+
shouldTransferToMLBuffer: boolean;
133+
129134
/**
130135
* [exported from pre-jsep.js] Register MLContext for a session.
131136
* @param sessionId - specify the session ID.

onnxruntime/core/providers/webnn/allocator.cc

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ void* WebNNBufferAllocator::Alloc(size_t size) {
1212
if (size == 0) {
1313
return nullptr;
1414
}
15+
if (!emscripten::val::module_property("shouldTransferToMLBuffer").as<bool>()) {
16+
// We don't need to transfer the buffer to an MLBuffer, so we don't need to allocate buffer id.
17+
return nullptr;
18+
}
1519
void* p = EM_ASM_PTR({ return Module.jsepReserveBufferId(); });
1620
allocations_[p] = size;
1721
stats_.num_allocs++;

onnxruntime/core/providers/webnn/data_transfer.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <emscripten.h>
77
#include "core/framework/tensor.h"
88

9-
109
namespace onnxruntime {
1110
namespace webnn {
1211

@@ -24,6 +23,11 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
2423

2524
const auto& dst_device = dst.Location().device;
2625

26+
if (!emscripten::val::module_property("shouldTransferToMLBuffer").as<bool>()) {
27+
// We don't need to transfer the buffer to an MLBuffer, so we don't need to copy the buffer.
28+
return Status::OK();
29+
}
30+
2731
if (dst_device.Type() == OrtDevice::GPU) {
2832
EM_ASM({
2933
Module.jsepUploadBuffer($0, HEAPU8.subarray($1, $1 + $2));

0 commit comments

Comments
 (0)