-
Notifications
You must be signed in to change notification settings - Fork 501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added unet example #2786
base: main
Are you sure you want to change the base?
added unet example #2786
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2786 +/- ##
==========================================
- Coverage 83.64% 81.69% -1.95%
==========================================
Files 826 851 +25
Lines 108746 113982 +5236
==========================================
+ Hits 90956 93122 +2166
- Misses 17790 20860 +3070 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome 😄 thanks for the addition!
I have a couple of comments below
pub struct BrainTumorDataset { | ||
dataset: InMemDataset<BrainTumorItem>, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was sure that you were going to use the segmentation dataset you added in #2426 😅
Any particular reason why you didn't? Looks like it should be compatible 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few reasons for implementing the dataset in this example, but I plan to test and build more with the ImageFolderDataset in the future.
- It was easier to convert PyTorch code using input and output tensors.
- I'm still learning rust (15 years python; 3 years c++ in undergrad and grad school). My background is in geophysics, not computer science.
- This approach allowed me to dive a little deeper into how datasets are built and how the data loader is working. There are still some traits and advanced rust features to explore and learn here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All very good reasons! Was mostly curious since you were the one to add support hehe
We can definitely leave it as is and maybe update it in a follow-up PR if you're up for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! By curiosity, do you have any sample predictions? Haven't tried the example locally yet.
Also, just some minor questions/comments. Otherwise implementation LGTM 🙂
@@ -0,0 +1,12 @@ | |||
[package] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want you can add your username to the authors field since this is basically your contribution 🙂
The training data used in the example can be downloaded from: [kaggle - nikhilroxtomar - brain-tumor-segmentation ](https://www.kaggle.com/datasets/nikhilroxtomar/brain-tumor-segmentation/data). | ||
|
||
> [!NOTE] | ||
> Not all images are the same resolution. Most are 512 x 512 pixels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the sections below you describe how to resize the images, is this a required step for the data to be used in this example?
Everything else works for 512 x 512 and I didn't download the data yet to see.
If that's the case, I think we should probably provide the script to pre-process the images (instead of having it hidden in the DATA.md file). Even better would be to automatically download the data as we do for other examples but that's not a hard requirement.
The following script illustrates how to convert an image file to a `burn::tensor::Tensor` with shape `[channels, height, width]` (aligns with Conv2D input). | ||
|
||
```rust | ||
// two ways to create tensors from images with shape [channel, height, width] (matching the expected Conv2d) | ||
let p: std::path::PathBuf = std::path::Path::new("data_down4_n500_wh32") | ||
.join("images") | ||
.join("1_w32_h32.png"); | ||
let img: DynamicImage = image::open(p).expect("Failed to open image"); | ||
let width: usize = img.width() as usize; | ||
let height: usize = img.height() as usize; | ||
let mut v: Vec<f32> = Vec::<f32>::with_capacity(width * height * 3); | ||
let rgb_img = match img { | ||
DynamicImage::ImageRgb8(rgb_img) => rgb_img.clone(), | ||
_ => img.to_rgb8(), | ||
}; | ||
// Iterate over pixels and fill the array in parallel -> WRONG ORDER!!!!!! | ||
//--- rgb_img | ||
//--- .enumerate_pixels() | ||
//--- .into_iter() | ||
//--- .for_each(|(x, y, pixel)| { | ||
//--- // Convert each channel from u8 (0-255) to f32 (0.0-1.0) | ||
//--- v.push(pixel[0] as f32 / 255.0); | ||
//--- v.push(pixel[1] as f32 / 255.0); | ||
//--- v.push(pixel[2] as f32 / 255.0); | ||
//--- }); | ||
// println!("{:?}", v); | ||
for x in 0..32 { | ||
for y in 0..32 { | ||
let pixel = rgb_img.get_pixel(x, y); | ||
v.push(pixel[0] as f32 / 255.0); | ||
v.push(pixel[1] as f32 / 255.0); | ||
v.push(pixel[2] as f32 / 255.0); | ||
} | ||
} | ||
let a: Box<[f32]> = v.into_boxed_slice(); | ||
let d = TensorData::from(&*a); | ||
let u: Tensor<MyBackend, 1> = Tensor::from_data(d, &device); | ||
let u1 = u.reshape([width, height, 3]).swap_dims(0, 2); | ||
println!("{:}", u1); | ||
``` | ||
|
||
For small images, you can use an array, but keep in mind this is created on the stack and cause a stack overflow. | ||
|
||
```rust | ||
let mut r = [[[0.0f32; 32]; 32]; 3]; // large images will cause a stack-overflow here! | ||
for x in 0..32 { | ||
for y in 0..32 { | ||
let pixel = rgb_img.get_pixel(x, y); | ||
let xx = x as usize; | ||
let yy = y as usize; | ||
// Convert each channel from u8 (0-255) to f32 (0.0-1.0) | ||
r[0][yy][xx] = pixel[0] as f32 / 255.0; | ||
r[1][yy][xx] = pixel[1] as f32 / 255.0; | ||
r[2][yy][xx] = pixel[2] as f32 / 255.0; | ||
} | ||
} | ||
|
||
let d = TensorData::from(r); | ||
let u2: Tensor<MyBackend, 3> = Tensor::from_data(d, &device); | ||
println!("{:}", u2); | ||
``` | ||
|
||
To confirm both constructors are identical: | ||
|
||
```rust | ||
let close_enough = u1.all_close(u2, Some(1e-5), Some(1e-8)); | ||
println!("{:}", close_enough); | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is more of a learning note than anything, given that the dataset already includes these steps?
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
This is just an example for the community. No issues were created.
Changes
New example added illustrating how to create and train a UNet architecture for image segmentation.
Testing
No changes to the library code itself; no tests added.