Skip to content

Commit 1e07f85

Browse files
committed
PR feedback
* Fixed issues when building under debug * Disabled MLBuffer on CPU device types * Renamed MlBuffer and MlContext to match specification
1 parent d823108 commit 1e07f85

24 files changed

+154
-120
lines changed

js/common/lib/tensor-factory-impl.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,10 @@ export const tensorFromGpuBuffer = <T extends TensorInterface.GpuBufferDataTypes
275275
};
276276

277277
/**
278-
* implementation of Tensor.fromMlBuffer().
278+
* implementation of Tensor.fromMLBuffer().
279279
*/
280-
export const tensorFromMlBuffer = <T extends TensorInterface.GpuBufferDataTypes>(
281-
mlBuffer: TensorInterface.MlBufferType, options: TensorFromGpuBufferOptions<T>): Tensor => {
280+
export const tensorFromMLBuffer = <T extends TensorInterface.GpuBufferDataTypes>(
281+
mlBuffer: TensorInterface.MLBufferType, options: TensorFromGpuBufferOptions<T>): Tensor => {
282282
const {dataType, dims, download, dispose} = options;
283283
return new Tensor({location: 'ml-buffer', type: dataType ?? 'float32', mlBuffer, dims, download, dispose});
284284
};

js/common/lib/tensor-factory.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ export interface GpuBufferConstructorParameters<T extends Tensor.GpuBufferDataTy
8484
readonly gpuBuffer: Tensor.GpuBufferType;
8585
}
8686

87-
export interface MlBufferConstructorParameters<T extends Tensor.MlBufferDataTypes = Tensor.MlBufferDataTypes> extends
87+
export interface MLBufferConstructorParameters<T extends Tensor.MLBufferDataTypes = Tensor.MLBufferDataTypes> extends
8888
CommonConstructorParameters<T>, GpuResourceConstructorParameters<T> {
8989
/**
9090
* Specify the location of the data to be 'ml-buffer'.
@@ -94,7 +94,7 @@ export interface MlBufferConstructorParameters<T extends Tensor.MlBufferDataType
9494
/**
9595
* Specify the WebNN buffer that holds the tensor data.
9696
*/
97-
readonly mlBuffer: Tensor.MlBufferType;
97+
readonly mlBuffer: Tensor.MLBufferType;
9898
}
9999

100100
// #endregion
@@ -212,7 +212,7 @@ export interface TensorFromGpuBufferOptions<T extends Tensor.GpuBufferDataTypes>
212212
dataType?: T;
213213
}
214214

215-
export interface TensorFromMlBufferOptions<T extends Tensor.MlBufferDataTypes> extends
215+
export interface TensorFromMLBufferOptions<T extends Tensor.MLBufferDataTypes> extends
216216
Pick<Tensor, 'dims'>, GpuResourceConstructorParameters<T> {
217217
/**
218218
* Describes the data type of the tensor.
@@ -345,7 +345,7 @@ export interface TensorFactory {
345345
*
346346
* @returns a tensor object
347347
*/
348-
fromMlBuffer<T extends Tensor.MlBufferDataTypes>(buffer: Tensor.MlBufferType, options: TensorFromMlBufferOptions<T>):
348+
fromMLBuffer<T extends Tensor.MLBufferDataTypes>(buffer: Tensor.MLBufferType, options: TensorFromMLBufferOptions<T>):
349349
TypedTensor<T>;
350350

351351
/**

js/common/lib/tensor-impl.ts

+10-10
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
55
import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
6-
import {tensorFromGpuBuffer, tensorFromImage, tensorFromMlBuffer, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
7-
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, MlBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
6+
import {tensorFromGpuBuffer, tensorFromImage, tensorFromMLBuffer, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
7+
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, MLBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
88
import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
99
import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
1010
import {Tensor as TensorInterface} from './tensor.js';
@@ -16,7 +16,7 @@ type TensorDataType = TensorInterface.DataType;
1616
type TensorDataLocation = TensorInterface.DataLocation;
1717
type TensorTextureType = TensorInterface.TextureType;
1818
type TensorGpuBufferType = TensorInterface.GpuBufferType;
19-
type TensorMlBufferType = TensorInterface.MlBufferType;
19+
type TensorMLBufferType = TensorInterface.MLBufferType;
2020

2121
/**
2222
* the implementation of Tensor interface.
@@ -68,14 +68,14 @@ export class Tensor implements TensorInterface {
6868
*
6969
* @param params - Specify the parameters to construct the tensor.
7070
*/
71-
constructor(params: MlBufferConstructorParameters);
71+
constructor(params: MLBufferConstructorParameters);
7272

7373
/**
7474
* implementation.
7575
*/
7676
constructor(
7777
arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
78-
TextureConstructorParameters|GpuBufferConstructorParameters|MlBufferConstructorParameters,
78+
TextureConstructorParameters|GpuBufferConstructorParameters|MLBufferConstructorParameters,
7979
arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
8080
// perform one-time check for BigInt/Float16Array support
8181
checkTypedArray();
@@ -273,9 +273,9 @@ export class Tensor implements TensorInterface {
273273
return tensorFromGpuBuffer(gpuBuffer, options);
274274
}
275275

276-
static fromMlBuffer<T extends TensorInterface.GpuBufferDataTypes>(
277-
mlBuffer: TensorMlBufferType, options: TensorFromGpuBufferOptions<T>): TensorInterface {
278-
return tensorFromMlBuffer(mlBuffer, options);
276+
static fromMLBuffer<T extends TensorInterface.GpuBufferDataTypes>(
277+
mlBuffer: TensorMLBufferType, options: TensorFromGpuBufferOptions<T>): TensorInterface {
278+
return tensorFromMLBuffer(mlBuffer, options);
279279
}
280280

281281
static fromPinnedBuffer<T extends TensorInterface.CpuPinnedDataTypes>(
@@ -326,7 +326,7 @@ export class Tensor implements TensorInterface {
326326
/**
327327
* stores the underlying WebNN MLBuffer when location is 'ml-buffer'. otherwise empty.
328328
*/
329-
private mlBufferData?: TensorMlBufferType;
329+
private mlBufferData?: TensorMLBufferType;
330330

331331

332332
/**
@@ -376,7 +376,7 @@ export class Tensor implements TensorInterface {
376376
return this.gpuBufferData;
377377
}
378378

379-
get mlBuffer(): TensorMlBufferType {
379+
get mlBuffer(): TensorMLBufferType {
380380
this.ensureValid();
381381
if (!this.mlBufferData) {
382382
throw new Error('The data is not stored as a WebNN buffer.');

js/common/lib/tensor-utils-impl.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, MlBufferConstructorParameters, TextureConstructorParameters} from './tensor-factory.js';
4+
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, MLBufferConstructorParameters, TextureConstructorParameters} from './tensor-factory.js';
55
import {Tensor} from './tensor-impl.js';
66

77
/**
@@ -56,7 +56,7 @@ export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor =
5656
return new Tensor({
5757
location: 'ml-buffer',
5858
mlBuffer: tensor.mlBuffer,
59-
type: tensor.type as MlBufferConstructorParameters['type'],
59+
type: tensor.type as MLBufferConstructorParameters['type'],
6060
dims,
6161
});
6262
default:

js/common/lib/tensor.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ interface TypedTensorBase<T extends Tensor.Type> {
4747
*
4848
* If the data is not in a WebNN MLBuffer, throw error.
4949
*/
50-
readonly mlBuffer: Tensor.MlBufferType;
50+
readonly mlBuffer: Tensor.MLBufferType;
5151

5252
/**
5353
* Get the buffer data of the tensor.
@@ -144,7 +144,7 @@ export declare namespace Tensor {
144144
*
145145
* The specification for WebNN's ML Buffer is currently in flux.
146146
*/
147-
export type MlBufferType = unknown;
147+
export type MLBufferType = unknown;
148148

149149
/**
150150
* supported data types for constructing a tensor from a WebGPU buffer
@@ -154,7 +154,7 @@ export declare namespace Tensor {
154154
/**
155155
* supported data types for constructing a tensor from a WebNN MLBuffer
156156
*/
157-
export type MlBufferDataTypes = 'float32'|'float16'|'int8'|'uint8'|'int32'|'uint32'|'int64'|'uint64'|'bool';
157+
export type MLBufferDataTypes = 'float32'|'float16'|'int8'|'uint8'|'int32'|'uint32'|'int64'|'uint64'|'bool';
158158

159159
/**
160160
* represent where the tensor data is stored

js/web/lib/wasm/jsep/backend-webnn.ts

+13-13
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export class WebNNBackend {
5454
/**
5555
* Maps from MLContext to session ids.
5656
*/
57-
private sessionIdsByMlContext = new Map<MLContext, Set<number>>();
57+
private sessionIdsByMLContext = new Map<MLContext, Set<number>>();
5858
/**
5959
* Current session id.
6060
*/
@@ -68,38 +68,38 @@ export class WebNNBackend {
6868
if (this.currentSessionId === undefined) {
6969
throw new Error('No active session');
7070
}
71-
return this.getMlContext(this.currentSessionId);
71+
return this.getMLContext(this.currentSessionId);
7272
}
7373

74-
public registerMlContext(sessionId: number, mlContext: MLContext): void {
74+
public registerMLContext(sessionId: number, mlContext: MLContext): void {
7575
this.mlContextBySessionId.set(sessionId, mlContext);
76-
let sessionIds = this.sessionIdsByMlContext.get(mlContext);
76+
let sessionIds = this.sessionIdsByMLContext.get(mlContext);
7777
if (!sessionIds) {
7878
sessionIds = new Set();
79-
this.sessionIdsByMlContext.set(mlContext, sessionIds);
79+
this.sessionIdsByMLContext.set(mlContext, sessionIds);
8080
}
8181
sessionIds.add(sessionId);
8282
}
8383

84-
public unregisterMlContext(sessionId: number): void {
84+
public unregisterMLContext(sessionId: number): void {
8585
const mlContext = this.mlContextBySessionId.get(sessionId)!;
8686
if (!mlContext) {
8787
throw new Error(`No MLContext found for session ${sessionId}`);
8888
}
8989
this.mlContextBySessionId.delete(sessionId);
90-
const sessionIds = this.sessionIdsByMlContext.get(mlContext)!;
90+
const sessionIds = this.sessionIdsByMLContext.get(mlContext)!;
9191
sessionIds.delete(sessionId);
9292
if (sessionIds.size === 0) {
93-
this.sessionIdsByMlContext.delete(mlContext);
93+
this.sessionIdsByMLContext.delete(mlContext);
9494
}
9595
}
9696

9797
public onReleaseSession(sessionId: number): void {
98-
this.unregisterMlContext(sessionId);
99-
this.bufferManager.releaseBuffersForContext(this.getMlContext(sessionId));
98+
this.unregisterMLContext(sessionId);
99+
this.bufferManager.releaseBuffersForContext(this.getMLContext(sessionId));
100100
}
101101

102-
public getMlContext(sessionId: number): MLContext {
102+
public getMLContext(sessionId: number): MLContext {
103103
return this.mlContextBySessionId.get(sessionId)!;
104104
}
105105

@@ -137,14 +137,14 @@ export class WebNNBackend {
137137
return this.bufferManager.download(bufferId);
138138
}
139139

140-
public createMlBufferDownloader(bufferId: BufferId, type: Tensor.GpuBufferDataTypes): () => Promise<Tensor.DataType> {
140+
public createMLBufferDownloader(bufferId: BufferId, type: Tensor.GpuBufferDataTypes): () => Promise<Tensor.DataType> {
141141
return async () => {
142142
const data = await this.bufferManager.download(bufferId);
143143
return createView(data, type);
144144
};
145145
}
146146

147-
public registerMlBuffer(buffer: MLBuffer): BufferId {
147+
public registerMLBuffer(buffer: MLBuffer): BufferId {
148148
return this.bufferManager.registerBuffer(this.currentContext, buffer);
149149
}
150150

js/web/lib/wasm/jsep/init.ts

+16-1
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,21 @@ export const init =
239239
]);
240240
} else {
241241
const backend = new WebNNBackend();
242-
jsepInit('webnn', [backend]);
242+
jsepInit('webnn', [
243+
backend,
244+
// jsepReserveBufferId
245+
() => backend.reserveBufferId(),
246+
// jsepReleaseBufferId,
247+
(bufferId: number) => backend.releaseBufferId(bufferId),
248+
// jsepEnsureBuffer
249+
(bufferId: number, onnxDataType: number, dimensions: number[]) =>
250+
backend.ensureBuffer(bufferId, onnxDataType, dimensions),
251+
// jsepUploadBuffer
252+
(bufferId: number, data: Uint8Array) => {
253+
backend.uploadBuffer(bufferId, data);
254+
},
255+
// jsepDownloadBuffer
256+
async (bufferId: number) => backend.downloadBuffer(bufferId),
257+
]);
243258
}
244259
};

js/web/lib/wasm/proxy-messages.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ export type GpuBufferMetadata = {
1515
dispose?: () => void;
1616
};
1717

18-
export type MlBufferMetadata = {
19-
mlBuffer: Tensor.MlBufferType;
20-
download?: () => Promise<Tensor.DataTypeMap[Tensor.MlBufferDataTypes]>;
18+
export type MLBufferMetadata = {
19+
mlBuffer: Tensor.MLBufferType;
20+
download?: () => Promise<Tensor.DataTypeMap[Tensor.MLBufferDataTypes]>;
2121
dispose?: () => void;
2222
};
2323

@@ -26,7 +26,7 @@ export type MlBufferMetadata = {
2626
*/
2727
export type UnserializableTensorMetadata =
2828
[dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']|
29-
[dataType: Tensor.Type, dims: readonly number[], data: MlBufferMetadata, location: 'ml-buffer']|
29+
[dataType: Tensor.Type, dims: readonly number[], data: MLBufferMetadata, location: 'ml-buffer']|
3030
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];
3131

3232
/**
@@ -37,7 +37,7 @@ export type UnserializableTensorMetadata =
3737
* - cpu: Uint8Array
3838
* - cpu-pinned: Uint8Array
3939
* - gpu-buffer: GpuBufferMetadata
40-
* - ml-buffer: MlBufferMetadata
40+
* - ml-buffer: MLBufferMetadata
4141
* - location: tensor data location
4242
*/
4343
export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata;

js/web/lib/wasm/session-handler-inference.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE
55

66
import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
77
import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
8-
import {isGpuBufferSupportedType, isMlBufferSupportedType} from './wasm-common';
8+
import {isGpuBufferSupportedType, isMLBufferSupportedType} from './wasm-common';
99
import {isNode} from './wasm-utils-env';
1010
import {loadFile} from './wasm-utils-load-file';
1111

@@ -36,11 +36,11 @@ export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
3636
}
3737
case 'ml-buffer': {
3838
const dataType = tensor[0];
39-
if (!isMlBufferSupportedType(dataType)) {
39+
if (!isMLBufferSupportedType(dataType)) {
4040
throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`);
4141
}
4242
const {mlBuffer, download, dispose} = tensor[2];
43-
return Tensor.fromMlBuffer(mlBuffer, {dataType, dims: tensor[1], download, dispose});
43+
return Tensor.fromMLBuffer(mlBuffer, {dataType, dims: tensor[1], download, dispose});
4444
}
4545
default:
4646
throw new Error(`invalid data location: ${tensor[3]}`);

js/web/lib/wasm/wasm-common.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB
182182
/**
183183
* Check whether the given tensor type is supported by WebNN MLBuffer
184184
*/
185-
export const isMlBufferSupportedType = (type: Tensor.Type): type is Tensor.MlBufferDataTypes => type === 'float32' ||
185+
export const isMLBufferSupportedType = (type: Tensor.Type): type is Tensor.MLBufferDataTypes => type === 'float32' ||
186186
type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint64' ||
187187
type === 'int8' || type === 'uint8' || type === 'bool';
188188

js/web/lib/wasm/wasm-core-impl.ts

+10-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common';
1111
import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
1212
import {setRunOptions} from './run-options';
1313
import {setSessionOptions} from './session-options';
14-
import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, isMlBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
14+
import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, isMLBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
1515
import {getInstance} from './wasm-factory';
1616
import {allocWasmString, checkLastError} from './wasm-utils';
1717
import {loadFile} from './wasm-utils-load-file';
@@ -292,7 +292,7 @@ export const createSession = async(
292292

293293
// clear current MLContext after session creation
294294
if (wasm.currentContext) {
295-
wasm.jsepRegisterMlContext!(sessionHandle, wasm.currentContext);
295+
wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);
296296
wasm.currentContext = undefined;
297297
}
298298

@@ -446,11 +446,11 @@ export const prepareInputOutputTensor =
446446
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
447447
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
448448

449-
const registerMlBuffer = wasm.jsepRegisterMlBuffer;
450-
if (!registerMlBuffer) {
449+
const registerMLBuffer = wasm.jsepRegisterMLBuffer;
450+
if (!registerMLBuffer) {
451451
throw new Error('Tensor location "ml-buffer" is not supported without using WebNN.');
452452
}
453-
rawData = registerMlBuffer(mlBuffer);
453+
rawData = registerMLBuffer(mlBuffer);
454454
} else {
455455
const data = tensor[2];
456456

@@ -691,13 +691,13 @@ export const run = async(
691691
'gpu-buffer'
692692
]);
693693
} else if (preferredLocation === 'ml-buffer' && size > 0) {
694-
const getMlBuffer = wasm.jsepGetMlBuffer;
695-
if (!getMlBuffer) {
694+
const getMLBuffer = wasm.jsepGetMLBuffer;
695+
if (!getMLBuffer) {
696696
throw new Error('preferredLocation "ml-buffer" is not supported without using WebNN.');
697697
}
698-
const mlBuffer = getMlBuffer(dataOffset);
698+
const mlBuffer = getMLBuffer(dataOffset);
699699
const elementSize = getTensorElementSize(dataType);
700-
if (elementSize === undefined || !isMlBufferSupportedType(type)) {
700+
if (elementSize === undefined || !isMLBufferSupportedType(type)) {
701701
throw new Error(`Unsupported data type: ${type}`);
702702
}
703703

@@ -707,7 +707,7 @@ export const run = async(
707707
output.push([
708708
type, dims, {
709709
mlBuffer,
710-
download: wasm.jsepCreateMlBufferDownloader!(dataOffset, type),
710+
download: wasm.jsepCreateMLBufferDownloader!(dataOffset, type),
711711
dispose: () => {
712712
wasm.jsepReleaseBufferId!(dataOffset);
713713
wasm._OrtReleaseTensor(tensor);

0 commit comments

Comments
 (0)