1717
1818import '@tensorflow/tfjs-backend-cpu' ;
1919import '@tensorflow/tfjs-backend-webgl' ;
20+ import '@tensorflow/tfjs-backend-webgpu' ;
2021
2122import * as tfc from '@tensorflow/tfjs-core' ;
2223// tslint:disable-next-line: no-imports-from-dist
@@ -27,27 +28,28 @@ import {SMOKE} from './constants';
2728/**
2829 * This file tests backend switching scenario.
2930 */
30-
31+ // TODO: Support backend switching between wasm and cpu.
32+ // https://github.com/tensorflow/tfjs/issues/7623
3133describeWithFlags (
3234 `${ SMOKE } backend switching` , {
33- predicate : testEnv => testEnv . backendName === 'webgl' &&
34- tfc . findBackend ( 'webgl' ) !== null && tfc . findBackend ( 'cpu' ) !== null
35+ predicate : testEnv =>
36+ testEnv . backendName !== 'cpu' && testEnv . backendName !== 'wasm'
3537 } ,
3638
37- ( ) => {
38- it ( `from webgl to cpu.` , async ( ) => {
39- await tfc . setBackend ( 'webgl' ) ;
39+ ( env ) => {
40+ it ( `from ${ env . name } to cpu.` , async ( ) => {
41+ await tfc . setBackend ( env . name ) ;
4042
41- const webglBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
43+ const backendBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
4244
4345 const input = tfc . tensor2d ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 ] , 'float32' ) ;
44- // input is stored in webgl backend.
46+ // input is stored in backend.
4547
4648 const inputReshaped = tfc . reshape ( input , [ 2 , 2 ] ) ;
4749
48- const webglAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
50+ const backendAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
4951
50- expect ( webglAfter ) . toEqual ( webglBefore + 1 ) ;
52+ expect ( backendAfter ) . toEqual ( backendBefore + 1 ) ;
5153
5254 await tfc . setBackend ( 'cpu' ) ;
5355
@@ -56,8 +58,9 @@ describeWithFlags(
5658 const inputReshaped2 = tfc . reshape ( inputReshaped , [ 2 , 2 ] ) ;
5759 // input moved to cpu.
5860
59- // Because input is moved to cpu, data should be deleted from webgl
60- expect ( tfc . findBackend ( 'webgl' ) . numDataIds ( ) ) . toEqual ( webglAfter - 1 ) ;
61+ // Because input is moved to cpu, data should be deleted from backend.
62+ expect ( tfc . findBackend ( env . name ) . numDataIds ( ) )
63+ . toEqual ( backendAfter - 1 ) ;
6164
6265 const cpuAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
6366
@@ -77,7 +80,7 @@ describeWithFlags(
7780 expect ( after ) . toBe ( cpuBefore ) ;
7881 } ) ;
7982
80- it ( `from cpu to webgl .` , async ( ) => {
83+ it ( `from cpu to ${ env . name } .` , async ( ) => {
8184 await tfc . setBackend ( 'cpu' ) ;
8285
8386 const cpuBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
@@ -91,46 +94,47 @@ describeWithFlags(
9194
9295 expect ( cpuAfter ) . toEqual ( cpuBefore + 1 ) ;
9396
94- await tfc . setBackend ( 'webgl' ) ;
97+ await tfc . setBackend ( env . name ) ;
9598
96- const webglBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
99+ const backendBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
97100
98101 const inputReshaped2 = tfc . reshape ( inputReshaped , [ 2 , 2 ] ) ;
99- // input moved to webgl.
102+ // input moved to webgl or webgpu .
100103
101- // Because input is moved to webgl, data should be deleted from cpu
104+ // Because input is moved to backend, data should be deleted
105+ // from cpu.
102106 expect ( tfc . findBackend ( 'cpu' ) . numDataIds ( ) ) . toEqual ( cpuAfter - 1 ) ;
103107
104- const webglAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
108+ const backendAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
105109
106- expect ( webglAfter ) . toEqual ( webglBefore + 1 ) ;
110+ expect ( backendAfter ) . toEqual ( backendBefore + 1 ) ;
107111
108112 input . dispose ( ) ;
109113
110- expect ( tfc . engine ( ) . backend . numDataIds ( ) ) . toEqual ( webglAfter ) ;
114+ expect ( tfc . engine ( ) . backend . numDataIds ( ) ) . toEqual ( backendAfter ) ;
111115
112116 inputReshaped . dispose ( ) ;
113117
114- expect ( tfc . engine ( ) . backend . numDataIds ( ) ) . toEqual ( webglAfter ) ;
118+ expect ( tfc . engine ( ) . backend . numDataIds ( ) ) . toEqual ( backendAfter ) ;
115119
116120 inputReshaped2 . dispose ( ) ;
117121
118122 const after = tfc . engine ( ) . backend . numDataIds ( ) ;
119123
120- expect ( after ) . toBe ( webglBefore ) ;
124+ expect ( after ) . toBe ( backendBefore ) ;
121125 } ) ;
122126
123127 it ( 'can execute op with data from mixed backends' , async ( ) => {
124128 const numTensors = tfc . memory ( ) . numTensors ;
125- const webglNumDataIds = tfc . findBackend ( 'webgl' ) . numDataIds ( ) ;
129+ const backendNumDataIds = tfc . findBackend ( env . name ) . numDataIds ( ) ;
126130 const cpuNumDataIds = tfc . findBackend ( 'cpu' ) . numDataIds ( ) ;
127131
128132 await tfc . setBackend ( 'cpu' ) ;
129133 // This scalar lives in cpu.
130134 const a = tfc . scalar ( 5 ) ;
131135
132- await tfc . setBackend ( 'webgl' ) ;
133- // This scalar lives in webgl.
136+ await tfc . setBackend ( env . name ) ;
137+ // This scalar lives in webgl or webgpu .
134138 const b = tfc . scalar ( 3 ) ;
135139
136140 // Verify that ops can execute with mixed backend data.
@@ -141,32 +145,34 @@ describeWithFlags(
141145 tfc . test_util . expectArraysClose ( await result . data ( ) , [ 8 ] ) ;
142146 expect ( tfc . findBackend ( 'cpu' ) . numDataIds ( ) ) . toBe ( cpuNumDataIds + 3 ) ;
143147
144- await tfc . setBackend ( 'webgl' ) ;
148+ await tfc . setBackend ( env . name ) ;
145149 tfc . test_util . expectArraysClose ( await tfc . add ( a , b ) . data ( ) , [ 8 ] ) ;
146- expect ( tfc . findBackend ( 'webgl' ) . numDataIds ( ) ) . toBe ( webglNumDataIds + 3 ) ;
150+ expect ( tfc . findBackend ( env . name ) . numDataIds ( ) )
151+ . toBe ( backendNumDataIds + 3 ) ;
147152
148153 tfc . engine ( ) . endScope ( ) ;
149154
150155 expect ( tfc . memory ( ) . numTensors ) . toBe ( numTensors + 2 ) ;
151- expect ( tfc . findBackend ( 'webgl' ) . numDataIds ( ) ) . toBe ( webglNumDataIds + 2 ) ;
156+ expect ( tfc . findBackend ( env . name ) . numDataIds ( ) )
157+ . toBe ( backendNumDataIds + 2 ) ;
152158 expect ( tfc . findBackend ( 'cpu' ) . numDataIds ( ) ) . toBe ( cpuNumDataIds ) ;
153159
154160 tfc . dispose ( [ a , b ] ) ;
155161
156162 expect ( tfc . memory ( ) . numTensors ) . toBe ( numTensors ) ;
157- expect ( tfc . findBackend ( 'webgl' ) . numDataIds ( ) ) . toBe ( webglNumDataIds ) ;
163+ expect ( tfc . findBackend ( env . name ) . numDataIds ( ) ) . toBe ( backendNumDataIds ) ;
158164 expect ( tfc . findBackend ( 'cpu' ) . numDataIds ( ) ) . toBe ( cpuNumDataIds ) ;
159165 } ) ;
160166
161167 // tslint:disable-next-line: ban
162- xit ( ' can move complex tensor from cpu to webgl.' , async ( ) => {
168+ xit ( ` can move complex tensor from cpu to ${ env . name } .` , async ( ) => {
163169 await tfc . setBackend ( 'cpu' ) ;
164170
165171 const real1 = tfc . tensor1d ( [ 1 ] ) ;
166172 const imag1 = tfc . tensor1d ( [ 2 ] ) ;
167173 const complex1 = tfc . complex ( real1 , imag1 ) ;
168174
169- await tfc . setBackend ( 'webgl' ) ;
175+ await tfc . setBackend ( env . name ) ;
170176
171177 const real2 = tfc . tensor1d ( [ 3 ] ) ;
172178 const imag2 = tfc . tensor1d ( [ 4 ] ) ;
@@ -178,8 +184,8 @@ describeWithFlags(
178184 } ) ;
179185
180186 // tslint:disable-next-line: ban
181- xit ( ' can move complex tensor from webgl to cpu.' , async ( ) => {
182- await tfc . setBackend ( 'webgl' ) ;
187+ xit ( ` can move complex tensor from ${ env . name } to cpu.` , async ( ) => {
188+ await tfc . setBackend ( env . name ) ;
183189
184190 const real1 = tfc . tensor1d ( [ 1 ] ) ;
185191 const imag1 = tfc . tensor1d ( [ 2 ] ) ;
0 commit comments