Skip to content

Commit

Permalink
New onnx model
Browse files Browse the repository at this point in the history
  • Loading branch information
neurolabusc committed Jul 17, 2024
1 parent 99be40a commit 5df670e
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 138 deletions.
2 changes: 1 addition & 1 deletion index.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

<body>
<header>
<button id="aboutBtn">About</button>
<button id="segmentBtn">Segment</button>
</header>
<main id="canvas-container">
<canvas id="gl1"></canvas>
Expand Down
232 changes: 95 additions & 137 deletions main.js
Original file line number Diff line number Diff line change
@@ -1,155 +1,113 @@
import { Niivue } from '@niivue/niivue'
// IMPORTANT: we need to import this specific file.
import * as ort from "./node_modules/onnxruntime-web/dist/ort.all.mjs"
console.log(ort);
async function main() {
aboutBtn.onclick = function () {
let url = "https://github.com/axinging/mlmodel-convension-demo/blob/main/onnx/onnx-brainchop.html"
window.open(url, '_blank').focus();
async function ensureConformed() {
const nii = nv1.volumes[0]
let isConformed = nii.dims[1] === 256 && nii.dims[2] === 256 && nii.dims[3] === 256
if (nii.permRAS[0] !== -1 || nii.permRAS[1] !== 3 || nii.permRAS[2] !== -2) {
isConformed = false
}
if (isConformed) {
return
}
const nii2 = await nv1.conform(nii, false)
await nv1.removeVolume(nv1.volumes[0])
await nv1.addVolume(nii2)
}
async function closeAllOverlays() {
while (nv1.volumes.length > 1) {
await nv1.removeVolume(nv1.volumes[1])
}
}
segmentBtn.onclick = async function () {
if (nv1.volumes.length < 1) {
window.alert('Please open a voxel-based image')
return
}
await closeAllOverlays()
await ensureConformed()
let img32 = new Float32Array(nv1.volumes[0].img)
// normalize input data to range 0..1
// TODO: ONNX not JavaScript https://onnx.ai/onnx/operators/onnx_aionnxml_Normalizer.html
let mx = img32[0]
let mn = mx
for (let i = 0; i < img32.length; i++) {
mx = Math.max(mx, img32[i])
mn = Math.min(mn, img32[i])
}
let scale32 = 1 / (mx - mn)
for (let i = 0; i < img32.length; i++) {
img32[i] = (img32[i] - mn) * scale32
}
// load onnx model
const option = {
executionProviders: [
{
name: 'webgpu',
},
],
graphOptimizationLevel: 'extended',
optimizedModelFilepath: 'opt.onnx'
}
const session = await ort.InferenceSession.create('./model.onnx', option)
const shape = [1, 1, 256, 256, 256]
const nvox = shape.reduce((a, b) => a * b)
if (img32.length !== nvox) {
throw new Error(`img32 length (${img32.length}) does not match expected tensor length (${expectedLength})`)
}
const imgTensor = new ort.Tensor('float32', img32, shape)
const feeds = { "input": imgTensor }
// run onnx inference
const results = await session.run(feeds)
const classImg = results.output.cpuData
// classImg will have one volume per class
const nvol = Math.floor(classImg.length / nvox)
if ((nvol < 2) || (classImg.length != (nvol * nvox))) {
console.log('Fatal error')
}
console.log(`${nvol} volumes each with ${nvox} voxels`)
// argmax should identify correct class for each voxel
// TODO: ONNX not JavaScript https://onnx.ai/onnx/operators/onnx__ArgMax.html
const argMaxImg = new Float32Array(nvox)
for (let vox = 0; vox < nvox; vox++) {
let mxVal = classImg[vox]
let mxVol = 0
for (let vol = 1; vol < nvol; vol++) {
const val = classImg[vox + (vol * nvox)]
if (val > mxVal) {
mxVol = vol
mxVal = val
}
}
// Next line incorrect!
argMaxImg[vox] = mxVal
// Next line should be correct: brightest volume
// argMaxImg[vox] = mxVol
}
//
const newImg = nv1.cloneVolume(0)
newImg.img = argMaxImg
newImg.hdr.datatypeCode = 16 // = float32
newImg.hdr.dims[4] = 1
newImg.trustCalMinMax = false
console.log(newImg)
// Add the output to niivue
nv1.addVolume(newImg)
nv1.setColormap(newImg.id, "actc")
nv1.setOpacity(1, 0.5)
}
function handleLocationChange(data) {
document.getElementById("intensity").innerHTML = data.string
}
const defaults = {
backColor: [0.4, 0.4, 0.4, 1],
show3Dcrosshair: true,
onLocationChange: handleLocationChange,
dragAndDropEnabled: false,
}
const nv1 = new Niivue(defaults)
nv1.attachToCanvas(gl1)
await nv1.loadVolumes([{ url: './t1_crop.nii.gz' }])
// FIXME: Do we want to conform?
/*const conformed = await nv1.conform(
nv1.volumes[0],
false,
true,
true
)
nv1.removeVolume(nv1.volumes[0])
nv1.addVolume(conformed)*/

let img32 = new Float32Array(nv1.volumes[0].img)
let mx = img32[0]
let mn = mx
for (let i = 0; i < img32.length; i++) {
mx = Math.max(mx, img32[i])
mn = Math.min(mn, img32[i])
}
let scale32 = 1 / (mx - mn)
for (let i = 0; i < img32.length; i++) {
img32[i] = (img32[i] - mn) * scale32
}
//sanity check: report that image now normalized 0..1
mx = img32[0]
mn = mx
for (let i = 0; i < img32.length; i++) {
mx = Math.max(mx, img32[i])
mn = Math.min(mn, img32[i])
}
console.log(`Normalized image intensity is ${mn}..${mx}`)

let feedsInfo = [];
function getFeedInfo(feed, type, data, dims) {
const warmupTimes = 0;
const runTimes = 1;
for (let i = 0; i < warmupTimes + runTimes; i++) {
let typedArray;
let typeBytes;
if (type === 'bool') {
data = [data];
dims = [1];
typeBytes = 1;
} else if (type === 'int8') {
typedArray = Int8Array;
} else if (type === 'float16') {
typedArray = Uint16Array;
} else if (type === 'int32') {
typedArray = Int32Array;
} else if (type === 'uint32') {
typedArray = Uint32Array;
} else if (type === 'float32') {
typedArray = Float32Array;
} else if (type === 'int64') {
typedArray = BigInt64Array;
}
if (typeBytes === undefined) {
typeBytes = typedArray.BYTES_PER_ELEMENT;
}

let size, _data;
if (Array.isArray(data) || ArrayBuffer.isView(data)) {
size = data.length;
_data = data;
} else {
size = dims.reduce((a, b) => a * b);
if (data === 'random') {
_data = typedArray.from({ length: size }, () => getRandom(type));
} else {
_data = typedArray.from({ length: size }, () => data);
}
}

if (i > feedsInfo.length - 1) {
feedsInfo.push(new Map());
}
feedsInfo[i].set(feed, [type, _data, dims, Math.ceil(size * typeBytes / 16) * 16]);
}
return feedsInfo;
}
const option = {
executionProviders: [
{
name: 'webgpu',
},
],
graphOptimizationLevel: 'extended',
optimizedModelFilepath: 'opt.onnx'
};

const session = await ort.InferenceSession.create('./model_5_channels.onnx', option);
const shape = [1, 1, 256, 256, 256];
// FIXME: Do we want to use a real image for inference?
const imgData = img32;
const expectedLength = shape.reduce((a, b) => a * b);
if (img32.length !== expectedLength) {
throw new Error(`img32 length (${img32.length}) does not match expected tensor length (${expectedLength})`);
}

const temp = getFeedInfo("input.1", "float32", img32, shape);
let dataA = temp[0].get('input.1')[1];
const tensorA = new ort.Tensor('float32', dataA, shape);

const feeds = { "input.1": tensorA };
// feed inputs and run
const results = await session.run(feeds);
console.log(results);
const aiVox = results[39].data
const outDims = results[39].dims
const vols = outDims[1]
const vox = outDims[2] * outDims[3] * outDims[4]
if ((img32.length != vox) || (vols != 3) || (aiVox.length != (vols * vox))) {
console.log('Fatal error')
}
const outData = new Float32Array(vox)
for (let i = 0; i < vox; i++) {
/*let mx = 2
if ((aiVox[i+vox] > aiVox[i]) && (aiVox[i+vox] > aiVox[i+vox+vox]))
mx = 1
else if ((aiVox[i] > aiVox[i+vox]) && (aiVox[i] > aiVox[i+vox+vox]))
mx = 0*/
outData[i] = aiVox[i]
}
const newImg = nv1.cloneVolume(0);
newImg.img = outData
newImg.hdr.datatypeCode = 16 //float32
newImg.hdr.dims[4] = 1
newImg.trustCalMinMax = false
console.log(newImg)
// Add the output to niivue
nv1.addVolume(newImg)
nv1.setColormap(newImg.id, "actc")
nv1.setOpacity(1, 0.5)
segmentBtn.onclick()
}

main()
Binary file added public/model.onnx
Binary file not shown.
Binary file removed public/model_5_channels.onnx
Binary file not shown.

0 comments on commit 5df670e

Please sign in to comment.