@@ -471,9 +471,17 @@ def collate_fn(examples):
471
471
unet , optimizer , train_dataloader , lr_scheduler
472
472
)
473
473
474
- # Move text_encode and vae to gpu
475
- text_encoder .to (accelerator .device )
476
- vae .to (accelerator .device )
474
+ weight_dtype = torch .float32
475
+ if args .mixed_precision == "fp16" :
476
+ weight_dtype = torch .float16
477
+ elif args .mixed_precision == "bf16" :
478
+ weight_dtype = torch .bfloat16
479
+
480
+ # Move text_encode and vae to gpu.
481
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
482
+ # as these models are only used for inference, keeping weights in full precision is not required.
483
+ text_encoder .to (accelerator .device , dtype = weight_dtype )
484
+ vae .to (accelerator .device , dtype = weight_dtype )
477
485
478
486
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
479
487
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -509,11 +517,11 @@ def collate_fn(examples):
509
517
with accelerator .accumulate (unet ):
510
518
# Convert images to latent space
511
519
with torch .no_grad ():
512
- latents = vae .encode (batch ["pixel_values" ]).latent_dist .sample ()
520
+ latents = vae .encode (batch ["pixel_values" ]. to ( dtype = weight_dtype ) ).latent_dist .sample ()
513
521
latents = latents * 0.18215
514
522
515
523
# Sample noise that we'll add to the latents
516
- noise = torch .randn (latents . shape ). to ( latents . device )
524
+ noise = torch .randn_like (latents )
517
525
bsz = latents .shape [0 ]
518
526
# Sample a random timestep for each image
519
527
timesteps = torch .randint (0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device )
@@ -539,12 +547,12 @@ def collate_fn(examples):
539
547
loss = F .mse_loss (noise_pred , noise , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
540
548
541
549
# Compute prior loss
542
- prior_loss = F .mse_loss (noise_pred_prior , noise_prior , reduction = "none" ). mean ([ 1 , 2 , 3 ]). mean ( )
550
+ prior_loss = F .mse_loss (noise_pred_prior . float () , noise_prior . float () , reduction = "mean" )
543
551
544
552
# Add the prior loss to the instance loss.
545
553
loss = loss + args .prior_loss_weight * prior_loss
546
554
else :
547
- loss = F .mse_loss (noise_pred , noise , reduction = "none" ). mean ([ 1 , 2 , 3 ]). mean ( )
555
+ loss = F .mse_loss (noise_pred . float () , noise . float () , reduction = "mean" )
548
556
549
557
accelerator .backward (loss )
550
558
if accelerator .sync_gradients :
0 commit comments