@@ -6,16 +6,19 @@ import { TensorToDataUrlOptions, TensorToImageDataOptions } from './tensor-conve
6
6
import {
7
7
tensorFromGpuBuffer ,
8
8
tensorFromImage ,
9
+ tensorFromMLTensor ,
9
10
tensorFromPinnedBuffer ,
10
11
tensorFromTexture ,
11
12
} from './tensor-factory-impl.js' ;
12
13
import {
13
14
CpuPinnedConstructorParameters ,
14
15
GpuBufferConstructorParameters ,
16
+ MLTensorConstructorParameters ,
15
17
TensorFromGpuBufferOptions ,
16
18
TensorFromImageBitmapOptions ,
17
19
TensorFromImageDataOptions ,
18
20
TensorFromImageElementOptions ,
21
+ TensorFromMLTensorOptions ,
19
22
TensorFromTextureOptions ,
20
23
TensorFromUrlOptions ,
21
24
TextureConstructorParameters ,
@@ -37,6 +40,7 @@ type TensorDataType = TensorInterface.DataType;
37
40
type TensorDataLocation = TensorInterface . DataLocation ;
38
41
type TensorTextureType = TensorInterface . TextureType ;
39
42
type TensorGpuBufferType = TensorInterface . GpuBufferType ;
43
+ type TensorMLTensorType = TensorInterface . MLTensorType ;
40
44
41
45
/**
42
46
* the implementation of Tensor interface.
@@ -86,6 +90,15 @@ export class Tensor implements TensorInterface {
86
90
*/
87
91
constructor ( params : GpuBufferConstructorParameters ) ;
88
92
93
+ /**
94
+ * Construct a new tensor object from the WebNN MLTensor with the given type and dims.
95
+ *
96
+ * Tensor's location will be set to 'ml-tensor'.
97
+ *
98
+ * @param params - Specify the parameters to construct the tensor.
99
+ */
100
+ constructor ( params : MLTensorConstructorParameters ) ;
101
+
89
102
/**
90
103
* implementation.
91
104
*/
@@ -98,7 +111,8 @@ export class Tensor implements TensorInterface {
98
111
| readonly boolean [ ]
99
112
| CpuPinnedConstructorParameters
100
113
| TextureConstructorParameters
101
- | GpuBufferConstructorParameters ,
114
+ | GpuBufferConstructorParameters
115
+ | MLTensorConstructorParameters ,
102
116
arg1 ?: TensorDataType | Uint8ClampedArray | readonly number [ ] | readonly string [ ] | readonly boolean [ ] ,
103
117
arg2 ?: readonly number [ ] ,
104
118
) {
@@ -155,6 +169,25 @@ export class Tensor implements TensorInterface {
155
169
this . disposer = arg0 . dispose ;
156
170
break ;
157
171
}
172
+ case 'ml-tensor' : {
173
+ if (
174
+ type !== 'float32' &&
175
+ type !== 'float16' &&
176
+ type !== 'int32' &&
177
+ type !== 'int64' &&
178
+ type !== 'uint32' &&
179
+ type !== 'uint64' &&
180
+ type !== 'int8' &&
181
+ type !== 'uint8' &&
182
+ type !== 'bool'
183
+ ) {
184
+ throw new TypeError ( `unsupported type "${ type } " to create tensor from MLTensor` ) ;
185
+ }
186
+ this . mlTensorData = arg0 . mlTensor ;
187
+ this . downloader = arg0 . download ;
188
+ this . disposer = arg0 . dispose ;
189
+ break ;
190
+ }
158
191
default :
159
192
throw new Error ( `Tensor constructor: unsupported location '${ this . dataLocation } '` ) ;
160
193
}
@@ -325,6 +358,13 @@ export class Tensor implements TensorInterface {
325
358
return tensorFromGpuBuffer ( gpuBuffer , options ) ;
326
359
}
327
360
361
+ static fromMLTensor < T extends TensorInterface . MLTensorDataTypes > (
362
+ mlTensor : TensorMLTensorType ,
363
+ options : TensorFromMLTensorOptions < T > ,
364
+ ) : TensorInterface {
365
+ return tensorFromMLTensor ( mlTensor , options ) ;
366
+ }
367
+
328
368
static fromPinnedBuffer < T extends TensorInterface . CpuPinnedDataTypes > (
329
369
type : T ,
330
370
buffer : TensorInterface . DataTypeMap [ T ] ,
@@ -373,6 +413,11 @@ export class Tensor implements TensorInterface {
373
413
*/
374
414
private gpuBufferData ?: TensorGpuBufferType ;
375
415
416
+ /**
417
+ * stores the underlying WebNN MLTensor when location is 'ml-tensor'. otherwise empty.
418
+ */
419
+ private mlTensorData ?: TensorMLTensorType ;
420
+
376
421
/**
377
422
* stores an optional downloader function to download data from GPU to CPU.
378
423
*/
@@ -420,6 +465,14 @@ export class Tensor implements TensorInterface {
420
465
}
421
466
return this . gpuBufferData ;
422
467
}
468
+
469
+ get mlTensor ( ) : TensorMLTensorType {
470
+ this . ensureValid ( ) ;
471
+ if ( ! this . mlTensorData ) {
472
+ throw new Error ( 'The data is not stored as a WebNN MLTensor.' ) ;
473
+ }
474
+ return this . mlTensorData ;
475
+ }
423
476
// #endregion
424
477
425
478
// #region methods
@@ -431,7 +484,8 @@ export class Tensor implements TensorInterface {
431
484
case 'cpu-pinned' :
432
485
return this . data ;
433
486
case 'texture' :
434
- case 'gpu-buffer' : {
487
+ case 'gpu-buffer' :
488
+ case 'ml-tensor' : {
435
489
if ( ! this . downloader ) {
436
490
throw new Error ( 'The current tensor is not created with a specified data downloader.' ) ;
437
491
}
@@ -472,6 +526,7 @@ export class Tensor implements TensorInterface {
472
526
this . cpuData = undefined ;
473
527
this . gpuTextureData = undefined ;
474
528
this . gpuBufferData = undefined ;
529
+ this . mlTensorData = undefined ;
475
530
this . downloader = undefined ;
476
531
this . isDownloading = undefined ;
477
532
0 commit comments