Skip to content

Commit

Permalink
upload
Browse files Browse the repository at this point in the history
  • Loading branch information
hanayik committed Sep 25, 2024
0 parents commit 6ec947f
Show file tree
Hide file tree
Showing 14 changed files with 1,840 additions and 0 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/ghpages.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Build and Deploy
on:
push:
branches:
- main
jobs:
build-and-publish-live-demo:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Install and Build
run: |
npm install
npm run build
- name: Deploy
uses: JamesIves/github-pages-deploy-action@v4
with:
branch: demo # The branch the action should deploy to.
folder: dist # The folder the action should deploy.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
node_modules
.DS_Store
dist
Empty file added LICENSE
Empty file.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# NiiVue + ScribblePromt

TBD
69 changes: 69 additions & 0 deletions docs/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch.nn.functional as F
from scribbleprompt.unet import UNet

class Predictor:
"""
wrapper for ScribblePrompt-UNet model with ONNX export functionality.
"""
def __init__(self, path: str, verbose: bool = True):
self.path = path
self.verbose = verbose
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.build_model()
self.load()
self.model.eval()
self.to_device()

def build_model(self):
"""
build the ScribblePrompt-UNet model.
"""
self.model = UNet(
in_channels=5,
out_channels=1,
features=[192, 192, 192, 192],
)

def load(self):
"""
load the state of the model from a checkpoint file.
"""
with open(self.path, "rb") as f:
state = torch.load(f, map_location=self.device)
self.model.load_state_dict(state, strict=True)
if self.verbose:
print(f"loaded checkpoint from {self.path} to {self.device}")

def to_device(self):
"""
move the model to the appropriate device.
"""
self.model = self.model.to(self.device)

def export_to_onnx(self, onnx_path="model.onnx"):
"""
export the model to ONNX format with dynamic H and W (height and width).
"""
# prepare a dummy input with arbitrary H and W, as ONNX export requires a concrete input shape
dummy_input = torch.randn(1, 5, 256, 256).to(self.device)
torch.onnx.export(
self.model,
dummy_input,
onnx_path,
export_params=True,
opset_version=20,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
# make H and W dynamic, along with the batch size
dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},
'output': {0: 'batch_size', 2: 'height', 3: 'width'}}
)
print(f"model exported to {onnx_path}")

# usage from CLI
if __name__ == "__main__":
checkpoint_path = "../checkpoints/ScribblePrompt_unet_v1_nf192_res128.pt"
predictor = Predictor(checkpoint_path)
predictor.export_to_onnx("scribbleprompt_unet.onnx")
39 changes: 39 additions & 0 deletions index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<!doctype html>
<html lang="en">

<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link rel="stylesheet" href="./niivue.css" />
<title>Niivue ScribblePromt</title>
</head>

<body>
<header>
<button id="segmentBtn">Segment</button>
<label for="clipCheck">Clip Plane</label>
<input type="checkbox" id="clipCheck" unchecked />
<label for="opacitySlider0">Background Opacity</label>
<input type="range" min="0" max="255" value="255" class="slider" id="opacitySlider0" />
&nbsp;
<label for="opacitySlider1">Overlay Opacity</label>
<input type="range" min="0" max="255" value="128" class="slider" id="opacitySlider1" />
&nbsp;
<!-- drawing opacity slider -->
<label for="opacitySlider2">Drawing Opacity</label>
<input type="range" min="0" max="255" value="255" class="slider" id="opacitySlider2" />
&nbsp;
<!-- conform checkbox -->
<label for="conform">Conform</label>
<input type="checkbox" id="conform" unchecked />
<button id="saveImgBtn">Save segmentation</button>
&nbsp;
<div id="loadingCircle" class="loading-circle hidden"></div>
</header>
<main id="canvas-container">
<canvas id="gl1"></canvas>
</main>
<footer id="intensity">&nbsp;</footer>
<script type="module" src="/main.js"></script>
</body>
</html>
Loading

0 comments on commit 6ec947f

Please sign in to comment.