Skip to content

Commit 1324a98

Browse files
committed
Updating MLBuffer specification
* CPU devices now support MLBuffer * MLContext.createBuffer now returns an Promise
1 parent b2d0110 commit 1324a98

File tree

12 files changed

+34
-30
lines changed

12 files changed

+34
-30
lines changed

js/web/lib/wasm/jsep/backend-webnn.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ export class WebNNBackend {
9999
this.bufferManager.releaseBufferId(bufferId);
100100
}
101101

102-
public ensureBuffer(bufferId: BufferId, onnxDataType: number|MLOperandDataType, dimensions: number[]): MLBuffer {
102+
public async ensureBuffer(bufferId: BufferId, onnxDataType: number|MLOperandDataType, dimensions: number[]):
103+
Promise<MLBuffer> {
103104
let dataType: MLOperandDataType;
104105
if (typeof onnxDataType === 'number') {
105106
const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!;

js/web/lib/wasm/jsep/init.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ export const init =
246246
// jsepReleaseBufferId,
247247
(bufferId: number) => backend.releaseBufferId(bufferId),
248248
// jsepEnsureBuffer
249-
(bufferId: number, onnxDataType: number, dimensions: number[]) =>
249+
async (bufferId: number, onnxDataType: number, dimensions: number[]) =>
250250
backend.ensureBuffer(bufferId, onnxDataType, dimensions),
251251
// jsepUploadBuffer
252252
(bufferId: number, data: Uint8Array) => {

js/web/lib/wasm/jsep/webnn/buffer-manager.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ export interface BufferManager {
2525
/**
2626
* Ensure a MLBuffer is created for the BufferId.
2727
*/
28-
ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): MLBuffer;
28+
ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise<MLBuffer>;
2929
/**
3030
* Upload data to a MLBuffer.
3131
*/
@@ -85,12 +85,12 @@ class BufferTracker {
8585
this.mlBuffer = undefined;
8686
}
8787

88-
public ensureBuffer(dataType: MLOperandDataType, dimensions: number[]): MLBuffer {
88+
public async ensureBuffer(dataType: MLOperandDataType, dimensions: number[]): Promise<MLBuffer> {
8989
if (this.mlBuffer) {
9090
return this.mlBuffer;
9191
}
9292

93-
const buffer = this.context.createBuffer({dataType, dimensions});
93+
const buffer = await this.context.createBuffer({dataType, dimensions});
9494
this.mlBuffer = buffer;
9595

9696
if (this.activeUpload) {
@@ -151,7 +151,7 @@ class BufferManagerImpl implements BufferManager {
151151
}
152152
}
153153

154-
public ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): MLBuffer {
154+
public async ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise<MLBuffer> {
155155
const buffer = this.buffersById.get(bufferId);
156156
if (!buffer) {
157157
throw new Error('Buffer not found.');

js/web/lib/wasm/jsep/webnn/webnn.d.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ interface MLBuffer {
387387

388388
type MLNamedBuffers = Record<string, MLBuffer>;
389389
interface MLContext {
390-
createBuffer(descriptor: MLOperandDescriptor): MLBuffer;
390+
createBuffer(descriptor: MLOperandDescriptor): Promise<MLBuffer>;
391391
writeBuffer(
392392
dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: number,
393393
srcElementSize?: number): void;

js/web/lib/wasm/wasm-core-impl.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ export const run = async(
704704

705705
// If the graph has been partitioned, the output tensor may have not been created. For this reason, we use
706706
// ensureBuffer to get/create the MLBuffer.
707-
const mlBuffer = ensureBuffer(dataOffset, dataType, dims);
707+
const mlBuffer = await ensureBuffer(dataOffset, dataType, dims);
708708

709709
// do not release the tensor right now. it will be released when user calls tensor.dispose().
710710
keepOutputTensor = true;

js/web/lib/wasm/wasm-types.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ export declare namespace JSEP {
2525
type ReplayFunction = () => void;
2626
type ReserveBufferIdFunction = () => number;
2727
type ReleaseBufferIdFunction = (bufferId: number) => void;
28-
type EnsureBufferFunction = (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => MLBuffer;
28+
type EnsureBufferFunction = (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) =>
29+
Promise<MLBuffer>;
2930
type UploadBufferFunction = (bufferId: number, data: Uint8Array) => void;
3031
type DownloadBufferFunction = (bufferId: number) => Promise<ArrayBuffer>;
3132

@@ -154,7 +155,7 @@ export declare namespace JSEP {
154155
* @param bufferId - specify the MLBuffer ID.
155156
* @returns the MLBuffer.
156157
*/
157-
jsepEnsureBuffer: (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => MLBuffer;
158+
jsepEnsureBuffer: (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => Promise<MLBuffer>;
158159
/**
159160
* [exported from pre-jsep.js] Upload data to MLBuffer.
160161
* @param bufferId - specify the MLBuffer ID.

js/web/test/test-runner.ts

+3-4
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ export class ModelTestContext {
257257
const executionProviderConfig =
258258
modelTest.backend === 'webnn' ? (testOptions?.webnnOptions || {name: 'webnn'}) : modelTest.backend!;
259259
let mlContext: MLContext|undefined;
260-
if(['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) {
261-
260+
if (['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) {
262261
const webnnOptions = executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption;
263262
const deviceType = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.deviceType;
264263
const numThreads = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.numThreads;
@@ -593,7 +592,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty
593592

594593
const dataType = type === 'bool' ? 'uint8' : type;
595594

596-
const mlBuffer = mlContext.createBuffer({dataType, dimensions: dims as number[]});
595+
const mlBuffer = await mlContext.createBuffer({dataType, dimensions: dims as number[]});
597596

598597
return ort.Tensor.fromMLBuffer(mlBuffer, {
599598
dataType: type,
@@ -611,7 +610,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso
611610
throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`);
612611
}
613612
const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type;
614-
const mlBuffer = mlContext.createBuffer({dataType, dimensions: cpuTensor.dims as number[]});
613+
const mlBuffer = await mlContext.createBuffer({dataType, dimensions: cpuTensor.dims as number[]});
615614
mlContext.writeBuffer(mlBuffer, cpuTensor.data);
616615
return ort.Tensor.fromMLBuffer(
617616
mlBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => mlBuffer.destroy()});

onnxruntime/core/providers/webnn/builders/helper.cc

+2-3
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,9 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
211211
}
212212
}
213213

214-
bool IsMLBufferSupported(WebnnDeviceType device_type) {
214+
bool IsMLBufferSupported() {
215215
static bool is_supported = !emscripten::val::global("MLBuffer").isUndefined();
216-
// The current MLBuffer implementation only supports GPU and NPU devices.
217-
return is_supported && device_type != WebnnDeviceType::CPU;
216+
return is_supported;
218217
}
219218

220219
} // namespace webnn

onnxruntime/core/providers/webnn/builders/helper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
285285

286286
bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type);
287287

288-
bool IsMLBufferSupported(WebnnDeviceType device_type);
288+
bool IsMLBufferSupported();
289289

290290
} // namespace webnn
291291
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/model.cc

+12-8
Original file line numberDiff line numberDiff line change
@@ -155,27 +155,31 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap<std::string, Onn
155155
onnxruntime::common::Status Model::Dispatch(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
156156
const InlinedHashMap<std::string, OnnxTensorData>& outputs) {
157157
auto jsepEnsureBuffer = emscripten::val::module_property("jsepEnsureBuffer");
158-
for (const auto& input : inputs) {
159-
const std::string& name = input.first;
160-
const struct OnnxTensorData tensor = input.second;
158+
auto promises = emscripten::val::array();
159+
for (const auto& [_, tensor] : inputs) {
161160
emscripten::val shape = emscripten::val::array();
162161
for (const auto& dim : tensor.tensor_info.shape) {
163162
uint32_t dim_val = SafeInt<uint32_t>(dim);
164163
shape.call<void>("push", dim_val);
165164
}
166165
auto buffer = jsepEnsureBuffer(reinterpret_cast<intptr_t>(tensor.buffer), tensor.tensor_info.data_type, shape);
167-
wnn_inputs_.set(name, buffer);
166+
promises.call<void>("push", buffer);
168167
}
169-
for (const auto& output : outputs) {
170-
const std::string& name = output.first;
171-
const struct OnnxTensorData tensor = output.second;
168+
for (const auto& [_, tensor] : outputs) {
172169
emscripten::val shape = emscripten::val::array();
173170
for (const auto& dim : tensor.tensor_info.shape) {
174171
uint32_t dim_val = SafeInt<uint32_t>(dim);
175172
shape.call<void>("push", dim_val);
176173
}
177174
auto buffer = jsepEnsureBuffer(reinterpret_cast<intptr_t>(tensor.buffer), tensor.tensor_info.data_type, shape);
178-
wnn_outputs_.set(name, buffer);
175+
promises.call<void>("push", buffer);
176+
}
177+
auto buffers = emscripten::val::global("Promise").call<emscripten::val>("all", promises).await();
178+
for (const auto& [name, _] : inputs) {
179+
wnn_inputs_.set(name, buffers.call<emscripten::val>("shift"));
180+
}
181+
for (const auto& [name, _] : outputs) {
182+
wnn_outputs_.set(name, buffers.call<emscripten::val>("shift"));
179183
}
180184
wnn_context_.call<void>("dispatch", wnn_graph_, wnn_inputs_, wnn_outputs_);
181185

onnxruntime/core/providers/webnn/builders/model_builder.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
332332
if (!wnn_graph.as<bool>()) {
333333
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
334334
}
335-
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLBufferSupported(wnn_device_type_)));
335+
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLBufferSupported()));
336336
model->SetInputs(std::move(input_names_));
337337
model->SetOutputs(std::move(output_names_));
338338
model->SetScalarOutputs(std::move(scalar_outputs_));

onnxruntime/core/providers/webnn/webnn_execution_provider.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
2424
onnxruntime::kWebNNExecutionProvider,
2525
// If MLBuffer is supported, we force all the tensors to be allocated as MLBuffer.
2626
OrtDevice(
27-
webnn::IsMLBufferSupported(webnn::DeviceTypeFromString(webnn_device_flags)) ? OrtDevice::GPU : OrtDevice::CPU,
27+
webnn::IsMLBufferSupported() ? OrtDevice::GPU : OrtDevice::CPU,
2828
OrtDevice::MemType::DEFAULT,
2929
0)},
3030
wnn_device_type_(webnn::DeviceTypeFromString(webnn_device_flags)) {
@@ -378,14 +378,14 @@ WebNNExecutionProvider::GetKernelRegistry() const {
378378
}
379379

380380
std::unique_ptr<onnxruntime::IDataTransfer> WebNNExecutionProvider::GetDataTransfer() const {
381-
if (!webnn::IsMLBufferSupported(wnn_device_type_)) {
381+
if (!webnn::IsMLBufferSupported()) {
382382
return nullptr;
383383
}
384384
return std::make_unique<webnn::DataTransfer>();
385385
}
386386

387387
std::vector<AllocatorPtr> WebNNExecutionProvider::CreatePreferredAllocators() {
388-
if (!webnn::IsMLBufferSupported(wnn_device_type_)) {
388+
if (!webnn::IsMLBufferSupported()) {
389389
return {};
390390
}
391391
AllocatorCreationInfo customAllocatorCreationInfo([&](OrtDevice::DeviceId) {

0 commit comments

Comments
 (0)