|
23 | 23 | import os
|
24 | 24 | import random
|
25 | 25 | import shutil
|
| 26 | +from contextlib import nullcontext |
26 | 27 | from pathlib import Path
|
27 | 28 | from typing import List, Union
|
28 | 29 |
|
@@ -238,6 +239,10 @@ def train_dataloader(self):
|
238 | 239 |
|
239 | 240 | def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
240 | 241 | logger.info("Running validation... ")
|
| 242 | + if torch.backends.mps.is_available(): |
| 243 | + autocast_ctx = nullcontext() |
| 244 | + else: |
| 245 | + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) |
241 | 246 |
|
242 | 247 | unet = accelerator.unwrap_model(unet)
|
243 | 248 | pipeline = StableDiffusionPipeline.from_pretrained(
|
@@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
274 | 279 |
|
275 | 280 | for _, prompt in enumerate(validation_prompts):
|
276 | 281 | images = []
|
277 |
| - with torch.autocast("cuda", dtype=weight_dtype): |
| 282 | + with autocast_ctx: |
278 | 283 | images = pipeline(
|
279 | 284 | prompt=prompt,
|
280 | 285 | num_inference_steps=4,
|
@@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
|
1172 | 1177 | ).input_ids.to(accelerator.device)
|
1173 | 1178 | uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
1174 | 1179 |
|
| 1180 | + if torch.backends.mps.is_available(): |
| 1181 | + autocast_ctx = nullcontext() |
| 1182 | + else: |
| 1183 | + autocast_ctx = torch.autocast(accelerator.device.type) |
| 1184 | + |
1175 | 1185 | # 16. Train!
|
1176 | 1186 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
1177 | 1187 |
|
@@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
|
1300 | 1310 | # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
1301 | 1311 | # solver timestep.
|
1302 | 1312 | with torch.no_grad():
|
1303 |
| - with torch.autocast("cuda"): |
| 1313 | + with autocast_ctx: |
1304 | 1314 | # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
1305 | 1315 | cond_teacher_output = teacher_unet(
|
1306 | 1316 | noisy_model_input.to(weight_dtype),
|
@@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
|
1359 | 1369 | # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
1360 | 1370 | # Note that we do not use a separate target network for LCM-LoRA distillation.
|
1361 | 1371 | with torch.no_grad():
|
1362 |
| - with torch.autocast("cuda", dtype=weight_dtype): |
| 1372 | + with autocast_ctx: |
1363 | 1373 | target_noise_pred = unet(
|
1364 | 1374 | x_prev.float(),
|
1365 | 1375 | timesteps,
|
|
0 commit comments