Skip to content

Commit 81bdbb5

Browse files
authored
DreamBooth DeepSpeed support for under 8 GB VRAM training (#735)
* Support deepspeed * Dreambooth DeepSpeed documentation * Remove unnecessary casts, documentation Due to recent commits some casts to half precision are not necessary anymore. Mention that DeepSpeed's version of Adam is about 2x faster. * Review comments
1 parent 71ca10c commit 81bdbb5

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

examples/dreambooth/README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,46 @@ accelerate launch train_dreambooth.py \
119119
--max_train_steps=800
120120
```
121121

122+
### Training on a 8 GB GPU:
123+
124+
By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
125+
tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
126+
127+
DeepSpeed needs to be enabled with `accelerate config`. During configuration
128+
answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
129+
mixed precision and offloading both parameters and optimizer state to cpu it's
130+
possible to train on under 8 GB VRAM with a drawback of requiring significantly
131+
more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
132+
133+
Changing the default Adam optimizer to DeepSpeed's special version of Adam
134+
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
135+
it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
136+
does not seem to be compatible with DeepSpeed at the moment.
137+
138+
```bash
139+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
140+
export INSTANCE_DIR="path-to-instance-images"
141+
export CLASS_DIR="path-to-class-images"
142+
export OUTPUT_DIR="path-to-save-model"
143+
144+
accelerate launch train_dreambooth.py \
145+
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
146+
--instance_data_dir=$INSTANCE_DIR \
147+
--class_data_dir=$CLASS_DIR \
148+
--output_dir=$OUTPUT_DIR \
149+
--with_prior_preservation --prior_loss_weight=1.0 \
150+
--instance_prompt="a photo of sks dog" \
151+
--class_prompt="a photo of dog" \
152+
--resolution=512 \
153+
--train_batch_size=1 \
154+
--gradient_accumulation_steps=1 --gradient_checkpointing \
155+
--learning_rate=5e-6 \
156+
--lr_scheduler="constant" \
157+
--lr_warmup_steps=0 \
158+
--num_class_images=200 \
159+
--max_train_steps=800 \
160+
--mixed_precision=fp16
161+
```
122162

123163
## Inference
124164

examples/dreambooth/train_dreambooth.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,17 @@ def collate_fn(examples):
471471
unet, optimizer, train_dataloader, lr_scheduler
472472
)
473473

474-
# Move text_encode and vae to gpu
475-
text_encoder.to(accelerator.device)
476-
vae.to(accelerator.device)
474+
weight_dtype = torch.float32
475+
if args.mixed_precision == "fp16":
476+
weight_dtype = torch.float16
477+
elif args.mixed_precision == "bf16":
478+
weight_dtype = torch.bfloat16
479+
480+
# Move text_encode and vae to gpu.
481+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
482+
# as these models are only used for inference, keeping weights in full precision is not required.
483+
text_encoder.to(accelerator.device, dtype=weight_dtype)
484+
vae.to(accelerator.device, dtype=weight_dtype)
477485

478486
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
479487
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -509,11 +517,11 @@ def collate_fn(examples):
509517
with accelerator.accumulate(unet):
510518
# Convert images to latent space
511519
with torch.no_grad():
512-
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
520+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
513521
latents = latents * 0.18215
514522

515523
# Sample noise that we'll add to the latents
516-
noise = torch.randn(latents.shape).to(latents.device)
524+
noise = torch.randn_like(latents)
517525
bsz = latents.shape[0]
518526
# Sample a random timestep for each image
519527
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
@@ -539,12 +547,12 @@ def collate_fn(examples):
539547
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
540548

541549
# Compute prior loss
542-
prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
550+
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
543551

544552
# Add the prior loss to the instance loss.
545553
loss = loss + args.prior_loss_weight * prior_loss
546554
else:
547-
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
555+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
548556

549557
accelerator.backward(loss)
550558
if accelerator.sync_gradients:

0 commit comments

Comments
 (0)