Skip to content

Commit 158a5a8

Browse files
lmxyylinoytsaban
andauthored
Remove the FP32 Wrapper when evaluating (#10617)
Remove the FP32 Wrapper Co-authored-by: Linoy Tsaban <[email protected]>
1 parent 012d08b commit 158a5a8

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,9 +1716,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17161716
pipeline = FluxPipeline.from_pretrained(
17171717
args.pretrained_model_name_or_path,
17181718
vae=vae,
1719-
text_encoder=accelerator.unwrap_model(text_encoder_one),
1720-
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1721-
transformer=accelerator.unwrap_model(transformer),
1719+
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
1720+
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
1721+
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
17221722
revision=args.revision,
17231723
variant=args.variant,
17241724
torch_dtype=weight_dtype,

0 commit comments

Comments
 (0)