Skip to content

Commit 8e963d1

Browse files
bghirabghirasayakpaul
authored
7529 do not disable autocast for cuda devices (#7530)
* 7529 do not disable autocast for cuda devices * Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue * add autocast fix to other training examples * disable native_amp for dreambooth (sdxl) * disable native_amp for pix2pix (sdxl) * remove tests from remaining files * disable native_amp on huggingface accelerator for every training example that uses it * convert more usages of autocast to nullcontext, make style fixes * make style fixes * style. * Empty-Commit --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 2b04ec2 commit 8e963d1

File tree

47 files changed

+312
-118
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+312
-118
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import re
2424
import shutil
2525
import warnings
26+
from contextlib import nullcontext
2627
from pathlib import Path
2728
from typing import List, Optional
2829

@@ -1844,7 +1845,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18441845
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
18451846
pipeline_args = {"prompt": args.validation_prompt}
18461847

1847-
with torch.cuda.amp.autocast():
1848+
if torch.backends.mps.is_available():
1849+
autocast_ctx = nullcontext()
1850+
else:
1851+
autocast_ctx = torch.autocast(accelerator.device.type)
1852+
1853+
with autocast_ctx:
18481854
images = [
18491855
pipeline(**pipeline_args, generator=generator).images[0]
18501856
for _ in range(args.num_validation_images)

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515

1616
import argparse
17-
import contextlib
1817
import gc
1918
import hashlib
2019
import itertools
@@ -26,6 +25,7 @@
2625
import re
2726
import shutil
2827
import warnings
28+
from contextlib import nullcontext
2929
from pathlib import Path
3030
from typing import List, Optional
3131

@@ -2192,13 +2192,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21922192
# run inference
21932193
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
21942194
pipeline_args = {"prompt": args.validation_prompt}
2195-
inference_ctx = (
2196-
contextlib.nullcontext()
2197-
if "playground" in args.pretrained_model_name_or_path
2198-
else torch.cuda.amp.autocast()
2199-
)
2195+
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
2196+
autocast_ctx = nullcontext()
2197+
else:
2198+
autocast_ctx = torch.autocast(accelerator.device.type)
22002199

2201-
with inference_ctx:
2200+
with autocast_ctx:
22022201
images = [
22032202
pipeline(**pipeline_args, generator=generator).images[0]
22042203
for _ in range(args.num_validation_images)

examples/amused/train_amused.py

+3
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ def main(args):
430430
log_with=args.report_to,
431431
project_config=accelerator_project_config,
432432
)
433+
# Disable AMP for MPS.
434+
if torch.backends.mps.is_available():
435+
accelerator.native_amp = False
433436

434437
if accelerator.is_main_process:
435438
os.makedirs(args.output_dir, exist_ok=True)

examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import random
2525
import shutil
26+
from contextlib import nullcontext
2627
from pathlib import Path
2728
from typing import List, Union
2829

@@ -238,6 +239,10 @@ def train_dataloader(self):
238239

239240
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
240241
logger.info("Running validation... ")
242+
if torch.backends.mps.is_available():
243+
autocast_ctx = nullcontext()
244+
else:
245+
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
241246

242247
unet = accelerator.unwrap_model(unet)
243248
pipeline = StableDiffusionPipeline.from_pretrained(
@@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
274279

275280
for _, prompt in enumerate(validation_prompts):
276281
images = []
277-
with torch.autocast("cuda", dtype=weight_dtype):
282+
with autocast_ctx:
278283
images = pipeline(
279284
prompt=prompt,
280285
num_inference_steps=4,
@@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
11721177
).input_ids.to(accelerator.device)
11731178
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
11741179

1180+
if torch.backends.mps.is_available():
1181+
autocast_ctx = nullcontext()
1182+
else:
1183+
autocast_ctx = torch.autocast(accelerator.device.type)
1184+
11751185
# 16. Train!
11761186
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
11771187

@@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
13001310
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
13011311
# solver timestep.
13021312
with torch.no_grad():
1303-
with torch.autocast("cuda"):
1313+
with autocast_ctx:
13041314
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
13051315
cond_teacher_output = teacher_unet(
13061316
noisy_model_input.to(weight_dtype),
@@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
13591369
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
13601370
# Note that we do not use a separate target network for LCM-LoRA distillation.
13611371
with torch.no_grad():
1362-
with torch.autocast("cuda", dtype=weight_dtype):
1372+
with autocast_ctx:
13631373
target_noise_pred = unet(
13641374
x_prev.float(),
13651375
timesteps,

examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import os
2323
import random
2424
import shutil
25+
from contextlib import nullcontext
2526
from pathlib import Path
2627

2728
import accelerate
@@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
146147

147148
for _, prompt in enumerate(validation_prompts):
148149
images = []
149-
with torch.autocast("cuda", dtype=weight_dtype):
150+
if torch.backends.mps.is_available():
151+
autocast_ctx = nullcontext()
152+
else:
153+
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
154+
155+
with autocast_ctx:
150156
images = pipeline(
151157
prompt=prompt,
152158
num_inference_steps=4,

examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import os
2525
import random
2626
import shutil
27+
from contextlib import nullcontext
2728
from pathlib import Path
2829
from typing import List, Union
2930

@@ -256,6 +257,10 @@ def train_dataloader(self):
256257

257258
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
258259
logger.info("Running validation... ")
260+
if torch.backends.mps.is_available():
261+
autocast_ctx = nullcontext()
262+
else:
263+
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
259264

260265
unet = accelerator.unwrap_model(unet)
261266
pipeline = StableDiffusionXLPipeline.from_pretrained(
@@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
291296

292297
for _, prompt in enumerate(validation_prompts):
293298
images = []
294-
with torch.autocast("cuda", dtype=weight_dtype):
299+
with autocast_ctx:
295300
images = pipeline(
296301
prompt=prompt,
297302
num_inference_steps=4,
@@ -1353,7 +1358,12 @@ def compute_embeddings(
13531358
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
13541359
# solver timestep.
13551360
with torch.no_grad():
1356-
with torch.autocast("cuda"):
1361+
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
1362+
autocast_ctx = nullcontext()
1363+
else:
1364+
autocast_ctx = torch.autocast(accelerator.device.type)
1365+
1366+
with autocast_ctx:
13571367
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
13581368
cond_teacher_output = teacher_unet(
13591369
noisy_model_input.to(weight_dtype),
@@ -1416,7 +1426,12 @@ def compute_embeddings(
14161426
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
14171427
# Note that we do not use a separate target network for LCM-LoRA distillation.
14181428
with torch.no_grad():
1419-
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
1429+
if torch.backends.mps.is_available():
1430+
autocast_ctx = nullcontext()
1431+
else:
1432+
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
1433+
1434+
with autocast_ctx:
14201435
target_noise_pred = unet(
14211436
x_prev.float(),
14221437
timesteps,

examples/consistency_distillation/train_lcm_distill_sd_wds.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import random
2525
import shutil
26+
from contextlib import nullcontext
2627
from pathlib import Path
2728
from typing import List, Union
2829

@@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
252253

253254
for _, prompt in enumerate(validation_prompts):
254255
images = []
255-
with torch.autocast("cuda"):
256+
if torch.backends.mps.is_available():
257+
autocast_ctx = nullcontext()
258+
else:
259+
autocast_ctx = torch.autocast(accelerator.device.type)
260+
261+
with autocast_ctx:
256262
images = pipeline(
257263
prompt=prompt,
258264
num_inference_steps=4,
@@ -1257,7 +1263,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
12571263
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
12581264
# solver timestep.
12591265
with torch.no_grad():
1260-
with torch.autocast("cuda"):
1266+
if torch.backends.mps.is_available():
1267+
autocast_ctx = nullcontext()
1268+
else:
1269+
autocast_ctx = torch.autocast(accelerator.device.type)
1270+
1271+
with autocast_ctx:
12611272
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
12621273
cond_teacher_output = teacher_unet(
12631274
noisy_model_input.to(weight_dtype),
@@ -1315,7 +1326,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
13151326

13161327
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
13171328
with torch.no_grad():
1318-
with torch.autocast("cuda", dtype=weight_dtype):
1329+
if torch.backends.mps.is_available():
1330+
autocast_ctx = nullcontext()
1331+
else:
1332+
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
1333+
1334+
with autocast_ctx:
13191335
target_noise_pred = target_unet(
13201336
x_prev.float(),
13211337
timesteps,

examples/consistency_distillation/train_lcm_distill_sdxl_wds.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import os
2525
import random
2626
import shutil
27+
from contextlib import nullcontext
2728
from pathlib import Path
2829
from typing import List, Union
2930

@@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
270271

271272
for _, prompt in enumerate(validation_prompts):
272273
images = []
273-
with torch.autocast("cuda"):
274+
if torch.backends.mps.is_available():
275+
autocast_ctx = nullcontext()
276+
else:
277+
autocast_ctx = torch.autocast(accelerator.device.type)
278+
279+
with autocast_ctx:
274280
images = pipeline(
275281
prompt=prompt,
276282
num_inference_steps=4,
@@ -1355,7 +1361,12 @@ def compute_embeddings(
13551361
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
13561362
# solver timestep.
13571363
with torch.no_grad():
1358-
with torch.autocast("cuda"):
1364+
if torch.backends.mps.is_available():
1365+
autocast_ctx = nullcontext()
1366+
else:
1367+
autocast_ctx = torch.autocast(accelerator.device.type)
1368+
1369+
with autocast_ctx:
13591370
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
13601371
cond_teacher_output = teacher_unet(
13611372
noisy_model_input.to(weight_dtype),
@@ -1417,7 +1428,12 @@ def compute_embeddings(
14171428

14181429
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
14191430
with torch.no_grad():
1420-
with torch.autocast("cuda", dtype=weight_dtype):
1431+
if torch.backends.mps.is_available():
1432+
autocast_ctx = nullcontext()
1433+
else:
1434+
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
1435+
1436+
with autocast_ctx:
14211437
target_noise_pred = target_unet(
14221438
x_prev.float(),
14231439
timesteps,

examples/controlnet/train_controlnet.py

+4
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,10 @@ def main(args):
752752
project_config=accelerator_project_config,
753753
)
754754

755+
# Disable AMP for MPS.
756+
if torch.backends.mps.is_available():
757+
accelerator.native_amp = False
758+
755759
# Make one log on every process with the configuration for debugging.
756760
logging.basicConfig(
757761
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",

examples/controlnet/train_controlnet_sdxl.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515

1616
import argparse
17-
import contextlib
1817
import functools
1918
import gc
2019
import logging
2120
import math
2221
import os
2322
import random
2423
import shutil
24+
from contextlib import nullcontext
2525
from pathlib import Path
2626

2727
import accelerate
@@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
125125
)
126126

127127
image_logs = []
128-
inference_ctx = (
129-
contextlib.nullcontext()
130-
if (is_final_validation or torch.backends.mps.is_available())
131-
else torch.autocast("cuda")
132-
)
128+
if is_final_validation or torch.backends.mps.is_available():
129+
autocast_ctx = nullcontext()
130+
else:
131+
autocast_ctx = torch.autocast(accelerator.device.type)
133132

134133
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
135134
validation_image = Image.open(validation_image).convert("RGB")
@@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
138137
images = []
139138

140139
for _ in range(args.num_validation_images):
141-
with inference_ctx:
140+
with autocast_ctx:
142141
image = pipeline(
143142
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
144143
).images[0]
@@ -811,6 +810,10 @@ def main(args):
811810
project_config=accelerator_project_config,
812811
)
813812

813+
# Disable AMP for MPS.
814+
if torch.backends.mps.is_available():
815+
accelerator.native_amp = False
816+
814817
# Make one log on every process with the configuration for debugging.
815818
logging.basicConfig(
816819
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",

examples/custom_diffusion/train_custom_diffusion.py

+4
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,10 @@ def main(args):
676676
project_config=accelerator_project_config,
677677
)
678678

679+
# Disable AMP for MPS.
680+
if torch.backends.mps.is_available():
681+
accelerator.native_amp = False
682+
679683
if args.report_to == "wandb":
680684
if not is_wandb_available():
681685
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

examples/dreambooth/train_dreambooth.py

+4
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,10 @@ def main(args):
821821
project_config=accelerator_project_config,
822822
)
823823

824+
# Disable AMP for MPS.
825+
if torch.backends.mps.is_available():
826+
accelerator.native_amp = False
827+
824828
if args.report_to == "wandb":
825829
if not is_wandb_available():
826830
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

examples/dreambooth/train_dreambooth_lora.py

+4
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,10 @@ def main(args):
749749
project_config=accelerator_project_config,
750750
)
751751

752+
# Disable AMP for MPS.
753+
if torch.backends.mps.is_available():
754+
accelerator.native_amp = False
755+
752756
if args.report_to == "wandb":
753757
if not is_wandb_available():
754758
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

0 commit comments

Comments
 (0)