Skip to content

Commit 1fd4b2c

Browse files
authored
Merge pull request #887 from dfinity/ulan/image-classification-simd
Use Wasm SIMD in the image classification example
2 parents 4f44852 + b6ac198 commit 1fd4b2c

File tree

12 files changed

+130
-40
lines changed

12 files changed

+130
-40
lines changed

rust/image-classification/Cargo.lock

Lines changed: 69 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/image-classification/README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,12 @@
33
This is an ICP smart contract that accepts an image from the user and runs image classification inference.
44
The smart contract consists of two canisters:
55

6-
- the backend canister embeds the [the Tract ONNX inference engine](https://github.com/sonos/tract) with [the MobileNet v2-7 model](https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet). It provides a `classify()` endpoint for the frontend code to call.
6+
- the backend canister embeds the [the Tract ONNX inference engine](https://github.com/sonos/tract) with [the MobileNet v2-7 model](https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet).
7+
It provides `classify()` and `classify_query()` endpoints for the frontend code to call.
8+
The former endpoint is used for replicated execution (running on all nodes) whereas the latter runs only on a single node.
79
- the frontend canister contains the Web assets such as HTML, JS, CSS that are served to the browser.
810

9-
Note that currently Wasm execution is not optimized for this workload.
10-
A single call executes about 24B instructions (~10s).
11-
12-
This is expected to improve in the future with:
13-
14-
- faster deterministic floating-point operations.
15-
- Wasm SIMD (Single-Instruction Multiple Data).
16-
17-
The ICP mainnet subnets and `dfx` running a replica version older than [463296](https://dashboard.internetcomputer.org/release/463296c0bc82ad5999b70245e5f125c14ba7d090) may fail with an instruction-limit-exceeded error.
11+
This example uses Wasm SIMD instructions that are available in `dfx` version `0.20.2-beta.0` or newer.
1812

1913
# Dependencies
2014

@@ -45,12 +39,18 @@ Install NodeJS dependencies for the frontend:
4539
npm install
4640
```
4741

42+
Install `wasm-opt`:
43+
44+
```
45+
cargo install wasm-opt
46+
```
47+
4848
# Build
4949

5050
```
5151
dfx start --background
5252
dfx deploy
5353
```
5454

55-
If the deployment is successfull, the it will show the `frontend` URL.
55+
If the deployment is successful, the it will show the `frontend` URL.
5656
Open that URL in browser to interact with the smart contract.

rust/image-classification/dfx.json

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
"package": "backend",
66
"type": "custom",
77
"wasm": "target/wasm32-wasi/release/backend-ic.wasm",
8-
"build": [
9-
"cargo build --release --target=wasm32-wasi",
10-
"wasi2ic ./target/wasm32-wasi/release/backend.wasm ./target/wasm32-wasi/release/backend-ic.wasm"
11-
]
12-
8+
"build": [ "bash build.sh" ]
139
},
1410
"frontend": {
1511
"dependencies": [

rust/image-classification/src/backend/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ prost = "0.11.0"
1616
prost-types = "0.11.0"
1717
image = { version = "0.24", features = ["png"], default-features = false }
1818
serde = { version = "1.0", features = ["derive"] }
19-
tract-onnx = { git = "https://github.com/sonos/tract", version = "=0.21.2-pre" }
19+
tract-onnx = { git = "https://github.com/sonos/tract", rev = "2a2914ac29390cc08963301c9f3d437b52dd321a" }
2020
ic-stable-structures = "0.6"
2121
ic-wasi-polyfill = { git = "https://github.com/wasm-forge/ic-wasi-polyfill", version = "0.3.17" }

rust/image-classification/src/backend/backend.did

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ type ClassificationResult = variant {
1414

1515
service : {
1616
"classify": (image: blob) -> (ClassificationResult);
17+
"classify_query": (image: blob) -> (ClassificationResult) query;
1718
}

rust/image-classification/src/backend/src/lib.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
use std::cell::RefCell;
21
use candid::{CandidType, Deserialize};
3-
use ic_stable_structures::{memory_manager::{MemoryId, MemoryManager}, DefaultMemoryImpl};
2+
use ic_stable_structures::{
3+
memory_manager::{MemoryId, MemoryManager},
4+
DefaultMemoryImpl,
5+
};
6+
use std::cell::RefCell;
47

58
mod onnx;
69

@@ -31,7 +34,7 @@ enum ClassificationResult {
3134
Err(ClassificationError),
3235
}
3336

34-
#[ic_cdk::update]
37+
#[ic_cdk::query]
3538
fn classify(image: Vec<u8>) -> ClassificationResult {
3639
let result = match onnx::classify(image) {
3740
Ok(result) => ClassificationResult::Ok(result),
@@ -42,6 +45,17 @@ fn classify(image: Vec<u8>) -> ClassificationResult {
4245
result
4346
}
4447

48+
#[ic_cdk::query]
49+
fn classify_query(image: Vec<u8>) -> ClassificationResult {
50+
let result = match onnx::classify(image) {
51+
Ok(result) => ClassificationResult::Ok(result),
52+
Err(err) => ClassificationResult::Err(ClassificationError {
53+
message: err.to_string(),
54+
}),
55+
};
56+
result
57+
}
58+
4559
#[ic_cdk::init]
4660
fn init() {
4761
let wasi_memory = MEMORY_MANAGER.with(|m| m.borrow().get(WASI_MEMORY_ID));

rust/image-classification/src/declarations/backend/backend.did

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ type ClassificationResult = variant {
1414

1515
service : {
1616
"classify": (image: blob) -> (ClassificationResult);
17+
"classify_query": (image: blob) -> (ClassificationResult) query;
1718
}

rust/image-classification/src/declarations/backend/backend.did.d.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export type ClassificationResult = { 'Ok' : Array<Classification> } |
88
{ 'Err' : ClassificationError };
99
export interface _SERVICE {
1010
'classify' : ActorMethod<[Uint8Array | number[]], ClassificationResult>,
11+
'classify_query' : ActorMethod<[Uint8Array | number[]], ClassificationResult>,
1112
}
1213
export declare const idlFactory: IDL.InterfaceFactory;
1314
export declare const init: (args: { IDL: typeof IDL }) => IDL.Type[];

rust/image-classification/src/declarations/backend/backend.did.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ export const idlFactory = ({ IDL }) => {
1010
});
1111
return IDL.Service({
1212
'classify' : IDL.Func([IDL.Vec(IDL.Nat8)], [ClassificationResult], []),
13+
'classify_query' : IDL.Func(
14+
[IDL.Vec(IDL.Nat8)],
15+
[ClassificationResult],
16+
['query'],
17+
),
1318
});
1419
};
1520
export const init = ({ IDL }) => { return []; };

rust/image-classification/src/frontend/assets/main.css

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ textarea {
4747
flex-flow: row;
4848
justify-content: left;
4949
align-items: center;
50+
margin-bottom: 20px;
5051
}
5152

5253
.toggle-switch {
@@ -71,7 +72,6 @@ textarea {
7172
bottom: 0;
7273
background-color: #ccc;
7374
border-radius: 18px;
74-
transition: 0.4s;
7575
}
7676

7777
.slider:before {
@@ -83,7 +83,6 @@ textarea {
8383
bottom: 1px;
8484
background-color: white;
8585
border-radius: 50%;
86-
transition: 0.4s;
8786
}
8887

8988
input:checked+.slider {
@@ -129,15 +128,15 @@ li {
129128

130129
@keyframes astrodance {
131130
0% {
132-
transform: translate(-50%,-50%) rotate(-20deg)
131+
transform: translate(-50%, -50%) rotate(-20deg)
133132
}
134133

135134
50% {
136-
transform: translate(-50%,-50%) rotate(10deg)
135+
transform: translate(-50%, -50%) rotate(10deg)
137136
}
138137

139138
to {
140-
transform: translate(-50%,-50%) rotate(-20deg)
139+
transform: translate(-50%, -50%) rotate(-20deg)
141140
}
142141
}
143142

rust/image-classification/src/frontend/src/index.html

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ <h1>ICP image classification</h1>
2020
<input id="file" class="file" name="file" type="file" accept="image/png, image/jpeg" />
2121
<div id="container">
2222
<div id="message"></div>
23+
<div class="option invisible" id="replicated_option">
24+
<label>
25+
<div class="toggle-switch">
26+
<input type="checkbox" id="replicated">
27+
<span class="slider"></span>
28+
</div>
29+
&nbsp; replicated execution
30+
</label>
31+
</div>
2332
<img id="loader" src="loader.svg" class="loader invisible" />
2433
<button id="classify" class="clean-button invisible" disabled>Go!</button>
2534
</div>

0 commit comments

Comments
 (0)