-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
99be40a
commit 5df670e
Showing
4 changed files
with
96 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,155 +1,113 @@ | ||
import { Niivue } from '@niivue/niivue' | ||
// IMPORTANT: we need to import this specific file. | ||
import * as ort from "./node_modules/onnxruntime-web/dist/ort.all.mjs" | ||
console.log(ort); | ||
async function main() { | ||
aboutBtn.onclick = function () { | ||
let url = "https://github.com/axinging/mlmodel-convension-demo/blob/main/onnx/onnx-brainchop.html" | ||
window.open(url, '_blank').focus(); | ||
async function ensureConformed() { | ||
const nii = nv1.volumes[0] | ||
let isConformed = nii.dims[1] === 256 && nii.dims[2] === 256 && nii.dims[3] === 256 | ||
if (nii.permRAS[0] !== -1 || nii.permRAS[1] !== 3 || nii.permRAS[2] !== -2) { | ||
isConformed = false | ||
} | ||
if (isConformed) { | ||
return | ||
} | ||
const nii2 = await nv1.conform(nii, false) | ||
await nv1.removeVolume(nv1.volumes[0]) | ||
await nv1.addVolume(nii2) | ||
} | ||
async function closeAllOverlays() { | ||
while (nv1.volumes.length > 1) { | ||
await nv1.removeVolume(nv1.volumes[1]) | ||
} | ||
} | ||
segmentBtn.onclick = async function () { | ||
if (nv1.volumes.length < 1) { | ||
window.alert('Please open a voxel-based image') | ||
return | ||
} | ||
await closeAllOverlays() | ||
await ensureConformed() | ||
let img32 = new Float32Array(nv1.volumes[0].img) | ||
// normalize input data to range 0..1 | ||
// TODO: ONNX not JavaScript https://onnx.ai/onnx/operators/onnx_aionnxml_Normalizer.html | ||
let mx = img32[0] | ||
let mn = mx | ||
for (let i = 0; i < img32.length; i++) { | ||
mx = Math.max(mx, img32[i]) | ||
mn = Math.min(mn, img32[i]) | ||
} | ||
let scale32 = 1 / (mx - mn) | ||
for (let i = 0; i < img32.length; i++) { | ||
img32[i] = (img32[i] - mn) * scale32 | ||
} | ||
// load onnx model | ||
const option = { | ||
executionProviders: [ | ||
{ | ||
name: 'webgpu', | ||
}, | ||
], | ||
graphOptimizationLevel: 'extended', | ||
optimizedModelFilepath: 'opt.onnx' | ||
} | ||
const session = await ort.InferenceSession.create('./model.onnx', option) | ||
const shape = [1, 1, 256, 256, 256] | ||
const nvox = shape.reduce((a, b) => a * b) | ||
if (img32.length !== nvox) { | ||
throw new Error(`img32 length (${img32.length}) does not match expected tensor length (${expectedLength})`) | ||
} | ||
const imgTensor = new ort.Tensor('float32', img32, shape) | ||
const feeds = { "input": imgTensor } | ||
// run onnx inference | ||
const results = await session.run(feeds) | ||
const classImg = results.output.cpuData | ||
// classImg will have one volume per class | ||
const nvol = Math.floor(classImg.length / nvox) | ||
if ((nvol < 2) || (classImg.length != (nvol * nvox))) { | ||
console.log('Fatal error') | ||
} | ||
console.log(`${nvol} volumes each with ${nvox} voxels`) | ||
// argmax should identify correct class for each voxel | ||
// TODO: ONNX not JavaScript https://onnx.ai/onnx/operators/onnx__ArgMax.html | ||
const argMaxImg = new Float32Array(nvox) | ||
for (let vox = 0; vox < nvox; vox++) { | ||
let mxVal = classImg[vox] | ||
let mxVol = 0 | ||
for (let vol = 1; vol < nvol; vol++) { | ||
const val = classImg[vox + (vol * nvox)] | ||
if (val > mxVal) { | ||
mxVol = vol | ||
mxVal = val | ||
} | ||
} | ||
// Next line incorrect! | ||
argMaxImg[vox] = mxVal | ||
// Next line should be correct: brightest volume | ||
// argMaxImg[vox] = mxVol | ||
} | ||
// | ||
const newImg = nv1.cloneVolume(0) | ||
newImg.img = argMaxImg | ||
newImg.hdr.datatypeCode = 16 // = float32 | ||
newImg.hdr.dims[4] = 1 | ||
newImg.trustCalMinMax = false | ||
console.log(newImg) | ||
// Add the output to niivue | ||
nv1.addVolume(newImg) | ||
nv1.setColormap(newImg.id, "actc") | ||
nv1.setOpacity(1, 0.5) | ||
} | ||
function handleLocationChange(data) { | ||
document.getElementById("intensity").innerHTML = data.string | ||
} | ||
const defaults = { | ||
backColor: [0.4, 0.4, 0.4, 1], | ||
show3Dcrosshair: true, | ||
onLocationChange: handleLocationChange, | ||
dragAndDropEnabled: false, | ||
} | ||
const nv1 = new Niivue(defaults) | ||
nv1.attachToCanvas(gl1) | ||
await nv1.loadVolumes([{ url: './t1_crop.nii.gz' }]) | ||
// FIXME: Do we want to conform? | ||
/*const conformed = await nv1.conform( | ||
nv1.volumes[0], | ||
false, | ||
true, | ||
true | ||
) | ||
nv1.removeVolume(nv1.volumes[0]) | ||
nv1.addVolume(conformed)*/ | ||
|
||
let img32 = new Float32Array(nv1.volumes[0].img) | ||
let mx = img32[0] | ||
let mn = mx | ||
for (let i = 0; i < img32.length; i++) { | ||
mx = Math.max(mx, img32[i]) | ||
mn = Math.min(mn, img32[i]) | ||
} | ||
let scale32 = 1 / (mx - mn) | ||
for (let i = 0; i < img32.length; i++) { | ||
img32[i] = (img32[i] - mn) * scale32 | ||
} | ||
//sanity check: report that image now normalized 0..1 | ||
mx = img32[0] | ||
mn = mx | ||
for (let i = 0; i < img32.length; i++) { | ||
mx = Math.max(mx, img32[i]) | ||
mn = Math.min(mn, img32[i]) | ||
} | ||
console.log(`Normalized image intensity is ${mn}..${mx}`) | ||
|
||
let feedsInfo = []; | ||
function getFeedInfo(feed, type, data, dims) { | ||
const warmupTimes = 0; | ||
const runTimes = 1; | ||
for (let i = 0; i < warmupTimes + runTimes; i++) { | ||
let typedArray; | ||
let typeBytes; | ||
if (type === 'bool') { | ||
data = [data]; | ||
dims = [1]; | ||
typeBytes = 1; | ||
} else if (type === 'int8') { | ||
typedArray = Int8Array; | ||
} else if (type === 'float16') { | ||
typedArray = Uint16Array; | ||
} else if (type === 'int32') { | ||
typedArray = Int32Array; | ||
} else if (type === 'uint32') { | ||
typedArray = Uint32Array; | ||
} else if (type === 'float32') { | ||
typedArray = Float32Array; | ||
} else if (type === 'int64') { | ||
typedArray = BigInt64Array; | ||
} | ||
if (typeBytes === undefined) { | ||
typeBytes = typedArray.BYTES_PER_ELEMENT; | ||
} | ||
|
||
let size, _data; | ||
if (Array.isArray(data) || ArrayBuffer.isView(data)) { | ||
size = data.length; | ||
_data = data; | ||
} else { | ||
size = dims.reduce((a, b) => a * b); | ||
if (data === 'random') { | ||
_data = typedArray.from({ length: size }, () => getRandom(type)); | ||
} else { | ||
_data = typedArray.from({ length: size }, () => data); | ||
} | ||
} | ||
|
||
if (i > feedsInfo.length - 1) { | ||
feedsInfo.push(new Map()); | ||
} | ||
feedsInfo[i].set(feed, [type, _data, dims, Math.ceil(size * typeBytes / 16) * 16]); | ||
} | ||
return feedsInfo; | ||
} | ||
const option = { | ||
executionProviders: [ | ||
{ | ||
name: 'webgpu', | ||
}, | ||
], | ||
graphOptimizationLevel: 'extended', | ||
optimizedModelFilepath: 'opt.onnx' | ||
}; | ||
|
||
const session = await ort.InferenceSession.create('./model_5_channels.onnx', option); | ||
const shape = [1, 1, 256, 256, 256]; | ||
// FIXME: Do we want to use a real image for inference? | ||
const imgData = img32; | ||
const expectedLength = shape.reduce((a, b) => a * b); | ||
if (img32.length !== expectedLength) { | ||
throw new Error(`img32 length (${img32.length}) does not match expected tensor length (${expectedLength})`); | ||
} | ||
|
||
const temp = getFeedInfo("input.1", "float32", img32, shape); | ||
let dataA = temp[0].get('input.1')[1]; | ||
const tensorA = new ort.Tensor('float32', dataA, shape); | ||
|
||
const feeds = { "input.1": tensorA }; | ||
// feed inputs and run | ||
const results = await session.run(feeds); | ||
console.log(results); | ||
const aiVox = results[39].data | ||
const outDims = results[39].dims | ||
const vols = outDims[1] | ||
const vox = outDims[2] * outDims[3] * outDims[4] | ||
if ((img32.length != vox) || (vols != 3) || (aiVox.length != (vols * vox))) { | ||
console.log('Fatal error') | ||
} | ||
const outData = new Float32Array(vox) | ||
for (let i = 0; i < vox; i++) { | ||
/*let mx = 2 | ||
if ((aiVox[i+vox] > aiVox[i]) && (aiVox[i+vox] > aiVox[i+vox+vox])) | ||
mx = 1 | ||
else if ((aiVox[i] > aiVox[i+vox]) && (aiVox[i] > aiVox[i+vox+vox])) | ||
mx = 0*/ | ||
outData[i] = aiVox[i] | ||
} | ||
const newImg = nv1.cloneVolume(0); | ||
newImg.img = outData | ||
newImg.hdr.datatypeCode = 16 //float32 | ||
newImg.hdr.dims[4] = 1 | ||
newImg.trustCalMinMax = false | ||
console.log(newImg) | ||
// Add the output to niivue | ||
nv1.addVolume(newImg) | ||
nv1.setColormap(newImg.id, "actc") | ||
nv1.setOpacity(1, 0.5) | ||
segmentBtn.onclick() | ||
} | ||
|
||
main() |
Binary file not shown.
Binary file not shown.