Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create a script to train autoencoderkl #10605

Merged
merged 15 commits into from
Jan 27, 2025
Merged

Conversation

lavinal712
Copy link
Contributor

Add AutoencoderKL Training Script

Description

This PR adds a complete training script for AutoencoderKL models. The script supports the following features:

  • Multiple loss functions (L1/L2 reconstruction loss, LPIPS perceptual loss, KL divergence)
  • Integrated adversarial training
  • Mixed precision training support
  • Comprehensive training logging and validation
  • Hugging Face Hub model upload
  • Detailed command-line parameter configuration

Key Features

Core Training Pipeline

  • Complete VAE training loop implementation
  • Multi-GPU distributed training support
  • Gradient accumulation and checkpoint saving

Loss Functions

  • Reconstruction loss (L1/L2)
  • LPIPS perceptual loss
  • KL divergence regularization
  • Adversarial loss

Optimization & Performance

  • 8-bit Adam optimizer support
  • Integrated xFormers memory optimization
  • TF32 acceleration support

Monitoring & Validation

  • TensorBoard and WandB logging
  • Periodic validation with sample image saving
  • Detailed training metrics monitoring

@sayakpaul

@sayakpaul
Copy link
Member

Hello, thanks so much for your contributions. Could you perhaps provide some decent results you obtained with the training script? Could you also help explain the main differences between this and the training script we have for vqgan?

@sayakpaul
Copy link
Member

@ariG23498 you might be interested in following this PR as you were looking for this for a while.

@lavinal712
Copy link
Contributor Author

This code is inspired by https://github.com/CompVis/latent-diffusion https://github.com/Stability-AI/generative-models #894 #3801 and aims to provide a streamlined approach for fine-tuning or training VAEs using diffusers with minimal code. The key distinction between this implementation and VQGAN is the incorporation of KL loss for VAE training, along with support for training the decoder independently.

@lavinal712
Copy link
Contributor Author

image
image
image
image
image

@lavinal712
Copy link
Contributor Author

The above results were obtained using this script to train an SD-VAE from scratch on ImageNet, with only 1,000 steps of training completed.

@lavinal712
Copy link
Contributor Author

I apologize that I am currently unable to provide more thoroughly trained results. Experiments have shown that the training speed of this script is relatively slow, and the trained VAE model can only offer basic image reconstruction at this stage. Further experiments are yet to be conducted.

@sayakpaul
Copy link
Member

Thanks for this and the results aren't underwhelming at all! Perhaps we could place the project under "research_projects" for now and expedite the merging? And then once you have more time to conduct experiments, we could bring it back to examples/?

WDYT?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is already very high-quality. I left some comments. LMK if they make sense.

accelerate config
```

## Training on ImageNet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's provide a smaller dataset here in the example.

## Training on ImageNet

```bash
accelerate launch --multi_gpu --num_processes 4 --mixed_precision bf16 train_autoencoderkl.py \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it for a single GPU and then make a note about multi-GPU later.

--report_to wandb \
--mixed_precision bf16 \
--train_data_dir /path/to/ImageNet/train \
--validation_image ./image.png \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does it come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation images are randomly selected from the ImageNet validation set, consisting of eight images. Here, they are simply represented as an abstract ./image.png for illustrative purposes.

import wandb

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
check_min_version("0.30.0.dev0")
check_min_version("0.33.0.dev0")

@@ -0,0 +1,1042 @@
import argparse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add licensing.

logger = get_logger(__name__)


def image_grid(imgs, rows, cols):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make use of from diffusers.utils import make_image_grid()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check it.

tracker.log(
{
"Original (left), Reconstruction (right)": [
wandb.Image(torchvision.utils.make_grid(image))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using torchvision.utils.make_grid() do we still need to have our own grid utility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code comes from #3801. Next, I will review the code to ensure that modules and functions are fully utilized.

).repo_id

# Load AutoencoderKL
if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretrained_model_name_or_path should be enough to do the init, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretrained_model_name_or_path allows people to fine-tune an existing VAE, while model_config_name_or_path enables users to configure the parameters of the VAE through a config.json file for training from scratch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If both are passed, we should error out, no? If we're not already doing that, let's maybe add a check in parse_args()?

Comment on lines 877 to 886
targets = batch["pixel_values"].to(dtype=weight_dtype)
if accelerator.num_processes > 1:
posterior = vae.module.encode(targets).latent_dist
else:
posterior = vae.encode(targets).latent_dist
latents = posterior.sample()
if accelerator.num_processes > 1:
reconstructions = vae.module.decode(latents).sample
else:
reconstructions = vae.decode(latents).sample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer using accelerator.unwrap_nodule(vae) after computing targets. In case of num_processes=1, that would just become a no-op.

else:
reconstructions = vae.decode(latents).sample

if (step // args.gradient_accumulation_steps) % 2 == 0 or global_step < args.disc_start:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accelerate should be able to take care of gradient accumulation.

Can we use that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section of code handles the optimizers for both the VAE and the discriminator. Gradient accumulation from accelerate has already been integrated.

--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
--dataset_name=cifar10 \
--image_column=img \
--validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users won't know about these images. Maybe add a note?

Comment on lines 60 to 61
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# check_min_version("0.33.0.dev0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should always check this. In fact lets remove diffusers from requirements.txt and ask users to install diffusers from source in the README.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. I forget to remove the comment.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just two comments and then we should be good to go.

@lavinal712
Copy link
Contributor Author

I'm not very clear about the pull request process. Does this mean the code has been successfully submitted?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

@lavinal712 can you fix the code quality by running make style && make quality?

@lavinal712
Copy link
Contributor Author

@lavinal712 can you fix the code quality by running make style && make quality?

Now it is ok for me.

@sayakpaul sayakpaul merged commit 4fa2459 into huggingface:main Jan 27, 2025
9 checks passed
@sayakpaul
Copy link
Member

Thanks for your contributions!

@priyammaz
Copy link

priyammaz commented Feb 19, 2025

So I have been working on creating a LDM from scratch (for teaching purposes) and used this code as a reference so thank you! I did have a question though, I notice that you dont have anything updating your discriminator? You compute the disc_loss like the following:

disc_loss = disc_factor * disc_loss(logits_real, logits_fake)

but no gradients were computed with this loss as far as I can tell?

I reference here my own VAE Trainer which is giving good results so far on Imagenet, Conceptual Captions and CelebaHQ. They are still training, will take a week or so to get them all done, but its a start!

iteration_25000
iteration_30750
iteration_18750

In all these case, the PatchGAN discriminator has not kicked in yet, thatll happen later today or tomorrow once it reaches the iteration.

The AutoEncoder being trained is very close to the AutoEncoderKL, just rewritten for my teaching purposes.

I also have a training script for a VQVAE if that would be interesting to add in as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants