@@ -204,6 +204,7 @@ export class ModelTestContext {
204204 readonly perfData : ModelTestContext . ModelTestPerfData ,
205205 readonly ioBinding : Test . IOBindingMode ,
206206 private readonly profile : boolean ,
207+ public readonly mlContext ?: MLContext ,
207208 ) { }
208209
209210 /**
@@ -254,7 +255,25 @@ export class ModelTestContext {
254255
255256 const initStart = now ( ) ;
256257 const executionProviderConfig =
257- modelTest . backend === 'webnn' ? ( testOptions ?. webnnOptions || 'webnn' ) : modelTest . backend ! ;
258+ modelTest . backend === 'webnn' ? ( testOptions ?. webnnOptions || { name : 'webnn' } ) : modelTest . backend ! ;
259+ let mlContext : MLContext | undefined ;
260+ if ( modelTest . ioBinding . includes ( 'ml-tensor' ) || modelTest . ioBinding . includes ( 'ml-location' ) ) {
261+
262+ const webnnOptions = executionProviderConfig as ort . InferenceSession . WebNNExecutionProviderOption ;
263+ const deviceType = ( webnnOptions as ort . InferenceSession . WebNNContextOptions ) ?. deviceType ;
264+ const numThreads = ( webnnOptions as ort . InferenceSession . WebNNContextOptions ) ?. numThreads ;
265+ const powerPreference = ( webnnOptions as ort . InferenceSession . WebNNContextOptions ) ?. powerPreference ;
266+
267+ mlContext = await navigator . ml . createContext ( {
268+ deviceType,
269+ numThreads,
270+ powerPreference,
271+ } ) ;
272+ ( executionProviderConfig as ort . InferenceSession . WebNNExecutionProviderOption ) . context = mlContext ;
273+ if ( ! deviceType ) {
274+ ( executionProviderConfig as ort . InferenceSession . WebNNContextOptions ) . deviceType = deviceType ;
275+ }
276+ }
258277 const session = await initializeSession (
259278 modelTest . modelUrl , executionProviderConfig , modelTest . ioBinding , profile , modelTest . externalData ,
260279 testOptions ?. sessionOptions || { } , this . cache ) ;
@@ -271,6 +290,7 @@ export class ModelTestContext {
271290 { init : initEnd - initStart , firstRun : - 1 , runs : [ ] , count : 0 } ,
272291 modelTest . ioBinding ,
273292 profile ,
293+ mlContext ,
274294 ) ;
275295 } finally {
276296 this . initializing = false ;
@@ -565,46 +585,34 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]
565585 } ) ;
566586}
567587
568- const getContext = ( ( ) => {
569- let context : MLContext | undefined ;
570-
571- return async ( ) : Promise < MLContext > => {
572- if ( ! context ) {
573- context = await navigator . ml . createContext ( ) ;
574- }
575- return context ;
576- } ;
577- } ) ( ) ;
578588
579- async function createMlTensorForOutput ( type : ort . Tensor . Type , dims : readonly number [ ] ) {
589+ async function createMLTensorForOutput ( mlContext : MLContext , type : ort . Tensor . Type , dims : readonly number [ ] ) {
580590 if ( ! isMLBufferSupportedType ( type ) ) {
581- throw new Error ( `createMlTensorForOutput can not work with ${ type } tensor` ) ;
591+ throw new Error ( `createMLTensorForOutput can not work with ${ type } tensor` ) ;
582592 }
583593
584594 const dataType = type === 'bool' ? 'uint8' : type ;
585595
586- const context = await getContext ( ) ;
587- const mlBuffer = context . createBuffer ( { dataType, dimensions : dims as number [ ] } ) ;
596+ const mlBuffer = mlContext . createBuffer ( { dataType, dimensions : dims as number [ ] } ) ;
588597
589598 return ort . Tensor . fromMLBuffer ( mlBuffer , {
590599 dataType : type ,
591600 dims,
592601 dispose : ( ) => mlBuffer . destroy ( ) ,
593602 download : async ( ) => {
594- const arrayBuffer = await context . readBuffer ( mlBuffer ) ;
595- return createView ( arrayBuffer , type ) as ort . Tensor . DataTypeMap [ ort . Tensor . GpuBufferDataTypes ] ;
603+ const arrayBuffer = await mlContext . readBuffer ( mlBuffer ) ;
604+ return createView ( arrayBuffer , type ) as ort . Tensor . DataTypeMap [ ort . Tensor . MLBufferDataTypes ] ;
596605 }
597606 } ) ;
598607}
599608
600- async function createMlTensorForInput ( cpuTensor : ort . Tensor ) : Promise < ort . Tensor > {
609+ async function createMLTensorForInput ( mlContext : MLContext , cpuTensor : ort . Tensor ) : Promise < ort . Tensor > {
601610 if ( ! isMLBufferSupportedType ( cpuTensor . type ) || Array . isArray ( cpuTensor . data ) ) {
602- throw new Error ( `createMlTensorForInput can not work with ${ cpuTensor . type } tensor` ) ;
611+ throw new Error ( `createMLTensorForInput can not work with ${ cpuTensor . type } tensor` ) ;
603612 }
604- const context = await getContext ( ) ;
605613 const dataType = cpuTensor . type === 'bool' ? 'uint8' : cpuTensor . type ;
606- const mlBuffer = context . createBuffer ( { dataType, dimensions : cpuTensor . dims as number [ ] } ) ;
607- context . writeBuffer ( mlBuffer , cpuTensor . data ) ;
614+ const mlBuffer = mlContext . createBuffer ( { dataType, dimensions : cpuTensor . dims as number [ ] } ) ;
615+ mlContext . writeBuffer ( mlBuffer , cpuTensor . data ) ;
608616 return ort . Tensor . fromMLBuffer (
609617 mlBuffer , { dataType : cpuTensor . type , dims : cpuTensor . dims , dispose : ( ) => mlBuffer . destroy ( ) } ) ;
610618}
@@ -613,6 +621,7 @@ export async function sessionRun(options: {
613621 session : ort . InferenceSession ; feeds : Record < string , ort . Tensor > ;
614622 outputsMetaInfo : Record < string , Pick < ort . Tensor , 'dims' | 'type' > > ;
615623 ioBinding : Test . IOBindingMode ;
624+ mlContext ?: MLContext ;
616625} ) : Promise < [ number , number , ort . InferenceSession . OnnxValueMapType ] > {
617626 const session = options . session ;
618627 const feeds = options . feeds ;
@@ -633,7 +642,7 @@ export async function sessionRun(options: {
633642 if ( Object . hasOwnProperty . call ( feeds , name ) ) {
634643 if ( feeds [ name ] . size > 0 ) {
635644 if ( options . ioBinding === 'ml-location' || options . ioBinding === 'ml-tensor' ) {
636- feeds [ name ] = await createMlTensorForInput ( feeds [ name ] ) ;
645+ feeds [ name ] = await createMLTensorForInput ( options . mlContext ! , feeds [ name ] ) ;
637646 } else {
638647 feeds [ name ] = createGpuTensorForInput ( feeds [ name ] ) ;
639648 }
@@ -650,7 +659,7 @@ export async function sessionRun(options: {
650659 fetches [ name ] = new ort . Tensor ( type , [ ] , dims ) ;
651660 } else {
652661 if ( options . ioBinding === 'ml-tensor' ) {
653- fetches [ name ] = await createMlTensorForOutput ( type , dims ) ;
662+ fetches [ name ] = await createMLTensorForOutput ( options . mlContext ! , type , dims ) ;
654663 } else {
655664 fetches [ name ] = createGpuTensorForOutput ( type , dims ) ;
656665 }
@@ -701,8 +710,8 @@ export async function runModelTestSet(
701710 const outputsMetaInfo : Record < string , ort . Tensor > = { } ;
702711 testCase . inputs ! . forEach ( ( tensor ) => feeds [ tensor . name ] = tensor ) ;
703712 testCase . outputs ! . forEach ( ( tensor ) => outputsMetaInfo [ tensor . name ] = tensor ) ;
704- const [ start , end , outputs ] =
705- await sessionRun ( { session : context . session , feeds, outputsMetaInfo, ioBinding : context . ioBinding } ) ;
713+ const [ start , end , outputs ] = await sessionRun (
714+ { session : context . session , feeds, outputsMetaInfo, ioBinding : context . ioBinding , mlContext : context . mlContext } ) ;
706715 if ( context . perfData . count === 0 ) {
707716 context . perfData . firstRun = end - start ;
708717 } else {
0 commit comments