|
8 | 8 |
|
9 | 9 | import {Tensor} from 'onnxruntime-common'; |
10 | 10 |
|
| 11 | +import {DataType} from '../wasm-common'; |
| 12 | +import {getInstance} from '../wasm-factory'; |
| 13 | + |
11 | 14 | import {createView} from './tensor-view'; |
12 | 15 | import {BufferId, BufferManager, createBufferManager} from './webnn/buffer-manager'; |
13 | 16 |
|
14 | | -/* |
15 | | - * TensorProto::data_type from the ONNX specification. |
16 | | - */ |
17 | | -enum TensorProtoDataType { |
18 | | - float = 1, |
19 | | - uint8 = 2, |
20 | | - int8 = 3, |
21 | | - int32 = 6, |
22 | | - int64 = 7, |
23 | | - bool = 9, |
24 | | - float16 = 10, |
25 | | - uint32 = 12, |
26 | | - uint64 = 13, |
27 | | -} |
28 | | - |
29 | 17 | /* |
30 | 18 | * TensorProto::data_type to WebNN OperandType mapping. |
31 | 19 | */ |
32 | | -const onnxDataTypeToWebnnDataType = new Map<TensorProtoDataType, MLOperandDataType>([ |
33 | | - [TensorProtoDataType.float, 'float32'], |
34 | | - [TensorProtoDataType.float16, 'float16'], |
35 | | - [TensorProtoDataType.int32, 'int32'], |
36 | | - [TensorProtoDataType.uint32, 'uint32'], |
37 | | - [TensorProtoDataType.int64, 'int64'], |
38 | | - [TensorProtoDataType.uint64, 'uint64'], |
39 | | - [TensorProtoDataType.int8, 'int8'], |
40 | | - [TensorProtoDataType.uint8, 'uint8'], |
41 | | - [TensorProtoDataType.bool, 'uint8'], |
| 20 | +const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([ |
| 21 | + [DataType.float, 'float32'], |
| 22 | + [DataType.float16, 'float16'], |
| 23 | + [DataType.int32, 'int32'], |
| 24 | + [DataType.uint32, 'uint32'], |
| 25 | + [DataType.int64, 'int64'], |
| 26 | + [DataType.uint64, 'uint64'], |
| 27 | + [DataType.int8, 'int8'], |
| 28 | + [DataType.uint8, 'uint8'], |
| 29 | + [DataType.bool, 'uint8'], |
42 | 30 | ]); |
43 | 31 |
|
44 | 32 | /** |
@@ -130,6 +118,10 @@ export class WebNNBackend { |
130 | 118 | } |
131 | 119 |
|
132 | 120 | public uploadBuffer(bufferId: BufferId, data: Uint8Array): void { |
| 121 | + const wasm = getInstance(); |
| 122 | + if (!wasm.shouldTransferToMLBuffer) { |
| 123 | + throw new Error('Trying to upload to a MLBuffer while shouldTransferToMLBuffer is false'); |
| 124 | + } |
133 | 125 | this.bufferManager.upload(bufferId, data); |
134 | 126 | } |
135 | 127 |
|
|
0 commit comments