Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added unet example #2786

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

anthonytorlucci
Copy link
Contributor

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

This is just an example for the community. No issues were created.

Changes

New example added illustrating how to create and train a UNet architecture for image segmentation.

Testing

No changes to the library code itself; no tests added.

Copy link

codecov bot commented Feb 8, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.69%. Comparing base (b9653f5) to head (76a4a88).
Report is 24 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2786      +/-   ##
==========================================
- Coverage   83.64%   81.69%   -1.95%     
==========================================
  Files         826      851      +25     
  Lines      108746   113982    +5236     
==========================================
+ Hits        90956    93122    +2166     
- Misses      17790    20860    +3070     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome 😄 thanks for the addition!

I have a couple of comments below

examples/unet/README.md Show resolved Hide resolved
examples/unet/src/lib.rs Outdated Show resolved Hide resolved
examples/unet/src/training.rs Outdated Show resolved Hide resolved
examples/unet/src/unet_model.rs Outdated Show resolved Hide resolved
examples/unet/src/unet_model.rs Outdated Show resolved Hide resolved
examples/unet/src/brain_tumor_data.rs Outdated Show resolved Hide resolved
Comment on lines +254 to +256
pub struct BrainTumorDataset {
dataset: InMemDataset<BrainTumorItem>,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was sure that you were going to use the segmentation dataset you added in #2426 😅

Any particular reason why you didn't? Looks like it should be compatible 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few reasons for implementing the dataset in this example, but I plan to test and build more with the ImageFolderDataset in the future.

  1. It was easier to convert PyTorch code using input and output tensors.
  2. I'm still learning rust (15 years python; 3 years c++ in undergrad and grad school). My background is in geophysics, not computer science.
  3. This approach allowed me to dive a little deeper into how datasets are built and how the data loader is working. There are still some traits and advanced rust features to explore and learn here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All very good reasons! Was mostly curious since you were the one to add support hehe

We can definitely leave it as is and maybe update it in a follow-up PR if you're up for it.

examples/unet/src/infer.rs Outdated Show resolved Hide resolved
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! By curiosity, do you have any sample predictions? Haven't tried the example locally yet.

Also, just some minor questions/comments. Otherwise implementation LGTM 🙂

@@ -0,0 +1,12 @@
[package]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want you can add your username to the authors field since this is basically your contribution 🙂

Comment on lines +4 to +7
The training data used in the example can be downloaded from: [kaggle - nikhilroxtomar - brain-tumor-segmentation ](https://www.kaggle.com/datasets/nikhilroxtomar/brain-tumor-segmentation/data).

> [!NOTE]
> Not all images are the same resolution. Most are 512 x 512 pixels.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the sections below you describe how to resize the images, is this a required step for the data to be used in this example?

Everything else works for 512 x 512 and I didn't download the data yet to see.

If that's the case, I think we should probably provide the script to pre-process the images (instead of having it hidden in the DATA.md file). Even better would be to automatically download the data as we do for other examples but that's not a hard requirement.

Comment on lines +67 to +134
The following script illustrates how to convert an image file to a `burn::tensor::Tensor` with shape `[channels, height, width]` (aligns with Conv2D input).

```rust
// two ways to create tensors from images with shape [channel, height, width] (matching the expected Conv2d)
let p: std::path::PathBuf = std::path::Path::new("data_down4_n500_wh32")
.join("images")
.join("1_w32_h32.png");
let img: DynamicImage = image::open(p).expect("Failed to open image");
let width: usize = img.width() as usize;
let height: usize = img.height() as usize;
let mut v: Vec<f32> = Vec::<f32>::with_capacity(width * height * 3);
let rgb_img = match img {
DynamicImage::ImageRgb8(rgb_img) => rgb_img.clone(),
_ => img.to_rgb8(),
};
// Iterate over pixels and fill the array in parallel -> WRONG ORDER!!!!!!
//--- rgb_img
//--- .enumerate_pixels()
//--- .into_iter()
//--- .for_each(|(x, y, pixel)| {
//--- // Convert each channel from u8 (0-255) to f32 (0.0-1.0)
//--- v.push(pixel[0] as f32 / 255.0);
//--- v.push(pixel[1] as f32 / 255.0);
//--- v.push(pixel[2] as f32 / 255.0);
//--- });
// println!("{:?}", v);
for x in 0..32 {
for y in 0..32 {
let pixel = rgb_img.get_pixel(x, y);
v.push(pixel[0] as f32 / 255.0);
v.push(pixel[1] as f32 / 255.0);
v.push(pixel[2] as f32 / 255.0);
}
}
let a: Box<[f32]> = v.into_boxed_slice();
let d = TensorData::from(&*a);
let u: Tensor<MyBackend, 1> = Tensor::from_data(d, &device);
let u1 = u.reshape([width, height, 3]).swap_dims(0, 2);
println!("{:}", u1);
```

For small images, you can use an array, but keep in mind this is created on the stack and cause a stack overflow.

```rust
let mut r = [[[0.0f32; 32]; 32]; 3]; // large images will cause a stack-overflow here!
for x in 0..32 {
for y in 0..32 {
let pixel = rgb_img.get_pixel(x, y);
let xx = x as usize;
let yy = y as usize;
// Convert each channel from u8 (0-255) to f32 (0.0-1.0)
r[0][yy][xx] = pixel[0] as f32 / 255.0;
r[1][yy][xx] = pixel[1] as f32 / 255.0;
r[2][yy][xx] = pixel[2] as f32 / 255.0;
}
}

let d = TensorData::from(r);
let u2: Tensor<MyBackend, 3> = Tensor::from_data(d, &device);
println!("{:}", u2);
```

To confirm both constructors are identical:

```rust
let close_enough = u1.all_close(u2, Some(1e-5), Some(1e-8));
println!("{:}", close_enough);
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is more of a learning note than anything, given that the dataset already includes these steps?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants