Skip to content

Commit 857b3a0

Browse files
committed
major updates.
1 parent 6c1f56d commit 857b3a0

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

examples/dreambooth/train_dreambooth_lora_lumina2.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,12 @@ def parse_args(input_args=None):
298298
default=None,
299299
help="A prompt that is used during validation to verify that the model is learning.",
300300
)
301+
parser.add_argument(
302+
"--final_validation_prompt",
303+
type=str,
304+
default=None,
305+
help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.",
306+
)
301307
parser.add_argument(
302308
"--num_validation_images",
303309
type=int,
@@ -1367,6 +1373,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13671373
noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input
13681374

13691375
# Predict the noise residual
1376+
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
1377+
timesteps = 1 - timesteps / noise_scheduler.config.num_train_timesteps
13701378
model_pred = transformer(
13711379
hidden_states=noisy_model_input,
13721380
encoder_hidden_states=prompt_embeds,
@@ -1379,8 +1387,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13791387
# and instead post-weight the loss
13801388
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
13811389

1382-
# flow matching loss
1383-
target = noise - model_input
1390+
# flow matching loss (reversed)
1391+
target = model_input - noise
13841392

13851393
if args.with_prior_preservation:
13861394
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
@@ -1505,8 +1513,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15051513

15061514
# run inference
15071515
images = []
1508-
if args.validation_prompt and args.num_validation_images > 0:
1509-
pipeline_args = {"prompt": args.validation_prompt, "system_prompt": args.system_prompt}
1516+
if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt):
1517+
prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
1518+
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
1519+
pipeline_args = {"prompt": prompt_to_use, "system_prompt": args.system_prompt}
15101520
images = log_validation(
15111521
pipeline=pipeline,
15121522
args=args,

0 commit comments

Comments
 (0)