1
1
import { Niivue } from '@niivue/niivue'
2
- import * as ort from 'onnxruntime-web' ;
2
+ // IMPORTANT: we need to import this specific file.
3
+ import * as ort from "./node_modules/onnxruntime-web/dist/ort.all.mjs"
4
+ console . log ( ort ) ;
3
5
async function main ( ) {
4
6
aboutBtn . onclick = function ( ) {
5
7
let url = "https://github.com/axinging/mlmodel-convension-demo/blob/main/onnx/onnx-brainchop.html"
@@ -17,79 +19,103 @@ async function main() {
17
19
const nv1 = new Niivue ( defaults )
18
20
nv1 . attachToCanvas ( gl1 )
19
21
await nv1 . loadVolumes ( [ { url : './t1_crop.nii.gz' } ] )
22
+ // FIXME: Do we want to conform?
23
+ const conformed = await nv1 . conform (
24
+ nv1 . volumes [ 0 ] ,
25
+ false ,
26
+ true ,
27
+ true
28
+ )
29
+ nv1 . removeVolume ( nv1 . volumes [ 0 ] )
30
+ nv1 . addVolume ( conformed )
20
31
21
32
let feedsInfo = [ ] ;
22
33
function getFeedInfo ( feed , type , data , dims ) {
23
- const warmupTimes = 0 ;
24
- const runTimes = 1 ;
25
- for ( let i = 0 ; i < warmupTimes + runTimes ; i ++ ) {
26
- let typedArray ;
27
- let typeBytes ;
28
- if ( type === 'bool' ) {
29
- data = [ data ] ;
30
- dims = [ 1 ] ;
31
- typeBytes = 1 ;
32
- } else if ( type === 'int8' ) {
33
- typedArray = Int8Array ;
34
- } else if ( type === 'float16' ) {
35
- typedArray = Uint16Array ;
36
- } else if ( type === 'int32' ) {
37
- typedArray = Int32Array ;
38
- } else if ( type === 'uint32' ) {
39
- typedArray = Uint32Array ;
40
- } else if ( type === 'float32' ) {
41
- typedArray = Float32Array ;
42
- } else if ( type === 'int64' ) {
43
- typedArray = BigInt64Array ;
44
- }
45
- if ( typeBytes === undefined ) {
46
- typeBytes = typedArray . BYTES_PER_ELEMENT ;
47
- }
34
+ const warmupTimes = 0 ;
35
+ const runTimes = 1 ;
36
+ for ( let i = 0 ; i < warmupTimes + runTimes ; i ++ ) {
37
+ let typedArray ;
38
+ let typeBytes ;
39
+ if ( type === 'bool' ) {
40
+ data = [ data ] ;
41
+ dims = [ 1 ] ;
42
+ typeBytes = 1 ;
43
+ } else if ( type === 'int8' ) {
44
+ typedArray = Int8Array ;
45
+ } else if ( type === 'float16' ) {
46
+ typedArray = Uint16Array ;
47
+ } else if ( type === 'int32' ) {
48
+ typedArray = Int32Array ;
49
+ } else if ( type === 'uint32' ) {
50
+ typedArray = Uint32Array ;
51
+ } else if ( type === 'float32' ) {
52
+ typedArray = Float32Array ;
53
+ } else if ( type === 'int64' ) {
54
+ typedArray = BigInt64Array ;
55
+ }
56
+ if ( typeBytes === undefined ) {
57
+ typeBytes = typedArray . BYTES_PER_ELEMENT ;
58
+ }
48
59
49
- let size , _data ;
50
- if ( Array . isArray ( data ) || ArrayBuffer . isView ( data ) ) {
51
- size = data . length ;
52
- _data = data ;
53
- } else {
54
- size = dims . reduce ( ( a , b ) => a * b ) ;
55
- if ( data === 'random' ) {
56
- _data = typedArray . from ( { length : size } , ( ) => getRandom ( type ) ) ;
57
- } else {
58
- _data = typedArray . from ( { length : size } , ( ) => data ) ;
59
- }
60
- }
60
+ let size , _data ;
61
+ if ( Array . isArray ( data ) || ArrayBuffer . isView ( data ) ) {
62
+ size = data . length ;
63
+ _data = data ;
64
+ } else {
65
+ size = dims . reduce ( ( a , b ) => a * b ) ;
66
+ if ( data === 'random' ) {
67
+ _data = typedArray . from ( { length : size } , ( ) => getRandom ( type ) ) ;
68
+ } else {
69
+ _data = typedArray . from ( { length : size } , ( ) => data ) ;
70
+ }
71
+ }
61
72
62
- if ( i > feedsInfo . length - 1 ) {
63
- feedsInfo . push ( new Map ( ) ) ;
64
- }
65
- feedsInfo [ i ] . set ( feed , [ type , _data , dims , Math . ceil ( size * typeBytes / 16 ) * 16 ] ) ;
73
+ if ( i > feedsInfo . length - 1 ) {
74
+ feedsInfo . push ( new Map ( ) ) ;
66
75
}
67
- return feedsInfo ;
76
+ feedsInfo [ i ] . set ( feed , [ type , _data , dims , Math . ceil ( size * typeBytes / 16 ) * 16 ] ) ;
77
+ }
78
+ return feedsInfo ;
68
79
}
69
80
const option = {
70
- executionProviders : [
71
- {
72
- //name: 'webgpu',
73
- name : 'webgl' ,
74
- } ,
75
- ] ,
76
- graphOptimizationLevel : 'extended' ,
77
- optimizedModelFilepath : 'opt.onnx'
81
+ executionProviders : [
82
+ {
83
+ name : 'webgpu' ,
84
+ } ,
85
+ ] ,
86
+ graphOptimizationLevel : 'extended' ,
87
+ optimizedModelFilepath : 'opt.onnx'
78
88
} ;
79
89
80
90
const session = await ort . InferenceSession . create ( './model_5_channels.onnx' , option ) ;
81
91
const shape = [ 1 , 1 , 256 , 256 , 256 ] ;
82
- const temp = getFeedInfo ( "input.1" , "float32" , 0 , shape ) ;
92
+ // FIXME: Do we want to use a real image for inference?
93
+ const imgData = nv1 . volumes [ 0 ] . img ;
94
+ const expectedLength = shape . reduce ( ( a , b ) => a * b ) ;
95
+ // FIXME: Do we need want this?
96
+ if ( imgData . length !== expectedLength ) {
97
+ throw new Error ( `imgData length (${ imgData . length } ) does not match expected tensor length (${ expectedLength } )` ) ;
98
+ }
99
+
100
+ const temp = getFeedInfo ( "input.1" , "float32" , imgData , shape ) ;
83
101
let dataA = temp [ 0 ] . get ( 'input.1' ) [ 1 ] ;
84
- // let dataTemp = await loadJSON("./onnx-branchchop-input64.jsonc");
85
- // dataA = dataTemp['data'];
86
102
const tensorA = new ort . Tensor ( 'float32' , dataA , shape ) ;
87
-
103
+
88
104
const feeds = { "input.1" : tensorA } ;
89
105
// feed inputs and run
90
106
console . log ( "before run" ) ;
91
107
const results = await session . run ( feeds ) ;
92
- console . log ( "after run" ) ;
108
+ console . log ( results ) ;
109
+ console . log ( "after run" )
110
+ // FIXME: is this really the output data? It doesn't make sense when rendered,
111
+ // but then again, maybe the input was wrong?
112
+ const outData = results [ 39 ] . data
113
+ const newImg = nv1 . cloneVolume ( 0 ) ;
114
+ newImg . img = outData
115
+ // Add the output to niivue
116
+ nv1 . addVolume ( newImg )
117
+ nv1 . setColormap ( newImg . id , "red" )
118
+ nv1 . setOpacity ( 1 , 0.5 )
93
119
}
94
120
95
121
main ( )
0 commit comments