Skip to content

Commit 8075776

Browse files
minor changes to image saving and controlnet loading
1 parent d82a3e6 commit 8075776

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

dit_embedder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(
3232
in_chans=in_chans,
3333
embed_dim=self.hidden_size,
3434
strict_img_size=pos_embed_max_size is None,
35+
device=device,
36+
dtype=dtype,
3537
)
3638

3739
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device)
@@ -41,14 +43,14 @@ def __init__(
4143

4244
self.transformer_blocks = nn.ModuleList(
4345
DismantledBlock(
44-
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True
46+
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True, device=device, dtype=dtype
4547
)
4648
for _ in range(num_layers)
4749
)
4850

4951
self.controlnet_blocks = nn.ModuleList([])
5052
for _ in range(len(self.transformer_blocks)):
51-
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
53+
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)
5254
self.controlnet_blocks.append(controlnet_block)
5355

5456
self.pos_embed_input = PatchEmbed(
@@ -57,6 +59,8 @@ def __init__(
5759
in_chans=in_chans,
5860
embed_dim=self.hidden_size,
5961
strict_img_size=False,
62+
dtype=dtype,
63+
device=device
6064
)
6165
self.is_8b = True
6266

sd3_impls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149
pooled_projection_size=pooled_projection_size,
150150
device=device,
151151
dtype=dtype,
152-
).to(device=device, dtype=dtype)
152+
)
153153

154154
def apply_model(self, x, sigma, c_crossattn=None, y=None, skip_layers=[], controlnet_cond=None):
155155
dtype = self.get_dtype()

sd3_infer.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from PIL import Image
1818
from safetensors import safe_open
1919
from tqdm import tqdm
20+
import re
2021

2122
import sd3_impls
2223
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
@@ -168,9 +169,6 @@ def __init__(
168169
).eval()
169170
load_into(f, self.model, "model.", "cuda", torch.float16)
170171
if control_model_file is not None:
171-
self.model.control_model = self.model.control_model.to(
172-
device=device, dtype=torch.float16
173-
)
174172
control_model_ckpt = safe_open(
175173
control_model_file, framework="pt", device=device
176174
)
@@ -388,8 +386,6 @@ def vae_encode(self, image, controlnet_cond: bool = False) -> torch.Tensor:
388386
image_torch = 2.0 * image_torch - 1.0
389387
image_torch = image_torch.cuda()
390388
self.vae.model = self.vae.model.cuda()
391-
if controlnet_cond:
392-
image_torch = image_torch * 255
393389
latent = self.vae.model.encode(image_torch).cpu()
394390
self.vae.model = self.vae.model.cpu()
395391
self.print("Encoded")
@@ -426,6 +422,7 @@ def _image_to_latent(self, image, width, height, controlnet_cond: bool = False):
426422
image_data = Image.open(image)
427423
image_data = image_data.resize((width, height), Image.LANCZOS)
428424
latent = self.vae_encode(image_data, controlnet_cond)
425+
# latent, _ = DiagonalGaussianRegularizer()(latent)
429426
latent = SD3LatentFormat().process_in(latent)
430427
return latent
431428

@@ -480,8 +477,9 @@ def gen_image(
480477
skip_layer_config,
481478
)
482479
image = self.vae_decode(sampled_latent)
480+
os.makedirs(out_dir, exist_ok=False)
483481
save_path = os.path.join(out_dir, f"{i:06d}.png")
484-
self.print(f"Will save to {save_path}")
482+
self.print(f"Saving to to {save_path}")
485483
image.save(save_path)
486484
self.print("Done")
487485

@@ -572,7 +570,6 @@ def main(
572570
model_folder,
573571
text_encoder_device,
574572
verbose,
575-
load_tokenizers=False,
576573
)
577574

578575
if isinstance(prompt, str):
@@ -582,6 +579,7 @@ def main(
582579
else:
583580
prompts = [prompt]
584581

582+
sanitized_prompt = re.sub(r'[^\w\-\.]', '_', prompt)
585583
out_dir = os.path.join(
586584
out_dir,
587585
(
@@ -592,11 +590,9 @@ def main(
592590
else ""
593591
)
594592
),
595-
os.path.splitext(os.path.basename(prompt))[0][:50]
593+
os.path.splitext(os.path.basename(sanitized_prompt))[0][:50]
596594
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
597595
)
598-
print(f"Saving images to {out_dir}")
599-
os.makedirs(out_dir, exist_ok=False)
600596

601597
inferencer.gen_image(
602598
prompts,

0 commit comments

Comments
 (0)