Skip to content

Commit 7df9edb

Browse files
Fixes for VAE logic and 2B ControlNets, and speed up model loading by loading ControlNets to CUDA if available
2 parents 0221c91 + 6ae9733 commit 7df9edb

6 files changed

+64
-35
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,5 @@ cython_debug/
167167
#.idea/
168168

169169
.vscode/
170-
*.out.*
170+
*.out.*
171+
*.pt

dit_embedder.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(
3232
in_chans=in_chans,
3333
embed_dim=self.hidden_size,
3434
strict_img_size=pos_embed_max_size is None,
35+
device=device,
36+
dtype=dtype,
3537
)
3638

3739
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device)
@@ -41,14 +43,14 @@ def __init__(
4143

4244
self.transformer_blocks = nn.ModuleList(
4345
DismantledBlock(
44-
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True
46+
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True, device=device, dtype=dtype
4547
)
4648
for _ in range(num_layers)
4749
)
4850

4951
self.controlnet_blocks = nn.ModuleList([])
5052
for _ in range(len(self.transformer_blocks)):
51-
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
53+
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)
5254
self.controlnet_blocks.append(controlnet_block)
5355

5456
self.pos_embed_input = PatchEmbed(
@@ -57,7 +59,10 @@ def __init__(
5759
in_chans=in_chans,
5860
embed_dim=self.hidden_size,
5961
strict_img_size=False,
62+
dtype=dtype,
63+
device=device
6064
)
65+
self.using_8b_controlnet: bool = False
6166

6267
def forward(
6368
self,
@@ -66,10 +71,9 @@ def forward(
6671
y: Tensor,
6772
scale: int = 1,
6873
timestep: Optional[Tensor] = None,
69-
is_8b: bool = False
7074
) -> Tuple[Tensor, List[Tensor]]:
7175

72-
if not is_8b:
76+
if not self.using_8b_controlnet:
7377
x = self.x_embedder(x)
7478
timestep = timestep * 1000
7579
c = self.t_embedder(timestep, dtype=x.dtype)
@@ -83,7 +87,7 @@ def forward(
8387

8488
for block in self.transformer_blocks:
8589
out = block(x, c)
86-
if is_8b:
90+
if self.using_8b_controlnet:
8791
x = out
8892
block_out += (out,)
8993

mmditx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
294294

295295
def forward(self, x: torch.Tensor) -> torch.Tensor:
296296
(q, k, v) = self.pre_attention(x)
297-
x = attention(q, k, v, self.num_heads)
297+
x = attention(q, k, v, self.num_heads, self.attn_mode)
298298
x = self.post_attention(x)
299299
return x
300300

other_impls.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,33 @@
77
import torch
88
from torch import nn
99
from transformers import CLIPTokenizer, T5TokenizerFast
10+
from einops import rearrange
11+
12+
try:
13+
import xformers.ops
14+
except ImportError:
15+
xformers.ops = None
16+
print("xformers not found, attn_mode='xformers' will not work")
1017

1118
#################################################################################################
1219
### Core/Utility
1320
#################################################################################################
1421

1522

16-
def attention(q, k, v, heads, mask=None):
23+
def attention(q, k, v, heads, mask=None, attn_mode: str = "torch"):
1724
"""Convenience wrapper around a basic attention operation"""
1825
b, _, dim_head = q.shape
1926
dim_head //= heads
2027
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
21-
out = torch.nn.functional.scaled_dot_product_attention(
22-
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
23-
)
24-
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
28+
if attn_mode == "torch":
29+
out = torch.nn.functional.scaled_dot_product_attention(
30+
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
31+
)
32+
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
33+
elif attn_mode == "xformers":
34+
x = xformers.ops.memory_efficient_attention(q, k, v)
35+
x = rearrange(x, "b h n d -> b n (h d)")
36+
return x
2537

2638

2739
class Mlp(nn.Module):

sd3_impls.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __init__(
148148
pooled_projection_size=pooled_projection_size,
149149
device=device,
150150
dtype=dtype,
151-
).to(device=device, dtype=dtype)
151+
)
152152

153153
def apply_model(self, x, sigma, c_crossattn=None, y=None, skip_layers=[], controlnet_cond=None):
154154
dtype = self.get_dtype()
@@ -159,17 +159,15 @@ def apply_model(self, x, sigma, c_crossattn=None, y=None, skip_layers=[], contro
159159
controlnet_cond = controlnet_cond.to(dtype=x.dtype, device=x.device)
160160
controlnet_cond = controlnet_cond.repeat(x.shape[0], 1, 1, 1)
161161

162-
# 8B ControlNets were trained with a slightly different architecture.
163-
is_8b = y_cond.shape[-1] == self.control_model.y_embedder.mlp[0].in_features
164-
if not is_8b:
162+
if not self.control_model.using_8b_controlnet:
165163
y_cond = self.diffusion_model.y_embedder(y)
166164

167165
x_controlnet = x
168-
if is_8b:
166+
if self.control_model.using_8b_controlnet:
169167
hw = x.shape[-2:]
170168
x_controlnet = self.diffusion_model.x_embedder(x) + self.diffusion_model.cropped_pos_embed(hw)
171169
controlnet_hidden_states = self.control_model(
172-
x_controlnet, controlnet_cond, y_cond, 1, sigma.to(torch.float32), is_8b
170+
x_controlnet, controlnet_cond, y_cond, 1, sigma.to(torch.float32)
173171
)
174172
model_output = self.diffusion_model(
175173
x.to(dtype),

sd3_infer.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from PIL import Image
1818
from safetensors import safe_open
1919
from tqdm import tqdm
20+
import re
2021

2122
import sd3_impls
2223
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
@@ -61,7 +62,9 @@ def load_into(ckpt, model, prefix, device, dtype=None, remap=None):
6162
obj.requires_grad_(False)
6263
# print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}")
6364
if obj.shape != tensor.shape:
64-
print(f"W: shape mismatch for key {model_key}, {obj.shape} != {tensor.shape}")
65+
print(
66+
f"W: shape mismatch for key {model_key}, {obj.shape} != {tensor.shape}"
67+
)
6568
obj.set_(tensor)
6669
except Exception as e:
6770
print(f"Failed to load key '{key}' in safetensors file: {e}")
@@ -148,6 +151,11 @@ class SD3:
148151
def __init__(
149152
self, model, shift, control_model_file=None, verbose=False, device="cpu"
150153
):
154+
155+
# NOTE 8B ControlNets were trained with a slightly different forward pass and conditioning,
156+
# so this is a flag to enable that logic.
157+
self.using_8b_controlnet = False
158+
151159
with safe_open(model, framework="pt", device="cpu") as f:
152160
control_model_ckpt = None
153161
if control_model_file is not None:
@@ -165,9 +173,6 @@ def __init__(
165173
).eval()
166174
load_into(f, self.model, "model.", "cuda", torch.float16)
167175
if control_model_file is not None:
168-
self.model.control_model = self.model.control_model.to(
169-
device=device, dtype=torch.float16
170-
)
171176
control_model_ckpt = safe_open(
172177
control_model_file, framework="pt", device=device
173178
)
@@ -179,6 +184,9 @@ def __init__(
179184
dtype=torch.float16,
180185
remap=CONTROLNET_MAP,
181186
)
187+
188+
self.using_8b_controlnet = self.model.control_model.y_embedder.mlp[0].in_features == 2048
189+
self.model.control_model.using_8b_controlnet = self.using_8b_controlnet
182190
control_model_ckpt = None
183191

184192

@@ -252,7 +260,7 @@ def load(
252260
model_folder: str = MODEL_FOLDER,
253261
text_encoder_device: str = "cpu",
254262
verbose=False,
255-
load_tokenizers: bool = True
263+
load_tokenizers: bool = True,
256264
):
257265
self.verbose = verbose
258266
print("Loading tokenizers...")
@@ -374,19 +382,19 @@ def do_sampling(
374382
self.print("Sampling done")
375383
return latent
376384

377-
def vae_encode(self, image, controlnet_cond: bool = False) -> torch.Tensor:
385+
def vae_encode(self, image, using_8b_controlnet: bool = False) -> torch.Tensor:
378386
self.print("Encoding image to latent...")
379387
image = image.convert("RGB")
380388
image_np = np.array(image).astype(np.float32) / 255.0
381389
image_np = np.moveaxis(image_np, 2, 0)
382390
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
383391
image_torch = torch.from_numpy(batch_images).cuda()
384-
if not controlnet_cond:
392+
if using_8b_controlnet:
385393
image_torch = 2.0 * image_torch - 1.0
394+
else:
395+
image_torch = image_torch * 255
386396
image_torch = image_torch.cuda()
387397
self.vae.model = self.vae.model.cuda()
388-
if controlnet_cond:
389-
image_torch = image_torch * 255
390398
latent = self.vae.model.encode(image_torch).cpu()
391399
self.vae.model = self.vae.model.cpu()
392400
self.print("Encoded")
@@ -411,10 +419,10 @@ def vae_decode(self, latent) -> Image.Image:
411419
self.print("Decoded")
412420
return out_image
413421

414-
def _image_to_latent(self, image, width, height, controlnet_cond: bool = False):
422+
def _image_to_latent(self, image, width, height, using_8b_controlnet: bool = False):
415423
image_data = Image.open(image)
416424
image_data = image_data.resize((width, height), Image.LANCZOS)
417-
latent = self.vae_encode(image_data, controlnet_cond)
425+
latent = self.vae_encode(image_data, using_8b_controlnet)
418426
latent = SD3LatentFormat().process_in(latent)
419427
return latent
420428

@@ -442,7 +450,7 @@ def gen_image(
442450
latent = latent.cuda()
443451
if controlnet_cond_image:
444452
controlnet_cond = self._image_to_latent(
445-
controlnet_cond_image, width, height, True
453+
controlnet_cond_image, width, height, self.sd3.using_8b_controlnet
446454
)
447455
neg_cond = self.get_cond("")
448456
seed_num = None
@@ -468,8 +476,9 @@ def gen_image(
468476
skip_layer_config,
469477
)
470478
image = self.vae_decode(sampled_latent)
479+
os.makedirs(out_dir, exist_ok=False)
471480
save_path = os.path.join(out_dir, f"{i:06d}.png")
472-
self.print(f"Will save to {save_path}")
481+
self.print(f"Saving to to {save_path}")
473482
image.save(save_path)
474483
self.print("Done")
475484

@@ -553,7 +562,13 @@ def main(
553562
inferencer = SD3Inferencer()
554563

555564
inferencer.load(
556-
model, vae, shift, controlnet_ckpt, model_folder, text_encoder_device, verbose
565+
model,
566+
vae,
567+
shift,
568+
controlnet_ckpt,
569+
model_folder,
570+
text_encoder_device,
571+
verbose,
557572
)
558573

559574
if isinstance(prompt, str):
@@ -563,6 +578,7 @@ def main(
563578
else:
564579
prompts = [prompt]
565580

581+
sanitized_prompt = re.sub(r'[^\w\-\.]', '_', prompt)
566582
out_dir = os.path.join(
567583
out_dir,
568584
(
@@ -573,11 +589,9 @@ def main(
573589
else ""
574590
)
575591
),
576-
os.path.splitext(os.path.basename(prompt))[0][:50]
592+
os.path.splitext(os.path.basename(sanitized_prompt))[0][:50]
577593
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
578594
)
579-
print(f"Saving images to {out_dir}")
580-
os.makedirs(out_dir, exist_ok=False)
581595

582596
inferencer.gen_image(
583597
prompts,

0 commit comments

Comments
 (0)