@@ -298,6 +298,12 @@ def parse_args(input_args=None):
298
298
default = None ,
299
299
help = "A prompt that is used during validation to verify that the model is learning." ,
300
300
)
301
+ parser .add_argument (
302
+ "--final_validation_prompt" ,
303
+ type = str ,
304
+ default = None ,
305
+ help = "A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided." ,
306
+ )
301
307
parser .add_argument (
302
308
"--num_validation_images" ,
303
309
type = int ,
@@ -1367,6 +1373,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1367
1373
noisy_model_input = (1.0 - sigmas ) * noise + sigmas * model_input
1368
1374
1369
1375
# Predict the noise residual
1376
+ # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
1377
+ timesteps = 1 - timesteps / noise_scheduler .config .num_train_timesteps
1370
1378
model_pred = transformer (
1371
1379
hidden_states = noisy_model_input ,
1372
1380
encoder_hidden_states = prompt_embeds ,
@@ -1379,8 +1387,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1379
1387
# and instead post-weight the loss
1380
1388
weighting = compute_loss_weighting_for_sd3 (weighting_scheme = args .weighting_scheme , sigmas = sigmas )
1381
1389
1382
- # flow matching loss
1383
- target = noise - model_input
1390
+ # flow matching loss (reversed)
1391
+ target = model_input - noise
1384
1392
1385
1393
if args .with_prior_preservation :
1386
1394
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
@@ -1505,8 +1513,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1505
1513
1506
1514
# run inference
1507
1515
images = []
1508
- if args .validation_prompt and args .num_validation_images > 0 :
1509
- pipeline_args = {"prompt" : args .validation_prompt , "system_prompt" : args .system_prompt }
1516
+ if (args .validation_prompt and args .num_validation_images > 0 ) or (args .final_validation_prompt ):
1517
+ prompt_to_use = args .validation_prompt if args .validation_prompt else args .final_validation_prompt
1518
+ args .num_validation_images = args .num_validation_images if args .num_validation_images else 1
1519
+ pipeline_args = {"prompt" : prompt_to_use , "system_prompt" : args .system_prompt }
1510
1520
images = log_validation (
1511
1521
pipeline = pipeline ,
1512
1522
args = args ,
0 commit comments