Skip to content

Commit 61cbb21

Browse files
committed
fixes
1 parent 9f85028 commit 61cbb21

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

examples/dreambooth/train_dreambooth_lora_lumina2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
12201220
prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0)
12211221

12221222
vae_config_scaling_factor = vae.config.scaling_factor
1223+
vae_config_shift_factor = vae.config.shift_factor
12231224
if args.cache_latents:
12241225
latents_cache = []
12251226
vae = vae.to(accelerator.device)
@@ -1334,9 +1335,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13341335

13351336
for step, batch in enumerate(train_dataloader):
13361337
models_to_accumulate = [transformer]
1337-
with accelerator.accumulate(models_to_accumulate):
1338-
prompts = batch["prompts"]
1338+
prompts = batch["prompts"]
13391339

1340+
with accelerator.accumulate(models_to_accumulate):
13401341
# encode batch prompts when custom prompts are provided for each image -
13411342
if train_dataset.custom_instance_prompts:
13421343
prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline)
@@ -1350,7 +1351,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13501351
model_input = vae.encode(pixel_values).latent_dist.sample()
13511352
if args.offload:
13521353
vae = vae.to("cpu")
1353-
model_input = model_input * vae_config_scaling_factor
1354+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
13541355
model_input = model_input.to(dtype=weight_dtype)
13551356

13561357
# Sample noise that we'll add to the latents
@@ -1380,8 +1381,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13801381
timesteps = timesteps / noise_scheduler.config.num_train_timesteps
13811382
model_pred = transformer(
13821383
hidden_states=noisy_model_input,
1383-
encoder_hidden_states=prompt_embeds,
1384-
encoder_attention_mask=prompt_attention_mask,
1384+
encoder_hidden_states=prompt_embeds.repeat(len(prompts), 1, 1)
1385+
if not train_dataset.custom_instance_prompts
1386+
else prompt_embeds,
1387+
encoder_attention_mask=prompt_attention_mask.repeat(len(prompts), 1)
1388+
if not train_dataset.custom_instance_prompts
1389+
else prompt_attention_mask,
13851390
timestep=timesteps,
13861391
return_dict=False,
13871392
)[0]
@@ -1536,7 +1541,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15361541
base_model=args.pretrained_model_name_or_path,
15371542
instance_prompt=args.instance_prompt,
15381543
system_prompt=args.system_prompt,
1539-
validation_prompt=args.validation_prompt,
1544+
validation_prompt=args.validation_prompt or args.final_validation_prompt,
15401545
repo_folder=args.output_dir,
15411546
)
15421547
upload_folder(

0 commit comments

Comments
 (0)