From cf94b6a3595a5915024436fc5e21945bddbf3164 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sat, 18 Jan 2025 12:20:26 +0000 Subject: [PATCH 1/8] create a script to train vae --- examples/autoencoderkl/README.md | 42 + examples/autoencoderkl/requirements.txt | 16 + examples/autoencoderkl/train_autoencoderkl.py | 971 ++++++++++++++++++ 3 files changed, 1029 insertions(+) create mode 100644 examples/autoencoderkl/README.md create mode 100644 examples/autoencoderkl/requirements.txt create mode 100644 examples/autoencoderkl/train_autoencoderkl.py diff --git a/examples/autoencoderkl/README.md b/examples/autoencoderkl/README.md new file mode 100644 index 000000000000..a953bfe29ba5 --- /dev/null +++ b/examples/autoencoderkl/README.md @@ -0,0 +1,42 @@ +# AutoencoderKL training example + +## Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +## Training on ImageNet + +```bash +accelerate launch --multi_gpu --num_processes 4 --mixed_precision bf16 train_autoencoderkl.py \ + --pretrained_model_name_or_path stabilityai/sdxl-vae \ + --num_train_epochs 100 \ + --gradient_accumulation_steps 2 \ + --learning_rate 4.5e-6 \ + --lr_scheduler cosine \ + --report_to wandb \ + --mixed_precision bf16 \ + --train_data_dir /path/to/ImageNet/train \ + --validation_image ./image.png \ + --decoder_only +``` diff --git a/examples/autoencoderkl/requirements.txt b/examples/autoencoderkl/requirements.txt new file mode 100644 index 000000000000..0a0e4d91c295 --- /dev/null +++ b/examples/autoencoderkl/requirements.txt @@ -0,0 +1,16 @@ +accelerate>=0.16.0 +bitsandbytes +datasets +diffusers +huggingface_hub +lpips +numpy +packaging +Pillow +taming_transformers +torch +torchvision +tqdm +transformers +wandb +xformers diff --git a/examples/autoencoderkl/train_autoencoderkl.py b/examples/autoencoderkl/train_autoencoderkl.py new file mode 100644 index 000000000000..19d4a5b4407f --- /dev/null +++ b/examples/autoencoderkl/train_autoencoderkl.py @@ -0,0 +1,971 @@ +import argparse +import contextlib +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import numpy as np +import lpips +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from taming.modules.losses.vqperceptual import * +from torchvision import transforms +from tqdm.auto import tqdm + +import diffusers +from diffusers import AutoencoderKL +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.30.0.dev0") + +logger = get_logger(__name__) + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def log_validation( + vae, args, accelerator, weight_dtype, step, is_final_validation=False +): + logger.info("Running validation... ") + + if not is_final_validation: + vae = accelerator.unwrap_model(vae) + else: + vae = AutoencoderKL.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + + images = [] + inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + for i, validation_image in enumerate(args.validation_image): + validation_image = Image.open(validation_image).convert("RGB") + targets = image_transforms(validation_image).to(weight_dtype) + targets = targets.unsqueeze(0).to(vae.device) + + with inference_ctx: + reconstructions = vae(targets).sample + + images.append( + torch.cat([targets.cpu(), reconstructions.cpu()], axis=0) + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images( + "Original (left), Reconstruction (right)", np_images, step + ) + elif tracker.name == "wandb": + tracker.log( + { + "Original (left), Reconstruction (right)": [ + wandb.Image(torchvision.utils.make_grid(image)) + for _, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.gen_images}") + + gc.collect() + torch.cuda.empty_cache() + + return images + + +def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): + img_str = "" + if images is not None: + img_str = "You can find some example images below.\n\n" + image_grid(images, 1, "example").save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# autoencoderkl-{repo_id} + +These are autoencoderkl weights trained on {base_model} with new type of conditioning. +{img_str} +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion", + "stable-diffusion-diffusers", + "image-to-image", + "diffusers", + "autoencoderkl", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a AutoencoderKL training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="autoencoderkl-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=4.5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help="A set of paths to the image be evaluated every `--validation_steps` and logged to `--report_to`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_autoencoderkl", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument( + "--rec_loss", + type=str, + default="l2", + help="The loss function for VAE reconstruction loss.", + ) + parser.add_argument( + "--kl_scale", + type=float, + default=1e-6, + help="Scaling factor for the Kullback-Leibler divergence penalty term.", + ) + parser.add_argument( + "--lpips_scale", + type=float, + default=0.5, + help="Scaling factor for the LPIPS metric", + ) + parser.add_argument( + "--disc_start", + type=int, + default=50001, + help="Start for the discriminator", + ) + parser.add_argument( + "--disc_factor", + type=float, + default=1.0, + help="Scaling factor for the discriminator", + ) + parser.add_argument( + "--disc_scale", + type=float, + default=1.0, + help="Scaling factor for the discriminator", + ) + parser.add_argument( + "--decoder_only", + action="store_true", + help="Only train the VAE decoder.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the diffusion model." + ) + + return args + + +def make_train_dataset(args, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + images = [image_transforms(image) for image in images] + + examples["pixel_values"] = images + + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + return train_dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + return {"pixel_values": pixel_values} + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load AutoencoderKL + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision + ) + lpips_loss_fn = lpips.LPIPS(net="vgg") + discriminator = NLayerDiscriminator( + input_nc=3, n_layers=3, use_actnorm=False, + ).apply(weights_init) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + if isinstance(model, AutoencoderKL): + sub_dir = "autoencoderkl" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + else: + sub_dir = "discriminator" + os.makedirs(os.path.join(output_dir, sub_dir), exist_ok=True) + torch.save(model.state_dict(), os.path.join(output_dir, sub_dir, "pytorch_model.bin")) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="autoencoderkl") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + + vae.requires_grad_(True) + if args.decoder_only: + vae.encoder.requires_grad_(False) + if getattr(vae, "quant_conv", None): + vae.quant_conv.requires_grad_(False) + vae.train() + lpips_loss_fn.requires_grad_(False) + discriminator.requires_grad_(True) + discriminator.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + vae.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + vae.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = filter(lambda p: p.requires_grad, vae.parameters()) + params_to_optimize_2 = filter(lambda p: p.requires_grad, discriminator.parameters()) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + optimizer_2 = optimizer_class( + params_to_optimize_2, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + train_dataset = make_train_dataset(args, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + vae, discriminator, optimizer, optimizer_2, train_dataloader, lr_scheduler = accelerator.prepare( + vae, discriminator, optimizer, optimizer_2, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + lpips_loss_fn.to(accelerator.device, dtype=weight_dtype) + discriminator.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + # Convert images to latent space and reconstruct from them + targets = batch["pixel_values"].to(dtype=weight_dtype) + if accelerator.num_processes > 1: + posterior = vae.module.encode(targets).latent_dist + else: + posterior = vae.encode(targets).latent_dist + latents = posterior.sample() + if accelerator.num_processes > 1: + reconstructions = vae.module.decode(latents).sample + else: + reconstructions = vae.decode(latents).sample + + if (step // args.gradient_accumulation_steps) % 2 == 0 or global_step < args.disc_start: + with accelerator.accumulate(vae): + # reconstruction loss. Pixel level differences between input vs output + if args.rec_loss == "l2": + rec_loss = F.mse_loss(reconstructions.float(), targets.float(), reduction="none") + else: + rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction="none") + # perceptual loss. The high level feature mean squared error loss + with torch.no_grad(): + lpips_loss = lpips_loss_fn(reconstructions, targets) + + rec_loss = rec_loss + args.lpips_scale * lpips_loss + nll_loss = rec_loss + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + logits_fake = discriminator(reconstructions) + g_loss = -torch.mean(logits_fake) + last_layer = accelerator.unwrap_model(vae).decoder.conv_out.weight + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + disc_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + disc_weight = torch.clamp(disc_weight, 0.0, 1e4).detach() + disc_weight = disc_weight * args.disc_scale + disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0 + + loss = nll_loss + args.kl_scale * kl_loss + disc_weight * disc_factor * g_loss + + logs = { + "loss": loss.detach().mean().item(), + "nll_loss": nll_loss.detach().mean().item(), + "rec_loss": rec_loss.detach().mean().item(), + "lpips_loss": lpips_loss.detach().mean().item(), + "kl_loss": kl_loss.detach().mean().item(), + "disc_weight": disc_weight.detach().mean().item(), + "disc_factor": torch.tensor(disc_factor), + "g_loss": g_loss.detach().mean().item(), + "lr": lr_scheduler.get_last_lr()[0] + } + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = vae.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + else: + with accelerator.accumulate(discriminator): + logits_real = discriminator(targets) + logits_fake = discriminator(reconstructions) + disc_loss = hinge_d_loss + disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0 + disc_loss = disc_factor * disc_loss(logits_real, logits_fake) + logs = { + "disc_loss": disc_loss.detach().mean().item(), + "logits_real": logits_real.detach().mean().item(), + "logits_fake": logits_fake.detach().mean().item(), + } + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step == 1 or global_step % args.validation_steps == 0: + image_logs = log_validation( + vae, + args, + accelerator, + weight_dtype, + global_step, + ) + + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + vae = accelerator.unwrap_model(vae) + vae.save_pretrained(args.output_dir) + + # Run a final round of validation. + image_logs = None + image_logs = log_validation( + vae=vae, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From dd3a0a3408f1b6f5ebaf9589ff5a280d49e66136 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Mon, 20 Jan 2025 08:30:41 +0000 Subject: [PATCH 2/8] update main.py --- examples/autoencoderkl/train_autoencoderkl.py | 137 +++++++++++++----- 1 file changed, 104 insertions(+), 33 deletions(-) diff --git a/examples/autoencoderkl/train_autoencoderkl.py b/examples/autoencoderkl/train_autoencoderkl.py index 19d4a5b4407f..8ae936bfe7e1 100644 --- a/examples/autoencoderkl/train_autoencoderkl.py +++ b/examples/autoencoderkl/train_autoencoderkl.py @@ -23,13 +23,16 @@ from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image -from taming.modules.losses.vqperceptual import * +from taming.modules.losses.vqperceptual import ( + hinge_d_loss, vanilla_d_loss, weights_init, NLayerDiscriminator +) from torchvision import transforms from tqdm.auto import tqdm import diffusers from diffusers import AutoencoderKL from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available @@ -56,6 +59,7 @@ def image_grid(imgs, rows, cols): return grid +@torch.no_grad() def log_validation( vae, args, accelerator, weight_dtype, step, is_final_validation=False ): @@ -80,8 +84,8 @@ def log_validation( for i, validation_image in enumerate(args.validation_image): validation_image = Image.open(validation_image).convert("RGB") - targets = image_transforms(validation_image).to(weight_dtype) - targets = targets.unsqueeze(0).to(vae.device) + targets = image_transforms(validation_image).to(accelerator.device, weight_dtype) + targets = targets.unsqueeze(0) with inference_ctx: reconstructions = vae(targets).sample @@ -112,15 +116,15 @@ def log_validation( gc.collect() torch.cuda.empty_cache() - return images + return images def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): img_str = "" if images is not None: img_str = "You can find some example images below.\n\n" - image_grid(images, 1, "example").save(os.path.join(repo_folder, f"images_{i}.png")) - img_str += f"![images_{i})](./images_{i}.png)\n" + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images.png")) + img_str += f"![images](./images.png)\n" model_description = f""" # autoencoderkl-{repo_id} @@ -156,9 +160,14 @@ def parse_args(input_args=None): "--pretrained_model_name_or_path", type=str, default=None, - required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--model_config_name_or_path", + type=str, + default=None, + help="The config of the VAE model to train, leave as None to use standard VAE model configuration.", + ) parser.add_argument( "--revision", type=str, @@ -242,6 +251,12 @@ def parse_args(input_args=None): default=4.5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) + parser.add_argument( + "--disc_learning_rate", + type=float, + default=4.5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -257,6 +272,15 @@ def parse_args(input_args=None): ' "constant", "constant_with_warmup"]' ), ) + parser.add_argument( + "--disc_lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) @@ -270,6 +294,7 @@ def parse_args(input_args=None): parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument( "--dataloader_num_workers", type=int, @@ -417,7 +442,7 @@ def parse_args(input_args=None): help="Scaling factor for the Kullback-Leibler divergence penalty term.", ) parser.add_argument( - "--lpips_scale", + "--perceptual_scale", type=float, default=0.5, help="Scaling factor for the LPIPS metric", @@ -440,6 +465,12 @@ def parse_args(input_args=None): default=1.0, help="Scaling factor for the discriminator", ) + parser.add_argument( + "--disc_loss", + type=str, + default="hinge", + help="Loss function for the discriminator", + ) parser.add_argument( "--decoder_only", action="store_true", @@ -587,19 +618,28 @@ def main(args): ).repo_id # Load AutoencoderKL - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision - ) - lpips_loss_fn = lpips.LPIPS(net="vgg") - discriminator = NLayerDiscriminator( - input_nc=3, n_layers=3, use_actnorm=False, - ).apply(weights_init) + if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None: + config = AutoencoderKL.load_config("stabilityai/sd-vae-ft-mse") + vae = AutoencoderKL.from_config(config) + elif args.pretrained_model_name_or_path is not None: + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, revision=args.revision) + else: + config = AutoencoderKL.load_config(args.model_config_name_or_path) + vae = AutoencoderKL.from_config(config) + if args.use_ema: + ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config) + perceptual_loss = lpips.LPIPS(net="vgg").eval() + discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init) # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: + if args.use_ema: + sub_dir = "autoencoderkl_ema" + ema_vae.save_pretrained(os.path.join(output_dir, sub_dir)) + i = len(weights) - 1 while len(weights) > 0: @@ -618,13 +658,22 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): while len(models) > 0: + if args.use_ema: + sub_dir = "autoencoderkl_ema" + load_model = EMAModel.from_pretrained(os.path.join(input_dir, sub_dir), AutoencoderKL) + ema_vae.load_state_dict(load_model.state_dict()) + ema_vae.to(accelerator.device) + del load_model + # pop models so that they are not loaded again model = models.pop() - - # load diffusers style into model + load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(os.path.join(input_dir, "discriminator", "pytorch_model.bin")) + model.load_state_dict(load_model.state_dict()) + del load_model + + model = models.pop() load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="autoencoderkl") model.register_to_config(**load_model.config) - model.load_state_dict(load_model.state_dict()) del load_model @@ -638,7 +687,6 @@ def load_model_hook(models, input_dir): if getattr(vae, "quant_conv", None): vae.quant_conv.requires_grad_(False) vae.train() - lpips_loss_fn.requires_grad_(False) discriminator.requires_grad_(True) discriminator.train() @@ -688,7 +736,7 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW params_to_optimize = filter(lambda p: p.requires_grad, vae.parameters()) - params_to_optimize_2 = filter(lambda p: p.requires_grad, discriminator.parameters()) + disc_params_to_optimize = filter(lambda p: p.requires_grad, discriminator.parameters()) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -696,9 +744,9 @@ def load_model_hook(models, input_dir): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) - optimizer_2 = optimizer_class( - params_to_optimize_2, - lr=args.learning_rate, + disc_optimizer = optimizer_class( + disc_params_to_optimize, + lr=args.disc_learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, @@ -729,10 +777,18 @@ def load_model_hook(models, input_dir): num_cycles=args.lr_num_cycles, power=args.lr_power, ) + disc_lr_scheduler = get_scheduler( + args.disc_lr_scheduler, + optimizer=disc_optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) # Prepare everything with our `accelerator`. - vae, discriminator, optimizer, optimizer_2, train_dataloader, lr_scheduler = accelerator.prepare( - vae, discriminator, optimizer, optimizer_2, train_dataloader, lr_scheduler + vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler = accelerator.prepare( + vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler ) # For mixed precision training we cast the text_encoder and vae weights to half-precision @@ -743,10 +799,12 @@ def load_model_hook(models, input_dir): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move vae to device and cast to weight_dtype + # Move VAE, perceptual loss and discriminator to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) - lpips_loss_fn.to(accelerator.device, dtype=weight_dtype) + perceptual_loss.to(accelerator.device, dtype=weight_dtype) discriminator.to(accelerator.device, dtype=weight_dtype) + if args.use_ema: + ema_vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -812,6 +870,8 @@ def load_model_hook(models, input_dir): image_logs = None for epoch in range(first_epoch, args.num_train_epochs): + vae.train() + discriminator.train() for step, batch in enumerate(train_dataloader): # Convert images to latent space and reconstruct from them targets = batch["pixel_values"].to(dtype=weight_dtype) @@ -834,9 +894,9 @@ def load_model_hook(models, input_dir): rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction="none") # perceptual loss. The high level feature mean squared error loss with torch.no_grad(): - lpips_loss = lpips_loss_fn(reconstructions, targets) + p_loss = perceptual_loss(reconstructions, targets) - rec_loss = rec_loss + args.lpips_scale * lpips_loss + rec_loss = rec_loss + args.perceptual_scale * p_loss nll_loss = rec_loss nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] @@ -859,10 +919,10 @@ def load_model_hook(models, input_dir): "loss": loss.detach().mean().item(), "nll_loss": nll_loss.detach().mean().item(), "rec_loss": rec_loss.detach().mean().item(), - "lpips_loss": lpips_loss.detach().mean().item(), + "p_loss": p_loss.detach().mean().item(), "kl_loss": kl_loss.detach().mean().item(), "disc_weight": disc_weight.detach().mean().item(), - "disc_factor": torch.tensor(disc_factor), + "disc_factor": disc_factor, "g_loss": g_loss.detach().mean().item(), "lr": lr_scheduler.get_last_lr()[0] } @@ -878,18 +938,21 @@ def load_model_hook(models, input_dir): with accelerator.accumulate(discriminator): logits_real = discriminator(targets) logits_fake = discriminator(reconstructions) - disc_loss = hinge_d_loss + disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0 disc_loss = disc_factor * disc_loss(logits_real, logits_fake) logs = { "disc_loss": disc_loss.detach().mean().item(), "logits_real": logits_real.detach().mean().item(), "logits_fake": logits_fake.detach().mean().item(), + "disc_lr": disc_lr_scheduler.get_last_lr()[0] } # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 + if args.use_ema: + ema_vae.step(vae.parameters()) if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: @@ -918,6 +981,9 @@ def load_model_hook(models, input_dir): logger.info(f"Saved state to {save_path}") if global_step == 1 or global_step % args.validation_steps == 0: + if args.use_ema: + ema_vae.store(vae.parameters()) + ema_vae.copy_to(vae.parameters()) image_logs = log_validation( vae, args, @@ -925,6 +991,8 @@ def load_model_hook(models, input_dir): weight_dtype, global_step, ) + if args.use_ema: + ema_vae.restore(vae.parameters()) progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -936,8 +1004,11 @@ def load_model_hook(models, input_dir): accelerator.wait_for_everyone() if accelerator.is_main_process: vae = accelerator.unwrap_model(vae) + discriminator = accelerator.unwrap_model(discriminator) + if args.use_ema: + ema_vae.copy_to(vae.parameters()) vae.save_pretrained(args.output_dir) - + torch.save(discriminator.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin")) # Run a final round of validation. image_logs = None image_logs = log_validation( From cdee155232a8950305ec94d60eb19f004ed0f0c0 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 22 Jan 2025 01:17:33 +0000 Subject: [PATCH 3/8] update train_autoencoderkl.py --- .../autoencoderkl/README.md | 19 +++++++- .../autoencoderkl/requirements.txt | 0 .../autoencoderkl/train_autoencoderkl.py | 48 +++++++++---------- .../autoencoderkl/train_autoencoderkl.sh | 10 ++++ 4 files changed, 51 insertions(+), 26 deletions(-) rename examples/{ => research_projects}/autoencoderkl/README.md (67%) rename examples/{ => research_projects}/autoencoderkl/requirements.txt (100%) rename examples/{ => research_projects}/autoencoderkl/train_autoencoderkl.py (97%) create mode 100644 examples/research_projects/autoencoderkl/train_autoencoderkl.sh diff --git a/examples/autoencoderkl/README.md b/examples/research_projects/autoencoderkl/README.md similarity index 67% rename from examples/autoencoderkl/README.md rename to examples/research_projects/autoencoderkl/README.md index a953bfe29ba5..c0b26294422f 100644 --- a/examples/autoencoderkl/README.md +++ b/examples/research_projects/autoencoderkl/README.md @@ -25,11 +25,26 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e accelerate config ``` +## Training on CIFAR10 + +```bash +accelerate launch train_autoencoderkl.py \ + --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \ + --dataset_name=cifar10 \ + --image_column=img \ + --validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \ + --num_train_epochs 100 \ + --gradient_accumulation_steps 2 \ + --learning_rate 4.5e-6 \ + --lr_scheduler cosine \ + --report_to wandb \ +``` + ## Training on ImageNet ```bash -accelerate launch --multi_gpu --num_processes 4 --mixed_precision bf16 train_autoencoderkl.py \ - --pretrained_model_name_or_path stabilityai/sdxl-vae \ +accelerate launch train_autoencoderkl.py \ + --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \ --num_train_epochs 100 \ --gradient_accumulation_steps 2 \ --learning_rate 4.5e-6 \ diff --git a/examples/autoencoderkl/requirements.txt b/examples/research_projects/autoencoderkl/requirements.txt similarity index 100% rename from examples/autoencoderkl/requirements.txt rename to examples/research_projects/autoencoderkl/requirements.txt diff --git a/examples/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py similarity index 97% rename from examples/autoencoderkl/train_autoencoderkl.py rename to examples/research_projects/autoencoderkl/train_autoencoderkl.py index 8ae936bfe7e1..e6fd0382cf2a 100644 --- a/examples/autoencoderkl/train_autoencoderkl.py +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + import argparse import contextlib import gc @@ -33,7 +48,7 @@ from diffusers import AutoencoderKL from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -43,22 +58,11 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.30.0.dev0") +# check_min_version("0.33.0.dev0") logger = get_logger(__name__) -def image_grid(imgs, rows, cols): - assert len(imgs) == rows * cols - - w, h = imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid - - @torch.no_grad() def log_validation( vae, args, accelerator, weight_dtype, step, is_final_validation=False @@ -111,7 +115,7 @@ def log_validation( } ) else: - logger.warn(f"image logging not implemented for {tracker.gen_images}") + logger.warn(f"image logging not implemented for {tracker.name}") gc.collect() torch.cuda.empty_cache() @@ -123,7 +127,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None) img_str = "" if images is not None: img_str = "You can find some example images below.\n\n" - image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images.png")) + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images.png")) img_str += f"![images](./images.png)\n" model_description = f""" @@ -875,23 +879,19 @@ def load_model_hook(models, input_dir): for step, batch in enumerate(train_dataloader): # Convert images to latent space and reconstruct from them targets = batch["pixel_values"].to(dtype=weight_dtype) - if accelerator.num_processes > 1: - posterior = vae.module.encode(targets).latent_dist - else: - posterior = vae.encode(targets).latent_dist + posterior = accelerator.unwrap_model(vae).encode(targets).latent_dist latents = posterior.sample() - if accelerator.num_processes > 1: - reconstructions = vae.module.decode(latents).sample - else: - reconstructions = vae.decode(latents).sample + reconstructions = accelerator.unwrap_model(vae).decode(latents).sample if (step // args.gradient_accumulation_steps) % 2 == 0 or global_step < args.disc_start: with accelerator.accumulate(vae): # reconstruction loss. Pixel level differences between input vs output if args.rec_loss == "l2": rec_loss = F.mse_loss(reconstructions.float(), targets.float(), reduction="none") - else: + elif args.rec_loss == "l1": rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction="none") + else: + raise ValueError(f"Invalid reconstruction loss type: {args.rec_loss}") # perceptual loss. The high level feature mean squared error loss with torch.no_grad(): p_loss = perceptual_loss(reconstructions, targets) diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.sh b/examples/research_projects/autoencoderkl/train_autoencoderkl.sh new file mode 100644 index 000000000000..8c42245cc90f --- /dev/null +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.sh @@ -0,0 +1,10 @@ +accelerate launch train_autoencoderkl.py \ + --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \ + --dataset_name=cifar10 \ + --image_column=img \ + --validation_image /home/azureuser/v-yuqianhong/ImageNet/ILSVRC2012/val/n01491361/ILSVRC2012_val_00002922.JPEG \ + --num_train_epochs 100 \ + --gradient_accumulation_steps 2 \ + --learning_rate 4.5e-6 \ + --lr_scheduler cosine \ + --report_to wandb \ \ No newline at end of file From f37ffa7930fc7240973052ce79a5c959bbc10ae8 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 22 Jan 2025 01:18:13 +0000 Subject: [PATCH 4/8] update train_autoencoderkl.py --- .../autoencoderkl/train_autoencoderkl.sh | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 examples/research_projects/autoencoderkl/train_autoencoderkl.sh diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.sh b/examples/research_projects/autoencoderkl/train_autoencoderkl.sh deleted file mode 100644 index 8c42245cc90f..000000000000 --- a/examples/research_projects/autoencoderkl/train_autoencoderkl.sh +++ /dev/null @@ -1,10 +0,0 @@ -accelerate launch train_autoencoderkl.py \ - --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \ - --dataset_name=cifar10 \ - --image_column=img \ - --validation_image /home/azureuser/v-yuqianhong/ImageNet/ILSVRC2012/val/n01491361/ILSVRC2012_val_00002922.JPEG \ - --num_train_epochs 100 \ - --gradient_accumulation_steps 2 \ - --learning_rate 4.5e-6 \ - --lr_scheduler cosine \ - --report_to wandb \ \ No newline at end of file From dafd929456f0ce5366fa107051a083907f07624e Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 22 Jan 2025 01:35:26 +0000 Subject: [PATCH 5/8] add a check of --pretrained_model_name_or_path and --model_config_name_or_path --- .../research_projects/autoencoderkl/train_autoencoderkl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py index e6fd0382cf2a..c1d609aeebf4 100644 --- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -486,6 +486,9 @@ def parse_args(input_args=None): else: args = parser.parse_args() + if args.pretrained_model_name_or_path is not None and args.model_config_name_or_path is not None: + raise ValueError("Cannot specify both `--pretrained_model_name_or_path` and `--model_config_name_or_path`") + if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") From 80d39797b9b9c6e4549dc6313d30c3b48cdab772 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 22 Jan 2025 01:47:24 +0000 Subject: [PATCH 6/8] remove the comment, remove diffusers in requiremnets.txt, add validation_image ote --- examples/research_projects/autoencoderkl/README.md | 2 ++ examples/research_projects/autoencoderkl/requirements.txt | 1 - examples/research_projects/autoencoderkl/train_autoencoderkl.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/autoencoderkl/README.md b/examples/research_projects/autoencoderkl/README.md index c0b26294422f..c62018312da5 100644 --- a/examples/research_projects/autoencoderkl/README.md +++ b/examples/research_projects/autoencoderkl/README.md @@ -27,6 +27,8 @@ accelerate config ## Training on CIFAR10 +Please replace the validation image with your own image. + ```bash accelerate launch train_autoencoderkl.py \ --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \ diff --git a/examples/research_projects/autoencoderkl/requirements.txt b/examples/research_projects/autoencoderkl/requirements.txt index 0a0e4d91c295..fe501252b46a 100644 --- a/examples/research_projects/autoencoderkl/requirements.txt +++ b/examples/research_projects/autoencoderkl/requirements.txt @@ -1,7 +1,6 @@ accelerate>=0.16.0 bitsandbytes datasets -diffusers huggingface_hub lpips numpy diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py index c1d609aeebf4..05a1cd045130 100644 --- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -58,7 +58,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -# check_min_version("0.33.0.dev0") +check_min_version("0.33.0.dev0") logger = get_logger(__name__) From ca62ba801cbb8da8477339e750b21a43a085b458 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sun, 26 Jan 2025 06:49:32 +0000 Subject: [PATCH 7/8] update autoencoderkl.py --- .../autoencoderkl/train_autoencoderkl.py | 90 ++++++++++--------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py index 05a1cd045130..c5788dba8691 100644 --- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -19,13 +19,12 @@ import logging import math import os -import random import shutil from pathlib import Path import accelerate -import numpy as np import lpips +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -38,9 +37,7 @@ from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image -from taming.modules.losses.vqperceptual import ( - hinge_d_loss, vanilla_d_loss, weights_init, NLayerDiscriminator -) +from taming.modules.losses.vqperceptual import NLayerDiscriminator, hinge_d_loss, vanilla_d_loss, weights_init from torchvision import transforms from tqdm.auto import tqdm @@ -93,22 +90,22 @@ def log_validation( with inference_ctx: reconstructions = vae(targets).sample - + images.append( torch.cat([targets.cpu(), reconstructions.cpu()], axis=0) ) - + tracker_key = "test" if is_final_validation else "validation" for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images( - "Original (left), Reconstruction (right)", np_images, step + f"{tracker_key}: Original (left), Reconstruction (right)", np_images, step ) elif tracker.name == "wandb": tracker.log( { - "Original (left), Reconstruction (right)": [ + f"{tracker_key}: Original (left), Reconstruction (right)": [ wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(images) ] @@ -127,8 +124,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None) img_str = "" if images is not None: img_str = "You can find some example images below.\n\n" - make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images.png")) - img_str += f"![images](./images.png)\n" + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, "images.png")) + img_str += "![images](./images.png)\n" model_description = f""" # autoencoderkl-{repo_id} @@ -529,7 +526,7 @@ def make_train_dataset(args, accelerator): # Preprocessing the datasets. # We need to tokenize inputs and targets. column_names = dataset["train"].column_names - + # 6. Get the column names for input/target. if args.image_column is None: image_column = column_names[0] @@ -540,7 +537,7 @@ def make_train_dataset(args, accelerator): raise ValueError( f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - + image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), @@ -580,7 +577,7 @@ def main(args): "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." " Please use `huggingface-cli login` to authenticate with the Hub." ) - + logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) @@ -591,7 +588,7 @@ def main(args): log_with=args.report_to, project_config=accelerator_project_config, ) - + # Disable AMP for MPS. if torch.backends.mps.is_available(): accelerator.native_amp = False @@ -623,7 +620,7 @@ def main(args): repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id - + # Load AutoencoderKL if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None: config = AutoencoderKL.load_config("stabilityai/sd-vae-ft-mse") @@ -637,7 +634,13 @@ def main(args): ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config) perceptual_loss = lpips.LPIPS(net="vgg").eval() discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init) - + + # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -677,7 +680,7 @@ def load_model_hook(models, input_dir): load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(os.path.join(input_dir, "discriminator", "pytorch_model.bin")) model.load_state_dict(load_model.state_dict()) del load_model - + model = models.pop() load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="autoencoderkl") model.register_to_config(**load_model.config) @@ -686,8 +689,8 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - - + + vae.requires_grad_(True) if args.decoder_only: vae.encoder.requires_grad_(False) @@ -696,7 +699,7 @@ def load_model_hook(models, input_dir): vae.train() discriminator.requires_grad_(True) discriminator.train() - + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -709,16 +712,21 @@ def load_model_hook(models, input_dir): vae.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") - + if args.gradient_checkpointing: vae.enable_gradient_checkpointing() - + # Check that all trainable models are in full precision low_precision_error_string = ( " Please make sure to always have all model weights in full float32 precision when starting training - even if" " doing mixed precision training, copy of the weights should still be float32." ) - + + if unwrap_model(vae).dtype != torch.float32: + raise ValueError( + f"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}" + ) + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -728,7 +736,7 @@ def load_model_hook(models, input_dir): args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: @@ -741,7 +749,7 @@ def load_model_hook(models, input_dir): optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW - + params_to_optimize = filter(lambda p: p.requires_grad, vae.parameters()) disc_params_to_optimize = filter(lambda p: p.requires_grad, discriminator.parameters()) optimizer = optimizer_class( @@ -760,7 +768,7 @@ def load_model_hook(models, input_dir): ) train_dataset = make_train_dataset(args, accelerator) - + train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, @@ -768,14 +776,14 @@ def load_model_hook(models, input_dir): batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) - + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - + lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, @@ -792,12 +800,12 @@ def load_model_hook(models, input_dir): num_cycles=args.lr_num_cycles, power=args.lr_power, ) - + # Prepare everything with our `accelerator`. vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler = accelerator.prepare( vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler ) - + # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -805,14 +813,14 @@ def load_model_hook(models, input_dir): weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - + # Move VAE, perceptual loss and discriminator to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) perceptual_loss.to(accelerator.device, dtype=weight_dtype) discriminator.to(accelerator.device, dtype=weight_dtype) if args.use_ema: ema_vae.to(accelerator.device, dtype=weight_dtype) - + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -850,7 +858,7 @@ def load_model_hook(models, input_dir): dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None - + if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." @@ -866,7 +874,7 @@ def load_model_hook(models, input_dir): first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 - + progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, @@ -898,7 +906,7 @@ def load_model_hook(models, input_dir): # perceptual loss. The high level feature mean squared error loss with torch.no_grad(): p_loss = perceptual_loss(reconstructions, targets) - + rec_loss = rec_loss + args.perceptual_scale * p_loss nll_loss = rec_loss nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] @@ -915,9 +923,9 @@ def load_model_hook(models, input_dir): disc_weight = torch.clamp(disc_weight, 0.0, 1e4).detach() disc_weight = disc_weight * args.disc_scale disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0 - + loss = nll_loss + args.kl_scale * kl_loss + disc_weight * disc_factor * g_loss - + logs = { "loss": loss.detach().mean().item(), "nll_loss": nll_loss.detach().mean().item(), @@ -929,7 +937,7 @@ def load_model_hook(models, input_dir): "g_loss": g_loss.detach().mean().item(), "lr": lr_scheduler.get_last_lr()[0] } - + accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = vae.parameters() @@ -1002,7 +1010,7 @@ def load_model_hook(models, input_dir): if global_step >= args.max_train_steps: break - + # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -1036,7 +1044,7 @@ def load_model_hook(models, input_dir): commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) - + accelerator.end_training() From 2e348b84731f02e6dc89461cc1ae2b7bf8b5fa53 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 27 Jan 2025 16:18:42 +0530 Subject: [PATCH 8/8] quality --- .../autoencoderkl/train_autoencoderkl.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py index c5788dba8691..cf13ecdbf8ac 100644 --- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -61,9 +61,7 @@ @torch.no_grad() -def log_validation( - vae, args, accelerator, weight_dtype, step, is_final_validation=False -): +def log_validation(vae, args, accelerator, weight_dtype, step, is_final_validation=False): logger.info("Running validation... ") if not is_final_validation: @@ -91,23 +89,18 @@ def log_validation( with inference_ctx: reconstructions = vae(targets).sample - images.append( - torch.cat([targets.cpu(), reconstructions.cpu()], axis=0) - ) + images.append(torch.cat([targets.cpu(), reconstructions.cpu()], axis=0)) tracker_key = "test" if is_final_validation else "validation" for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images( - f"{tracker_key}: Original (left), Reconstruction (right)", np_images, step - ) + tracker.writer.add_images(f"{tracker_key}: Original (left), Reconstruction (right)", np_images, step) elif tracker.name == "wandb": tracker.log( { f"{tracker_key}: Original (left), Reconstruction (right)": [ - wandb.Image(torchvision.utils.make_grid(image)) - for _, image in enumerate(images) + wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(images) ] } ) @@ -677,7 +670,9 @@ def load_model_hook(models, input_dir): # pop models so that they are not loaded again model = models.pop() - load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(os.path.join(input_dir, "discriminator", "pytorch_model.bin")) + load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict( + os.path.join(input_dir, "discriminator", "pytorch_model.bin") + ) model.load_state_dict(load_model.state_dict()) del load_model @@ -690,7 +685,6 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - vae.requires_grad_(True) if args.decoder_only: vae.encoder.requires_grad_(False) @@ -723,9 +717,7 @@ def load_model_hook(models, input_dir): ) if unwrap_model(vae).dtype != torch.float32: - raise ValueError( - f"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}" - ) + raise ValueError(f"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}") # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -802,7 +794,15 @@ def load_model_hook(models, input_dir): ) # Prepare everything with our `accelerator`. - vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler = accelerator.prepare( + ( + vae, + discriminator, + optimizer, + disc_optimizer, + train_dataloader, + lr_scheduler, + disc_lr_scheduler, + ) = accelerator.prepare( vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler ) @@ -935,7 +935,7 @@ def load_model_hook(models, input_dir): "disc_weight": disc_weight.detach().mean().item(), "disc_factor": disc_factor, "g_loss": g_loss.detach().mean().item(), - "lr": lr_scheduler.get_last_lr()[0] + "lr": lr_scheduler.get_last_lr()[0], } accelerator.backward(loss) @@ -956,7 +956,7 @@ def load_model_hook(models, input_dir): "disc_loss": disc_loss.detach().mean().item(), "logits_real": logits_real.detach().mean().item(), "logits_fake": logits_fake.detach().mean().item(), - "disc_lr": disc_lr_scheduler.get_last_lr()[0] + "disc_lr": disc_lr_scheduler.get_last_lr()[0], } # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: