-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
create a script to train autoencoderkl #10605
Conversation
Hello, thanks so much for your contributions. Could you perhaps provide some decent results you obtained with the training script? Could you also help explain the main differences between this and the training script we have for |
@ariG23498 you might be interested in following this PR as you were looking for this for a while. |
This code is inspired by https://github.com/CompVis/latent-diffusion https://github.com/Stability-AI/generative-models #894 #3801 and aims to provide a streamlined approach for fine-tuning or training VAEs using diffusers with minimal code. The key distinction between this implementation and VQGAN is the incorporation of KL loss for VAE training, along with support for training the decoder independently. |
The above results were obtained using this script to train an SD-VAE from scratch on ImageNet, with only 1,000 steps of training completed. |
I apologize that I am currently unable to provide more thoroughly trained results. Experiments have shown that the training speed of this script is relatively slow, and the trained VAE model can only offer basic image reconstruction at this stage. Further experiments are yet to be conducted. |
Thanks for this and the results aren't underwhelming at all! Perhaps we could place the project under "research_projects" for now and expedite the merging? And then once you have more time to conduct experiments, we could bring it back to WDYT? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is already very high-quality. I left some comments. LMK if they make sense.
examples/autoencoderkl/README.md
Outdated
accelerate config | ||
``` | ||
|
||
## Training on ImageNet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's provide a smaller dataset here in the example.
examples/autoencoderkl/README.md
Outdated
## Training on ImageNet | ||
|
||
```bash | ||
accelerate launch --multi_gpu --num_processes 4 --mixed_precision bf16 train_autoencoderkl.py \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it for a single GPU and then make a note about multi-GPU later.
examples/autoencoderkl/README.md
Outdated
--report_to wandb \ | ||
--mixed_precision bf16 \ | ||
--train_data_dir /path/to/ImageNet/train \ | ||
--validation_image ./image.png \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does it come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation images are randomly selected from the ImageNet validation set, consisting of eight images. Here, they are simply represented as an abstract ./image.png for illustrative purposes.
import wandb | ||
|
||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | ||
check_min_version("0.30.0.dev0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_min_version("0.30.0.dev0") | |
check_min_version("0.33.0.dev0") |
@@ -0,0 +1,1042 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add licensing.
logger = get_logger(__name__) | ||
|
||
|
||
def image_grid(imgs, rows, cols): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make use of from diffusers.utils import make_image_grid()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will check it.
tracker.log( | ||
{ | ||
"Original (left), Reconstruction (right)": [ | ||
wandb.Image(torchvision.utils.make_grid(image)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're using torchvision.utils.make_grid()
do we still need to have our own grid utility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code comes from #3801. Next, I will review the code to ensure that modules and functions are fully utilized.
).repo_id | ||
|
||
# Load AutoencoderKL | ||
if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretrained_model_name_or_path
should be enough to do the init, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretrained_model_name_or_path
allows people to fine-tune an existing VAE, while model_config_name_or_path
enables users to configure the parameters of the VAE through a config.json
file for training from scratch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If both are passed, we should error out, no? If we're not already doing that, let's maybe add a check in parse_args()
?
targets = batch["pixel_values"].to(dtype=weight_dtype) | ||
if accelerator.num_processes > 1: | ||
posterior = vae.module.encode(targets).latent_dist | ||
else: | ||
posterior = vae.encode(targets).latent_dist | ||
latents = posterior.sample() | ||
if accelerator.num_processes > 1: | ||
reconstructions = vae.module.decode(latents).sample | ||
else: | ||
reconstructions = vae.decode(latents).sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer using accelerator.unwrap_nodule(vae)
after computing targets
. In case of num_processes=1
, that would just become a no-op.
else: | ||
reconstructions = vae.decode(latents).sample | ||
|
||
if (step // args.gradient_accumulation_steps) % 2 == 0 or global_step < args.disc_start: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accelerate
should be able to take care of gradient accumulation.
Can we use that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will check it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section of code handles the optimizers for both the VAE and the discriminator. Gradient accumulation from accelerate has already been integrated.
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \ | ||
--dataset_name=cifar10 \ | ||
--image_column=img \ | ||
--validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users won't know about these images. Maybe add a note?
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | ||
# check_min_version("0.33.0.dev0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should always check this. In fact lets remove diffusers
from requirements.txt
and ask users to install diffusers
from source in the README.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. I forget to remove the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just two comments and then we should be good to go.
I'm not very clear about the pull request process. Does this mean the code has been successfully submitted? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@lavinal712 can you fix the code quality by running |
Now it is ok for me. |
Thanks for your contributions! |
So I have been working on creating a LDM from scratch (for teaching purposes) and used this code as a reference so thank you! I did have a question though, I notice that you dont have anything updating your discriminator? You compute the disc_loss like the following:
but no gradients were computed with this loss as far as I can tell? I reference here my own VAE Trainer which is giving good results so far on Imagenet, Conceptual Captions and CelebaHQ. They are still training, will take a week or so to get them all done, but its a start! In all these case, the PatchGAN discriminator has not kicked in yet, thatll happen later today or tomorrow once it reaches the iteration. The AutoEncoder being trained is very close to the AutoEncoderKL, just rewritten for my teaching purposes. I also have a training script for a VQVAE if that would be interesting to add in as well. |
Add AutoencoderKL Training Script
Description
This PR adds a complete training script for AutoencoderKL models. The script supports the following features:
Key Features
Core Training Pipeline
Loss Functions
Optimization & Performance
Monitoring & Validation
@sayakpaul