@@ -1220,6 +1220,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
1220
1220
prompt_attention_mask = torch .cat ([prompt_attention_mask , class_prompt_attention_mask ], dim = 0 )
1221
1221
1222
1222
vae_config_scaling_factor = vae .config .scaling_factor
1223
+ vae_config_shift_factor = vae .config .shift_factor
1223
1224
if args .cache_latents :
1224
1225
latents_cache = []
1225
1226
vae = vae .to (accelerator .device )
@@ -1334,9 +1335,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1334
1335
1335
1336
for step , batch in enumerate (train_dataloader ):
1336
1337
models_to_accumulate = [transformer ]
1337
- with accelerator .accumulate (models_to_accumulate ):
1338
- prompts = batch ["prompts" ]
1338
+ prompts = batch ["prompts" ]
1339
1339
1340
+ with accelerator .accumulate (models_to_accumulate ):
1340
1341
# encode batch prompts when custom prompts are provided for each image -
1341
1342
if train_dataset .custom_instance_prompts :
1342
1343
prompt_embeds , prompt_attention_mask = compute_text_embeddings (prompts , text_encoding_pipeline )
@@ -1350,7 +1351,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1350
1351
model_input = vae .encode (pixel_values ).latent_dist .sample ()
1351
1352
if args .offload :
1352
1353
vae = vae .to ("cpu" )
1353
- model_input = model_input * vae_config_scaling_factor
1354
+ model_input = ( model_input - vae_config_shift_factor ) * vae_config_scaling_factor
1354
1355
model_input = model_input .to (dtype = weight_dtype )
1355
1356
1356
1357
# Sample noise that we'll add to the latents
@@ -1380,8 +1381,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1380
1381
timesteps = timesteps / noise_scheduler .config .num_train_timesteps
1381
1382
model_pred = transformer (
1382
1383
hidden_states = noisy_model_input ,
1383
- encoder_hidden_states = prompt_embeds ,
1384
- encoder_attention_mask = prompt_attention_mask ,
1384
+ encoder_hidden_states = prompt_embeds .repeat (len (prompts ), 1 , 1 )
1385
+ if not train_dataset .custom_instance_prompts
1386
+ else prompt_embeds ,
1387
+ encoder_attention_mask = prompt_attention_mask .repeat (len (prompts ), 1 )
1388
+ if not train_dataset .custom_instance_prompts
1389
+ else prompt_attention_mask ,
1385
1390
timestep = timesteps ,
1386
1391
return_dict = False ,
1387
1392
)[0 ]
@@ -1536,7 +1541,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1536
1541
base_model = args .pretrained_model_name_or_path ,
1537
1542
instance_prompt = args .instance_prompt ,
1538
1543
system_prompt = args .system_prompt ,
1539
- validation_prompt = args .validation_prompt ,
1544
+ validation_prompt = args .validation_prompt or args . final_validation_prompt ,
1540
1545
repo_folder = args .output_dir ,
1541
1546
)
1542
1547
upload_folder (
0 commit comments