-
Notifications
You must be signed in to change notification settings - Fork 501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added unet example #2786
base: main
Are you sure you want to change the base?
added unet example #2786
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[package] | ||
name = "unet" | ||
edition.workspace = true | ||
license.workspace = true | ||
version.workspace = true | ||
publish = false | ||
|
||
|
||
[dependencies] | ||
burn = { path = "../../crates/burn", features = ["train", "wgpu"] } | ||
image = { workspace = true } | ||
rand = { workspace = true } |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# Data Gathering and Preprocessing | ||
To train and use the unet example, data must be provided in the form of image pairs. | ||
|
||
The training data used in the example can be downloaded from: [kaggle - nikhilroxtomar - brain-tumor-segmentation ](https://www.kaggle.com/datasets/nikhilroxtomar/brain-tumor-segmentation/data). | ||
|
||
> [!NOTE] | ||
> Not all images are the same resolution. Most are 512 x 512 pixels. | ||
Comment on lines
+4
to
+7
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the sections below you describe how to resize the images, is this a required step for the data to be used in this example? Everything else works for 512 x 512 and I didn't download the data yet to see. If that's the case, I think we should probably provide the script to pre-process the images (instead of having it hidden in the DATA.md file). Even better would be to automatically download the data as we do for other examples but that's not a hard requirement. |
||
|
||
## Requirements | ||
* Source files should be in an `images` directory, while corresponding target files are in a `masks` directory. | ||
* All images have a corresponding mask with a consistent naming convention. | ||
* All images and masks have a uniform size, that is a constant and consistent width and height. | ||
* All masks have only two classes; class A is black (0,0,0) and class B is white (255,255,255). | ||
|
||
Recommended directory structure: | ||
|- unet | ||
| |- src | ||
| |- data | ||
| |- images | ||
| |- 1.png | ||
| |- 2.png | ||
| |- ... | ||
| |- masks | ||
| |- 1.png | ||
| |- 2.png | ||
| |- ... | ||
|
||
## Dataset, Batcher, and DataLoader | ||
The model takes a burn tensor with shape `[batch_size, channels, height, width]` as input and outputs a tensor with the shape `[batch_size, 1, height, width]`. The purpose of the data pipeline is to convert the images (png files) to tensors for training, testing, and validation. | ||
|
||
Set the constants in the [`brain_tumor_data.rs`](./src/brain_tumor_data.rs) to match the data: | ||
|
||
```rust | ||
// height and width of the images used in training | ||
pub const WIDTH: usize = 512; | ||
pub const HEIGHT: usize = 512; | ||
const TRAINING_DATA_DIRECTORY_STR: &str = "data"; | ||
``` | ||
|
||
To confirm the batching strategy, you can test the dataloader using: | ||
|
||
```rust | ||
use burn::data::dataloader::DataLoaderBuilder; | ||
use unet::segmentation_data::BrainTumorBatcher; | ||
use burn_unet_example::segmentation_data::BrainTumorDataset; | ||
|
||
type MyBackend = NdArray<f32, i64, i8>; | ||
let batcher_train = BrainTumorBatcher::<MyBackend>::new(device.clone()); | ||
let train_dataset: BrainTumorDataset = | ||
BrainTumorDataset::train().expect("Failed to build training dataset"); | ||
let dataloader_train = DataLoaderBuilder::new(batcher_train) | ||
.batch_size(2) | ||
.shuffle(42) | ||
.num_workers(1) | ||
.build(train_dataset); | ||
println!("{}", dataloader_train.num_items()); | ||
|
||
for batch in dataloader_train.iter() { | ||
println!("training ..."); | ||
println!("{:?}", batch.source_tensor.shape()); | ||
} | ||
``` | ||
|
||
## Additional notes | ||
|
||
#### Image to Tensor conversion | ||
The following script illustrates how to convert an image file to a `burn::tensor::Tensor` with shape `[channels, height, width]` (aligns with Conv2D input). | ||
|
||
```rust | ||
// two ways to create tensors from images with shape [channel, height, width] (matching the expected Conv2d) | ||
let p: std::path::PathBuf = std::path::Path::new("data_down4_n500_wh32") | ||
.join("images") | ||
.join("1_w32_h32.png"); | ||
let img: DynamicImage = image::open(p).expect("Failed to open image"); | ||
let width: usize = img.width() as usize; | ||
let height: usize = img.height() as usize; | ||
let mut v: Vec<f32> = Vec::<f32>::with_capacity(width * height * 3); | ||
let rgb_img = match img { | ||
DynamicImage::ImageRgb8(rgb_img) => rgb_img.clone(), | ||
_ => img.to_rgb8(), | ||
}; | ||
// Iterate over pixels and fill the array in parallel -> WRONG ORDER!!!!!! | ||
//--- rgb_img | ||
//--- .enumerate_pixels() | ||
//--- .into_iter() | ||
//--- .for_each(|(x, y, pixel)| { | ||
//--- // Convert each channel from u8 (0-255) to f32 (0.0-1.0) | ||
//--- v.push(pixel[0] as f32 / 255.0); | ||
//--- v.push(pixel[1] as f32 / 255.0); | ||
//--- v.push(pixel[2] as f32 / 255.0); | ||
//--- }); | ||
// println!("{:?}", v); | ||
for x in 0..32 { | ||
for y in 0..32 { | ||
let pixel = rgb_img.get_pixel(x, y); | ||
v.push(pixel[0] as f32 / 255.0); | ||
v.push(pixel[1] as f32 / 255.0); | ||
v.push(pixel[2] as f32 / 255.0); | ||
} | ||
} | ||
let a: Box<[f32]> = v.into_boxed_slice(); | ||
let d = TensorData::from(&*a); | ||
let u: Tensor<MyBackend, 1> = Tensor::from_data(d, &device); | ||
let u1 = u.reshape([width, height, 3]).swap_dims(0, 2); | ||
println!("{:}", u1); | ||
``` | ||
|
||
For small images, you can use an array, but keep in mind this is created on the stack and cause a stack overflow. | ||
|
||
```rust | ||
let mut r = [[[0.0f32; 32]; 32]; 3]; // large images will cause a stack-overflow here! | ||
for x in 0..32 { | ||
for y in 0..32 { | ||
let pixel = rgb_img.get_pixel(x, y); | ||
let xx = x as usize; | ||
let yy = y as usize; | ||
// Convert each channel from u8 (0-255) to f32 (0.0-1.0) | ||
r[0][yy][xx] = pixel[0] as f32 / 255.0; | ||
r[1][yy][xx] = pixel[1] as f32 / 255.0; | ||
r[2][yy][xx] = pixel[2] as f32 / 255.0; | ||
} | ||
} | ||
|
||
let d = TensorData::from(r); | ||
let u2: Tensor<MyBackend, 3> = Tensor::from_data(d, &device); | ||
println!("{:}", u2); | ||
``` | ||
|
||
To confirm both constructors are identical: | ||
|
||
```rust | ||
let close_enough = u1.all_close(u2, Some(1e-5), Some(1e-8)); | ||
println!("{:}", close_enough); | ||
``` | ||
Comment on lines
+67
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is more of a learning note than anything, given that the dataset already includes these steps? |
||
|
||
#### Downsampling | ||
If memory is an issue, smaller input images will result in a smaller model. Downsample the images and masks with: | ||
|
||
```rust | ||
pub fn downsample_images( | ||
input_dir: &Path, | ||
output_dir: &Path, | ||
downscaling_factor: u32, | ||
) -> Result<(), Error> { | ||
// Create output directory if it doesn't exist | ||
fs::create_dir_all(output_dir)?; | ||
|
||
let factor: u32 = 1u32 << downscaling_factor; | ||
|
||
// Iterate over the entries in the directory | ||
for entry in fs::read_dir(input_dir)? { | ||
let entry = entry?; | ||
let path = entry.path(); | ||
|
||
// Check if the file has a .png extension | ||
if path.extension().and_then(|ext| ext.to_str()) == Some("png") { | ||
// Open the image file | ||
//let img: DynamicImage = ImageReader::open(&path)?.decode()?; | ||
let img: DynamicImage = ImageReader::open(&path) | ||
.map_err(|e| Error::new(ErrorKind::Other, e))? | ||
.decode() | ||
.map_err(|e| Error::new(ErrorKind::Other, e))?; | ||
|
||
// Downsample the image | ||
let downsampled_image = img.resize( | ||
img.width() / factor, | ||
img.height() / factor, | ||
FilterType::Nearest, | ||
); | ||
|
||
// Get new dimensions | ||
let (new_width, new_height) = downsampled_image.dimensions(); | ||
|
||
// Create new filename with dimensions | ||
let original_stem = path | ||
.file_stem() | ||
.and_then(|s| s.to_str()) | ||
.unwrap_or("unnamed"); | ||
|
||
let new_filename = format!("{}_w{}_h{}.png", original_stem, new_width, new_height); | ||
|
||
// Create full output path | ||
let output_path = output_dir.join(new_filename); | ||
|
||
// Save the downsampled image | ||
downsampled_image | ||
.save(output_path) | ||
.map_err(|e| Error::new(ErrorKind::Other, e))?; | ||
} | ||
} | ||
|
||
Ok(()) | ||
} | ||
``` | ||
|
||
For example, | ||
|
||
```rust | ||
let images_data_dir = Path::new("data").join("images"); | ||
let masks_data_dir = Path::new("data").join("masks"); | ||
let images_d4_data_dir = Path::new("data1").join("images"); | ||
let masks_d4_data_dir = Path::new("data1").join("masks"); | ||
|
||
let _ = utils::downsample_images(images_data_dir.as_ref(), images_d4_data_dir.as_ref(), 1); | ||
let _ = utils::downsample_images(masks_data_dir.as_ref(), masks_d4_data_dir.as_ref(), 1); | ||
``` | ||
|
||
_downscaling_factor_ mapping for input size 512 pixels: | ||
* 0: 512/1 -> 512 | ||
* 1: 512/2 -> 256 | ||
* 2: 512/4 -> 128 | ||
* 3: 512/8 -> 64 | ||
* 4: 512/16 -> 32 | ||
* 5: 512/32 -> 16 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
## Model Architecture Implementation | ||
### Key Components | ||
The implementation consists of several key components: | ||
|
||
1. **DoubleConv Block** | ||
* A basic building block that performs two consecutive convolutions | ||
* Each convolution is followed by batch normalization and ReLU activation | ||
* Used throughout the network for feature extraction | ||
|
||
2. **Down Block** | ||
* Handles downsampling in the encoder path | ||
* Combines max pooling with a DoubleConv block | ||
* Reduces spatial dimensions while increasing feature channels | ||
|
||
3. **Up Block** | ||
* Handles upsampling in the decoder path | ||
* Uses transposed convolution followed by a DoubleConv block | ||
* Increases spatial dimensions while decreasing feature channels | ||
|
||
4. **OutConv Block** | ||
* Final output layer | ||
* Single convolution with ReLU activation | ||
* Maps features to the desired number of output channels | ||
|
||
5. **UNet Structure** The main UNet struct combines these components in a classic U-shaped architecture: | ||
* Encoder path: Initial DoubleConv followed by 4 Down blocks | ||
* Decoder path: 4 Up blocks with skip connections | ||
* Final output layer: OutConv block | ||
|
||
### Encoder Path | ||
The encoder path consists of a series of `Down` modules that downscale the input image while increasing the number of channels. The `Down` modules apply max pooling with a stride of 2, followed by a convolutional block. | ||
|
||
* **inc:** The initial convolutional block that takes the input image and produces a feature map with 64 channels. | ||
* **down1, down2, down3, down4:** The subsequent `Down` modules that downscale the feature map while increasing the number of channels to 128, 256, 512, and 1024, respectively. | ||
|
||
### Decoder Path | ||
The decoder path consists of a series of Up modules that upsample the feature map while decreasing the number of channels. The `Up` modules apply transposed convolution with a stride of 2, followed by a convolutional block. | ||
|
||
* **up4, up3, up2, up1:** The `Up` modules that upsample the feature map while decreasing the number of channels to 512, 256, 128, and 64, respectively. | ||
|
||
### Final Output Layer | ||
The final output layer consists of an `OutConv` module that applies a convolutional operation with a kernel size of 1x1 to produce the final output. | ||
|
||
### Forward Pass | ||
The forward pass through the U-Net architecture involves the following steps: | ||
|
||
1. The input tensor `[batch_size, 3, height, width]` (3 channel images) is passed through the `inc` convolutional block to produce a feature map with 64 channels. | ||
2. The feature map is then passed through the `down1`, `down2`, `down3`, and `down4` modules progressively reducing spatial dimensions while increasing channels to produce a feature map with 1024 channels. | ||
3. The feature map is then passed through the `up4`, `up3`, `up2`, and `up1` modules progressively increasing spatial dimensions while decreasing channels to produce a feature map with 64 channels. | ||
4. The feature map is then passed through the `outc` convolutional block to produce the final output tensor `[batch_size, 1, height, width]` (single channel segmentation mask). | ||
|
||
### Skip Connections | ||
The implementation also includes concatenated skip connections between the encoder and decoder paths. The skip connections allow the model to preserve spatial information and improve the accuracy of the segmentation. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Modified UNet | ||
A burn implementation of a modified form of UNet (convolutional neural network) for image segmentation. This work is based on the original paper _U-Net: Convolutional Networks for Biomedical Image Segmentation_^1^ [arxiv](https://arxiv.org/abs/1505.04597). For more information about the model architecture in this implementation, see [MODEL ARCHITECTURE](./MODEL_ARCHITECTURE.md). | ||
|
||
## Usage | ||
laggui marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
> [!CAUTION] | ||
> Data is not provided for this example and it will not run. See [DATA](./DATA.md) for an explanation of the requirements and a link to sample data publicly available. | ||
|
||
### Training | ||
```sh | ||
cargo run --example unet_train --release | ||
``` | ||
|
||
### Inference | ||
```sh | ||
cargo run --example unet_infer --release | ||
``` | ||
|
||
|
||
--- | ||
References: | ||
|
||
Ronneberger, O., Fischer, P., Brox, T., 2015. U-net: Convolutional networks for biomedical image segmentation, in: International Conference on Medical image computing and computer-assisted intervention, Springer. pp. 234–241. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
use burn::backend::wgpu::Wgpu; | ||
use burn::backend::wgpu::WgpuDevice; | ||
use std::path::Path; | ||
|
||
fn main() { | ||
type MyBackend = Wgpu<f32, i32>; | ||
let device = WgpuDevice::default(); | ||
|
||
let training_artifact_dir = Path::new("artifacts"); | ||
let infer_artifact_dir = Path::new("inferred_segmentations"); | ||
let _ = unet::infer::infer::<MyBackend>( | ||
training_artifact_dir, | ||
infer_artifact_dir, | ||
&device, | ||
Path::new("data").join("images").join("1.png").as_path(), | ||
); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
use std::path::Path; | ||
|
||
use burn::backend::wgpu::Wgpu; | ||
use burn::backend::wgpu::WgpuDevice; | ||
use burn::backend::Autodiff; | ||
use burn::optim::AdamConfig; | ||
use unet::training::UNetTrainingConfig; | ||
use unet::unet_model::UNetConfig; | ||
|
||
fn main() { | ||
type MyBackend = Wgpu<f32, i32>; | ||
let device = WgpuDevice::default(); | ||
|
||
// Model training | ||
let training_artifact_dir = Path::new("artifacts"); | ||
let model_config = UNetConfig::new(); | ||
let optimizer_config = AdamConfig::new() | ||
.with_beta_1(0.9) | ||
.with_beta_2(0.999) | ||
.with_epsilon(1e-8); | ||
let training_config = UNetTrainingConfig::new(model_config, optimizer_config) | ||
.with_num_epochs(1) | ||
.with_batch_size(4) | ||
.with_num_workers(1) | ||
.with_seed(42) | ||
.with_learning_rate(1e-3); | ||
|
||
unet::training::train::<Autodiff<MyBackend>>(training_artifact_dir, training_config, &device); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want you can add your username to the authors field since this is basically your contribution 🙂