|
8 | 8 |
|
9 | 9 | import {Tensor} from 'onnxruntime-common';
|
10 | 10 |
|
| 11 | +import {DataType} from '../wasm-common'; |
| 12 | +import {getInstance} from '../wasm-factory'; |
| 13 | + |
11 | 14 | import {createView} from './tensor-view';
|
12 | 15 | import {BufferId, BufferManager, createBufferManager} from './webnn/buffer-manager';
|
13 | 16 |
|
14 |
| -/* |
15 |
| - * TensorProto::data_type from the ONNX specification. |
16 |
| - */ |
17 |
| -enum TensorProtoDataType { |
18 |
| - float = 1, |
19 |
| - uint8 = 2, |
20 |
| - int8 = 3, |
21 |
| - int32 = 6, |
22 |
| - int64 = 7, |
23 |
| - bool = 9, |
24 |
| - float16 = 10, |
25 |
| - uint32 = 12, |
26 |
| - uint64 = 13, |
27 |
| -} |
28 |
| - |
29 | 17 | /*
|
30 | 18 | * TensorProto::data_type to WebNN OperandType mapping.
|
31 | 19 | */
|
32 |
| -const onnxDataTypeToWebnnDataType = new Map<TensorProtoDataType, MLOperandDataType>([ |
33 |
| - [TensorProtoDataType.float, 'float32'], |
34 |
| - [TensorProtoDataType.float16, 'float16'], |
35 |
| - [TensorProtoDataType.int32, 'int32'], |
36 |
| - [TensorProtoDataType.uint32, 'uint32'], |
37 |
| - [TensorProtoDataType.int64, 'int64'], |
38 |
| - [TensorProtoDataType.uint64, 'uint64'], |
39 |
| - [TensorProtoDataType.int8, 'int8'], |
40 |
| - [TensorProtoDataType.uint8, 'uint8'], |
41 |
| - [TensorProtoDataType.bool, 'uint8'], |
| 20 | +const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([ |
| 21 | + [DataType.float, 'float32'], |
| 22 | + [DataType.float16, 'float16'], |
| 23 | + [DataType.int32, 'int32'], |
| 24 | + [DataType.uint32, 'uint32'], |
| 25 | + [DataType.int64, 'int64'], |
| 26 | + [DataType.uint64, 'uint64'], |
| 27 | + [DataType.int8, 'int8'], |
| 28 | + [DataType.uint8, 'uint8'], |
| 29 | + [DataType.bool, 'uint8'], |
42 | 30 | ]);
|
43 | 31 |
|
44 | 32 | /**
|
@@ -130,6 +118,10 @@ export class WebNNBackend {
|
130 | 118 | }
|
131 | 119 |
|
132 | 120 | public uploadBuffer(bufferId: BufferId, data: Uint8Array): void {
|
| 121 | + const wasm = getInstance(); |
| 122 | + if (!wasm.shouldTransferToMLBuffer) { |
| 123 | + throw new Error('Trying to upload to a MLBuffer while shouldTransferToMLBuffer is false'); |
| 124 | + } |
133 | 125 | this.bufferManager.upload(bufferId, data);
|
134 | 126 | }
|
135 | 127 |
|
|
0 commit comments