Skip to content

Commit 1a8b3c2

Browse files
sayakpaulbghirakashif
authored
[Training] SD3 training fixes (huggingface#8917)
* SD3 training fixes Co-authored-by: bghira <[email protected]> * rewrite noise addition part to respect the eqn. * styler * Update examples/dreambooth/README_sd3.md Co-authored-by: Kashif Rasul <[email protected]> --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Kashif Rasul <[email protected]>
1 parent 56e772a commit 1a8b3c2

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

examples/dreambooth/README_sd3.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,6 @@ accelerate launch train_dreambooth_lora_sd3.py \
183183

184184
## Other notes
185185

186-
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
186+
1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
187+
2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917).
188+
3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well.

examples/dreambooth/train_dreambooth_lora_sd3.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,13 @@ def parse_args(input_args=None):
523523
default=1.29,
524524
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
525525
)
526+
parser.add_argument(
527+
"--precondition_outputs",
528+
type=int,
529+
default=1,
530+
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
531+
"model `target` is calculated.",
532+
)
526533
parser.add_argument(
527534
"--optimizer",
528535
type=str,
@@ -1636,7 +1643,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16361643

16371644
# Convert images to latent space
16381645
model_input = vae.encode(pixel_values).latent_dist.sample()
1639-
model_input = model_input * vae.config.scaling_factor
1646+
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
16401647
model_input = model_input.to(dtype=weight_dtype)
16411648

16421649
# Sample noise that we'll add to the latents
@@ -1656,8 +1663,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16561663
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
16571664

16581665
# Add noise according to flow matching.
1666+
# zt = (1 - texp) * x + texp * z1
16591667
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1660-
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
1668+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
16611669

16621670
# Predict the noise residual
16631671
model_pred = transformer(
@@ -1670,14 +1678,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16701678

16711679
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
16721680
# Preconditioning of the model outputs.
1673-
model_pred = model_pred * (-sigmas) + noisy_model_input
1681+
if args.precondition_outputs:
1682+
model_pred = model_pred * (-sigmas) + noisy_model_input
16741683

16751684
# these weighting schemes use a uniform timestep sampling
16761685
# and instead post-weight the loss
16771686
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
16781687

16791688
# flow matching loss
1680-
target = model_input
1689+
if args.precondition_outputs:
1690+
target = model_input
1691+
else:
1692+
target = noise - model_input
16811693

16821694
if args.with_prior_preservation:
16831695
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.

examples/dreambooth/train_dreambooth_sd3.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,13 @@ def parse_args(input_args=None):
494494
default=1.29,
495495
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
496496
)
497+
parser.add_argument(
498+
"--precondition_outputs",
499+
type=int,
500+
default=1,
501+
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
502+
"model `target` is calculated.",
503+
)
497504
parser.add_argument(
498505
"--optimizer",
499506
type=str,
@@ -1549,7 +1556,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15491556

15501557
# Convert images to latent space
15511558
model_input = vae.encode(pixel_values).latent_dist.sample()
1552-
model_input = model_input * vae.config.scaling_factor
1559+
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
15531560
model_input = model_input.to(dtype=weight_dtype)
15541561

15551562
# Sample noise that we'll add to the latents
@@ -1569,8 +1576,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15691576
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
15701577

15711578
# Add noise according to flow matching.
1579+
# zt = (1 - texp) * x + texp * z1
15721580
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1573-
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
1581+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
15741582

15751583
# Predict the noise residual
15761584
if not args.train_text_encoder:
@@ -1598,13 +1606,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15981606

15991607
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
16001608
# Preconditioning of the model outputs.
1601-
model_pred = model_pred * (-sigmas) + noisy_model_input
1609+
if args.precondition_outputs:
1610+
model_pred = model_pred * (-sigmas) + noisy_model_input
1611+
16021612
# these weighting schemes use a uniform timestep sampling
16031613
# and instead post-weight the loss
16041614
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
16051615

16061616
# flow matching loss
1607-
target = model_input
1617+
if args.precondition_outputs:
1618+
target = model_input
1619+
else:
1620+
target = noise - model_input
16081621

16091622
if args.with_prior_preservation:
16101623
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.

0 commit comments

Comments
 (0)